Huh & Sejnowski (2018) - Gradient Descent for Spiking Neural Networks

You can find the paper here.

Why am I reading this paper?

I came across this one while researching stochastic vs batch training for SNN. My aim is to better understand how to accumulate the gradients and whether there is a prominent difference between accumulating the gradients in a rate-based networks (ANNs) vs. spiking nets.


Motivation

One of the main limitations of SNN is that we don’t know how to train them and therefore we are unable to reap the benefits of spike-based computations.

ANNs can be seen as a sub-category of SNN. In fact, a SNN can be reduced to a rate-based network in the “high firing-rate limit” (quoting the authors).

One interesting question that was brought up in the paper is what kind of computation can we unlock if we manage to train SNN? What kind of computation becomes possible then?


Main Contribution

A novel approach for training SNN that is not subjected to specific neuron models, network architecture, loss functions or, a particular task, one that is revolving around the formulation of a differentiable SNN for which the gradient can be derived.

Note: The goal is not to have a biologically-plausible learning rule but to derive an efficient learning method. Once SNN networks are trained, we can then try to analyze this Pandora’s box to reveal computational processes of the brain or come up with other algorithms that are hardware-friendly.


Previous Attempts to Supervised Learning in SNNs


Methods

A differentiable formulation of SNN

Most models describe the synaptic current dynamics as linear filter process which instantly activates when the presynaptic membrane voltage $v$ crosses a threshold.

which is equivalent to a soft-clipping function. Synaptic_model

Define an entire network around this new formulation

Now that we have a differentiable synaptic current model, we can formulate the network input-output dynamics around it.

Next step is to optimize the weight matrices for the input layer, recurrent connections and readout, $U$ , $W$ and $O$ which can now be accomplished by gradient descent.

Gradient Descent

Without going through all the math (partly because I don’t like equations that much and, I don’t fully grasp them for now 🙈), the weight matrices optimization problem is accomplished through backpropagation dynamics of adjoint state variables or the Pontryagin’s minimum principle.

At the end, we get a formula for the gradient that resembles the reward-modulated STDP; a presynaptic input is multiplied by a postsynaptic spike activity and a temporal error signal.


Results

This approach is tested on a predictive coding task (i.e matching (reproducing) input-output behavior also auto-encoding task) and delayed-memory XOR task where the XOR operation is performed on the historical input which is stored.

The latter task is used to demonstrate the ability of the algorithm to train an SNN performing a non-linear computations over larger timescales (> time constants).