Long Short-Term Memory Network and Back-Propagation through Time
Recurrent neural networks (RNNs) are neural networks to model sequential data. RNNs are often used in speech recognition and natural language processing. In this blog post, I discuss one of the most popular RNNs, a long short-term memory (LSTM) network. Then I briefly address a training procedure for a LSTM.
Long short-term memory network
While RNNs in theory can use contextual information to map input to output sequences, RNNs in practice can use only a limited range of context. The basic problem is that error signals propagated through time tend to either vanish or explode. This difficulty arises from the exponentially smaller weights given to long-term interactions compared to short-term ones (the vanishing gradient problem). As a remedy to the problem, an LSTM network was proposed in 1997.
An LSTM is a special kind of RNNs, and its core contribution is an introduction of a self-connected unit (a memory cell unit). The weight of this cell is gated, as opposed to fixed, which enables the error signals to flow for long duration. An LSTM contains three gates: a forget gate and an input gate to control what information is stored at the memory cell unit, and an output gate to control what information is accessed at the memory cell unit.
The forget gate unit for time step $t$ takes $\bar{f}^{(t)}$ as an input [ \bar{f}^{(t)} = b_{f} + U_{f} \; x^{(t)} + W_{f} \; h^{(t-1)}, ] and gets $f^{(t)}$ as an activation [ f^{(t)} = \sigma_{f} \left( \bar{f}^{(t)} \right). ] Here, $\sigma_{f}$ is a sigmoid, usually the logistic function; $x^{(t)}$ is the input vector and $h^{(t)}$ is the hidden layer vector for time $t$; and $b_{f}$, $U_{f}$, and $W_{f}$ are biases, input weights and recurrent weights for the forget gate.
The activation of the forget gate, $f^{(t)}$, can take a value between 0 and 1, where 1 signifies “retain all of the memory cell state” and 0 signifies “completely forget the memory cell state.”
The input gate unit is computed similarly to the forget gate, but with its own parameters. Its input is [ \bar{g}^{(t)} = b_{g} + U_{g} \; x^{(t)} + W_{g} \; h^{(t-1)}, ] and the output is [ g^{(t)} = \sigma_{g} \left( \bar{g}^{(t)} \right). ] Here again, $\sigma_{g}$ is a sigmoid, usually the logistic function. The activation of this input gate, $g^{(t)}$, can take a value between 0 and 1, and determines which memory cell units are updated and how much they are updated.
The memory cell unit also gets an input from the block input unit. An input to the block input is [ \bar{z}^{(t)} = b_{z} + U_{z} \; x^{(t)} + W_{z} \; h^{(t-1)}, ] and its activation is [ z^{(t)} = \sigma_{z} \left( \bar{z}^{(t)} \right), ] where $\sigma_{z}$ is usually the hyperbolic tangent function.
Then, the memory cell state is updated as follows: [ s^{(t)} = f^{(t)} \odot s^{(t-1)} + g^{(t)} \odot z^{(t)}, ] where $\odot$ indicates element-wise (Hadamard) product.
Remaining is the output gate. Its input is [ \bar{q}^{(t)} = b_{q} + U_{q} \; x^{(t)} + W_{q} \; h^{(t-1)}, ] and its activation is [ q^{(t)} = \sigma_{q} \left( \bar{q}^{(t)} \right). ] Here again, $\sigma_{q}$ is a sigmoid, usually the logistic function. This output gate determines which memory cell states contributes to the hidden state and the output.
Then, the hidden state is updated: [ h^{(t)} = \sigma_{h} \left( s^{(t)} \right) \odot q^{(t)}, ] where $\sigma_{h}$ is usually the hyperbolic tangent function. The output from a LSTM is this hidden state vector, $h^{(t)}$.
Therefore, a LSTM accepts a sequence of inputs $x^{(t)}$ and outputs a sequence
$h^{(t)}$, which I symbolically write as
[
h = LSTM(x).
]
In practice, a LSTM is often stacked. A 4-layer LSTM network, for example, is
[
\begin{align}
h^{(t, 1)} &= LSTM_{1}(x^{(t)})\\
h^{(t, 2)} &= LSTM_{2}(h^{(t, 1)})\\
h^{(t, 3)} &= LSTM_{3}(h^{(t, 2)})\\
h^{(t, 4)} &= LSTM_{4}(h^{(t, 3)}).
\end{align}
]
where $h^{(t, 4)}$ is used to compute the loss function at time $t$.
Back-propagation through time
The original LSTM training algorithm (Hochreiter and Schmidhuber, 1997) used an approximate error gradient calculated with a combination of real time recurrent learning and back-propagation through time (BPTT). However, it is possible to calculate the exact LSTM gradient with BPTT.
For BPTT, we recursively compute gradient. Let me introduce new notations: $L$
denotes the loss function and $\tau$ is a sequence length. Then, we derive the
gradient with respect to $h$, $s$, $\bar{q}$, $\bar{z}$, $\bar{g}$, and $\bar{f}$.
For convenience, I introduce a few notations:
[
\begin{align}
\epsilon_{h}^{(t)}
&= \frac{\partial L}{\partial h^{(t)}}
= \sum_{i=1}^{\tau} \frac{\partial L^{(i)}}{\partial h^{(t)}}
= \sum_{i=t}^{\tau} \frac{\partial L^{(i)}}{\partial h^{(t)}}
\\
\epsilon_{s}^{(t)}
&= \frac{\partial L}{\partial s^{(t)}}
= \sum_{i=1}^{\tau} \frac{\partial L^{(i)}}{\partial s^{(t)}}
= \sum_{i=t}^{\tau} \frac{\partial L^{(i)}}{\partial s^{(t)}}
\\
\epsilon_{\bar{q}}^{(t)}
&= \frac{\partial L}{\partial \bar{q}^{(t)}}
= \sum_{i=1}^{\tau} \frac{\partial L^{(i)}}{\partial \bar{q}^{(t)}}
= \sum_{i=t}^{\tau} \frac{\partial L^{(i)}}{\partial \bar{q}^{(t)}}
\\
\epsilon_{\bar{z}}^{(t)}
&= \frac{\partial L}{\partial \bar{z}^{(t)}}
= \sum_{i=1}^{\tau} \frac{\partial L^{(i)}}{\partial \bar{z}^{(t)}}
= \sum_{i=t}^{\tau} \frac{\partial L^{(i)}}{\partial \bar{z}^{(t)}}
\\
\epsilon_{\bar{g}}^{(t)}
&= \frac{\partial L}{\partial \bar{g}^{(t)}}
= \sum_{i=1}^{\tau} \frac{\partial L^{(i)}}{\partial \bar{g}^{(t)}}
= \sum_{i=t}^{\tau} \frac{\partial L^{(i)}}{\partial \bar{g}^{(t)}}
\quad{\textrm{and}}
\\
\epsilon_{\bar{f}}^{(t)}
&= \frac{\partial L}{\partial \bar{f}^{(t)}}
= \sum_{i=1}^{\tau} \frac{\partial L^{(i)}}{\partial \bar{f}^{(t)}}
= \sum_{i=t}^{\tau} \frac{\partial L^{(i)}}{\partial \bar{f}^{(t)}.}
\end{align}
]
We first compute
[
\epsilon_{h}^{(\tau)} = \frac{\partial L^{(\tau)}}{\partial h^{(\tau)}}
]
and
[
\begin{align}
\epsilon_{s}^{(\tau)}
&= \frac{\partial L^{(\tau)}}{\partial s^{(\tau)}}\\
&=
\frac{\partial L^{(\tau)}}{\partial h^{(\tau)}}
\frac{\partial h^{(\tau)}}{\partial \sigma_{s} \left(s^{(\tau)} \right)}
\frac{\partial \sigma_{s} \left(s^{(\tau)} \right)}{\partial s^{(\tau)}}
\\
&=
\epsilon_{h}^{(\tau)}
\odot
q^{(\tau)}
\odot
\sigma_{s}’\left(s^{(\tau)} \right).
\end{align}
]
Then, we go backwards through time by computing for $t = \tau - 1 , \tau - 2, \dots, 1$:
[
\begin{align}
\epsilon_{s}^{(t)}
&= \frac{\partial L^{(i)}}{\partial s^{(t)}} +
\sum_{i=t+1}^{\tau} \frac{\partial L^{(i)}}{\partial s^{(t)}}
\\
&= \frac{\partial L^{(i)}}{\partial s^{(t)}} +
\sum_{i=t+1}^{\tau} \frac{\partial L^{(i)}}{\partial s^{(t+1)}}
\frac{\partial s^{(t+1)}}{\partial s^{(t)}}
\\
&= \frac{\partial L^{(i)}}{\partial s^{(t)}} +
\epsilon_{s}^{(t+1)} \odot f^{(t+1)}
\\
&= \epsilon_{h}^{(t)} \odot q^{(t)} \odot \sigma_{h}’\left(s^{(t)} \right) +
\epsilon_{s}^{(t+1)} \odot f^{(t+1)},
\end{align}
]
and
[
\begin{align}
\epsilon_{h}^{(t)}
&= \frac{\partial L^{(t)}}{\partial h^{(t)}} +
\sum_{i=t+1}^{\tau} \frac{\partial L^{(i)}}{\partial h^{(t)}}
\\
&= \frac{\partial L^{(t)}}{\partial h^{(t)}} +
\epsilon_{\bar{q}}^{(t+1)} \frac{\partial \bar{q}^{(t+1)}}{\partial h^{(t)}} +
\epsilon_{\bar{z}}^{(t+1)} \frac{\partial \bar{z}^{(t+1)}}{\partial h^{(t)}} +
\epsilon_{\bar{g}}^{(t+1)} \frac{\partial \bar{g}^{(t+1)}}{\partial h^{(t)}} +
\epsilon_{\bar{f}}^{(t+1)} \frac{\partial \bar{f}^{(t+1)}}{\partial h^{(t)}}
\\
&= \frac{\partial L^{(t)}}{\partial h^{(t)}} +
W_{q}^{T} \epsilon_{\bar{q}}^{(t+1)} +
W_{z}^{T} \epsilon_{\bar{z}}^{(t+1)} +
W_{g}^{T} \epsilon_{\bar{g}}^{(t+1)} +
W_{f}^{T} \epsilon_{\bar{f}}^{(t+1)},
\end{align}
]
where superscript $T$ indicates transpose.
The gradient with respect to other functions is given by
[
\begin{align}
\epsilon_{\bar{q}}^{(t+1)}
&=
\sum_{i=t+1}^{\tau} \frac{\partial L^{(i)}}{\partial h^{(t+1)}}
\frac{\partial h^{(t+1)}}{\partial q^{(t+1)}}
\frac{\partial q^{(t+1)}}{\partial \bar{q}^{(t+1)}}
=
\epsilon_{h}^{(t+1)} \odot
\sigma_{h} \left( s^{(t+1)} \right) \odot
\sigma_{q}’ \left( \bar{q}^{(t+1)} \right),
\\
\epsilon_{\bar{z}}^{(t+1)}
&=
\sum_{i=t+1}^{\tau} \frac{\partial L^{(i)}}{\partial s^{(t+1)}}
\frac{\partial s^{(t+1)}}{\partial z^{(t+1)}}
\frac{\partial z^{(t+1)}}{\partial \bar{z}^{(t+1)}}
=
\epsilon_{s}^{(t+1)} \odot
g^{(t)} \odot
\sigma_{z}’ \left( \bar{z}^{(t+1)} \right),
\\
\epsilon_{\bar{g}}^{(t+1)}
&=
\sum_{i=t+1}^{\tau} \frac{\partial L^{(i)}}{\partial s^{(t+1)}}
\frac{\partial s^{(t+1)}}{\partial g^{(t+1)}}
\frac{\partial g^{(t+1)}}{\partial \bar{g}^{(t+1)}}
=
\epsilon_{s}^{(t+1)} \odot
z^{(t)} \odot
\sigma_{g}’ \left( \bar{g}^{(t+1)} \right), \textrm{ and }
\\
\epsilon_{\bar{f}}^{(t+1)}
&=
\sum_{i=t+1}^{\tau} \frac{\partial L^{(i)}}{\partial s^{(t+1)}}
\frac{\partial s^{(t+1)}}{\partial f^{(t+1)}}
\frac{\partial f^{(t+1)}}{\partial \bar{f}^{(t+1)}}
=
\epsilon_{s}^{(t+1)} \odot
s^{(t)} \odot
\sigma_{f}’ \left( \bar{f}^{(t+1)} \right).
\end{align}
]
The gradient for each parameter is computed with the above terms. For example, [ \frac{\partial L}{\partial W_{f}} = \sum_{t=1}^{\tau} \frac{\partial L}{\partial \bar{f}^{(t)}} \frac{\partial \bar{f}^{(t)}}{\partial W_{f}} = \sum_{t=1}^{\tau} \epsilon_{\bar{f}}^{(t)} \otimes x^{(t)}, ] where $\otimes$ indicates outer product.
For a multi-layer LSTM, we need to compute the gradient with respect to the
input $x$.
[
\begin{align}
\frac{\partial L}{\partial x^{(t)}}
&=
\frac{\partial L}{\partial \bar{f}^{(t)}}
\frac{\partial \bar{f}^{(t)}}{\partial x^{(t)}}
+
\frac{\partial L}{\partial \bar{g}^{(t)}}
\frac{\partial \bar{g}^{(t)}}{\partial x^{(t)}}
+
\frac{\partial L}{\partial \bar{z}^{(i)}}
\frac{\partial \bar{z}^{(t)}}{\partial x^{(t)}}
+
\frac{\partial L}{\partial \bar{q}^{(t)}}
\frac{\partial \bar{q}^{(t)}}{\partial x^{(t)}}
\\
&=
W_{f}^{T} \, \epsilon_{\bar{f}}^{(t)}
+
W_{g}^{T} \, \epsilon_{\bar{g}}^{(t)}
+
W_{z}^{T} \, \epsilon_{\bar{z}}^{(t)}
+
W_{q}^{T} \, \epsilon_{\bar{q}}^{(t)}
\end{align}
]
References and Bibliography
Goodfellow, I., Bengio, Y., and Courville, A. (2016). Deep learning. link to the book website.
Graves, A. (2012). Supervised sequence labelling with recurrent neural networks. Springer. link to preprint pdf.
Greff, K., Srivastava, R. K., KoutnÃk, J., Steunebrink, B. R., and Schmidhuber, J. (2015). LSTM: A Search Space Odyssey. arXiv:1503.04069 (cs.NE).
Hochreiter, S., and Schmidhuber, J. (1997). Long short-term memory. Neural computation, 9, 1735-1780.
Jimenez, N. D. (2014). Simple LSTM. link to the blog post.
Olah, C. (2015). Understanding LSTM networks. link to the blog post.