Synaptic metaplasticity in binarized neural networks - PubMed Skip to main page content
U.S. flag

An official website of the United States government

Dot gov

The .gov means it’s official.
Federal government websites often end in .gov or .mil. Before sharing sensitive information, make sure you’re on a federal government site.

Https

The site is secure.
The https:// ensures that you are connecting to the official website and that any information you provide is encrypted and transmitted securely.

Access keys NCBI Homepage MyNCBI Homepage Main Content Main Navigation
. 2021 May 5;12(1):2549.
doi: 10.1038/s41467-021-22768-y.

Synaptic metaplasticity in binarized neural networks

Affiliations

Synaptic metaplasticity in binarized neural networks

Axel Laborieux et al. Nat Commun. .

Abstract

While deep neural networks have surpassed human performance in multiple situations, they are prone to catastrophic forgetting: upon training a new task, they rapidly forget previously learned ones. Neuroscience studies, based on idealized tasks, suggest that in the brain, synapses overcome this issue by adjusting their plasticity depending on their past history. However, such "metaplastic" behaviors do not transfer directly to mitigate catastrophic forgetting in deep neural networks. In this work, we interpret the hidden weights used by binarized neural networks, a low-precision version of deep neural networks, as metaplastic variables, and modify their training technique to alleviate forgetting. Building on this idea, we propose and demonstrate experimentally, in situations of multitask and stream learning, a training technique that reduces catastrophic forgetting without needing previously presented data, nor formal boundaries between datasets and with performance approaching more mainstream techniques with task boundaries. We support our approach with a theoretical analysis on a tractable task. This work bridges computational neuroscience and deep learning, and presents significant assets for future embedded and neuromorphic systems, especially when using novel nanodevices featuring physics analogous to metaplasticity.

PubMed Disclaimer

Conflict of interest statement

The authors declare no competing interests.

Figures

Fig. 1
Fig. 1. Problem setting and illustration of our approach.
a Problem setting: two training sets (here MNIST and Fashion-MNIST) are presented sequentially to a fully connected neural network. When learning MNIST (epochs 0–50), the MNIST test accuracy reaches 97%, while the Fashion-MNIST accuracy stays around 10%. When learning Fashion-MNIST (epochs 50–100), the associated test accuracy reaches 85% while the MNIST test accuracy collapses to ~20% in 25 epochs: this phenomenon is known as “catastrophic forgetting.” b Illustration of our approach: in a binarized neural network, each synapse incorporates a hidden weight Wh used for learning and a binary weight Wb = sign(Wh) used for inference. Our method, inspired by neuroscience works in the literature, amounts to regarding hidden weights as metaplastic states that can encode memory across tasks and thereby alleviate forgetting. With regards to the conventional training technique of binarized neural network, it consists in modulating some hidden weight updates by a function fmeta(Wh) whose shape is indicated in c. This modulation is applied to negative updates of positive hidden weights, and to positive updates of negative hidden weights. fmeta(∣Wh∣) being a decreasing function, this modulation makes the hidden weight signs less likely to switch back when they grow in absolute value.
Fig. 2
Fig. 2. Permuted MNIST learning task.
ad Binarized neural network learning six tasks sequentially for several values of the metaplastic parameter m. a m = 0 corresponds to a conventional binarized neural network b m = 0.5, c m = 1.0, d m = 1.35. Curves are averaged over five runs and shadows correspond to one standard deviation. e, f Final test accuracy on each task after the last task has been learned. The dots indicate the mean values over five runs, and the shaded zone one standard deviation. e Corresponds to a binarized neural network and f corresponds to our method applied to a real valued weights deep neural network with the same architecture. g, h Hidden weights distribution of a m = 1.35, two hidden layers of 4096 units binarized neural network after learning for 40 epochs one permuted MNIST (g) and two permuted MNISTs (h).
Fig. 3
Fig. 3. Influence of the network size on the number of tasks learned.
a, b Mean test accuracy over tasks learned so far for up to ten tasks. Each task is a permuted version of MNIST learned for 40 epochs. The binarized neural network architecture consists of two hidden layers of a variable number of hidden units ranging from 512 to 4096. a Uses metaplasticity with parameter m = 1.35 and b uses elastic weight consolidation with λEWC = 5000. The decrease in mean test accuracy comes from the impossibility to learn new tasks because too many weights are consolidated. Results for non-sequential (interleaved) training for c a non-metaplastic and d a metaplastic binarized neural network. In this situation, each point is an independent training experiment performed on the corresponding number of tasks. All curves are averaged over five runs and shadow areas denote one standard deviation.
Fig. 4
Fig. 4. Sequential learning on various datasets.
Binarized neural network learning MNIST and Fashion-MNIST sequentially a without metaplasticity and b with metaplasticity. c Sequential training of the MNIST and USPS datasets of handwritten digits. The baselines correspond to the accuracy reached by non-metaplastic networks with half the number of neurons trained independently on each task. d Presents the same experiment as c, with a metasplastic network featuring a doubled number of parameters with regards to the baselines. e, f Test accuracy when learning sequentially two subsets of CIFAR-10 classes from features extracted by a pretrained ResNet on ImageNet (see Supplementary Note 11). g, h Same experiment with CIFAR-100 features. All curves except c and d are averaged over five runs. c and d are averaged over fifty runs due to the small amount of data (see Supplementary Note 10). Shadows correspond to one standard deviation.
Fig. 5
Fig. 5. Stream learning experiments.
a Progressive learning of the Fashion-MNIST dataset. The dataset is split into 60 parts consisting of only 1000 examples, and containing all ten classes. Each sub dataset is learned for 20 epochs. The dashed lines represent the accuracies reached when the training is done on the full dataset for 20 epochs so that all curves are obtained with the same number of optimization steps. b Progressive learning of the CIFAR-10 dataset. The dataset is split into 20 parts, consisting of only 2500 examples. Each sub dataset is learned for 200 epochs. The dashed lines represent the accuracies reached when the training is done on the full dataset for 200 epochs. Shadows correspond to one standard deviation around the mean over five runs.
Fig. 6
Fig. 6. Interpretation of the meaning of hidden weights.
a Example of hidden weights trajectory in a two-dimensional quadratic binary task. One hidden weight Wxh diverges because the optimal hidden weight vector W* has uniform norm greater than one (Lemma 2 of Supplementary Note 5). b Mean increase in the loss occurred by switching the sign of a hidden weight as a function of the normalized value of the hidden weight, for a 500-dimensional quadratic binary task. The mean is taken by assigning hidden weights to bins of increasing absolute value and the error bars denote one standard deviation around the mean. The leftmost point corresponds to hidden weights staying bounded. ce Increase in the loss occurred by switching the sign of hidden weights as a function of the normalized absolute value of the hidden weight in a binarized neural network trained on MNIST. Each dot is the mean increase over 100 realizations of weights to be switched and the error bars denote one standard deviation. The scales differ because the layers have different numbers of weights and thus different relative importance. See “Methods” for implementation details.

Similar articles

Cited by

References

    1. LeCun Y, Bengio Y, Hinton G. Deep learning. Nature. 2015;521:436–444. doi: 10.1038/nature14539. - DOI - PubMed
    1. Goodfellow, I. J., Mirza, M., Xiao, D., Courville, A. & Bengio, Y. An empirical investigation of catastrophic forgeting in gradientbased neural networks. In Proc. International Conference on Learning Representations ICLR (2014).
    1. Kirkpatrick J, et al. Overcoming catastrophic forgetting in neural networks. Proc. Natl Acad. Sci. 2017;114:3521–3526. doi: 10.1073/pnas.1611835114. - DOI - PMC - PubMed
    1. French RM. Catastrophic forgetting in connectionist networks. Trends Cogn. Sci. 1999;3:128–135. doi: 10.1016/S1364-6613(99)01294-2. - DOI - PubMed
    1. McClelland JL, McNaughton BL, O’Reilly RC. Why there are complementary learning systems in the hippocampus and neocortex: insights from the successes and failures of connectionist models of learning and memory. Psychol. Rev. 1995;102:419–457. doi: 10.1037/0033-295X.102.3.419. - DOI - PubMed

Publication types