Gradient Descent by Gradient Descent
In this blog post, I review the 2016 NIPS paper “learning to learn by gradient descent by gradient descent” (I abbreviate as LLGG) by Andrychowicz et al.
Gradient descent is an iterative optimization algorithm. Consider the problem of minimizing an objective function $f(\theta)$. Gradient descent solves this problem through a sequence of updates: [ \theta_{t+1} = \theta_{t} + \alpha_{t} \; \nabla_{\theta_{t}} f ] where $\nabla$ denotes gradient, $\nabla_{\theta_{t}} f = \frac{\partial f}{\partial \theta_{t}}$.
In LLGG, the update $\alpha_{t} \; \nabla_{\theta_{t}} f$ is replaced with an output from the optimizer function $g$: [ \theta_{t+1} = \theta_{t} + g_{t}. ]
Optimizer function
For the optimizer function, a two-layer stack of LSTM networks is used:
[
\begin{align}
h^{(t)} &= LSTM_{1}(\nabla_{\theta_{t}} f)\\
g_{t} &= LSTM_{2}(h^{(t)}).
\end{align}
]
The number of dimensions for the hidden node is set at 20 in the LLGG paper.
This optimizer function is trained at the same time, using the truncated back-propagation through time. When $\theta$ is updated $t$ times, the loss function is [ L = \sum_{s=t - T}^{t} f\left( \theta_{s} \right), ] where in the LLGG paper, the truncation length $T$ is 20 or 32, depending on the training data.
Then, the set of the parameters for the LSTM stack, $\phi$, is updated as follows: [ \phi_{s+1} = \phi_{s} + \alpha_{s} \, \frac{\partial L}{\partial \phi_{s}}, ] where $\alpha$ is given by the adaptive moment estimation algorithm, and [ \begin{align} \frac{\partial L}{\partial \phi_{s}} &= \sum_{t} \left( \frac{\partial L}{\partial g_{t}} \frac{\partial g_{t}}{\partial \phi_{s}} + \frac{\partial L}{\partial \nabla_{\theta_{t-1}}f} \frac{\partial \nabla_{\theta_{t-1}}f}{\partial \phi_{s}} \right). \end{align} ] For computational feasibility, it is assumed that [ \frac{\partial \nabla_{\theta_{t}}}{\partial \phi} = 0 ] so that the LSTM stack can be trained with the standard, back-propagation through time.
Summing up, the training procedure is
- Initialise $\theta$ and $\phi$
- For $t = 1, 2, \dots$
- $\hspace{1cm}$ $L \gets 0$
- $\hspace{1cm}$ For $s = 1, 2, \dots, T$
- $\hspace{1cm}$ $\hspace{1cm}$ $L \gets L + f(\theta)$
- $\hspace{1cm}$ $\hspace{1cm}$ $g_{s} = m(\nabla_{\theta}f, \, \phi)$, where $m$ is the optimizer function (i.e., the stack of LSTM networks)
- $\hspace{1cm}$ $\hspace{1cm}$ $\theta \gets \theta + g_{s}$
- $\hspace{1cm}$ $\phi \gets \phi + \alpha \, \frac{\partial L}{\partial \phi}$
- $\hspace{1cm}$ Update $\alpha$
For computational feasibility again, the optimizer function is coordinatewise: each optimizer function outputs a scaler to update each parameter individually. Here, the same optimizer function is used for each parameter.
Though, the gradient for each parameter can vary greatly, which can be
problematic for training the optimizer function. As a remedy, the LLGG paper
pre-processes the input to the optimizer function:
[
\nabla \rightarrow
\begin{cases}
\left( \frac{\ln \vert \nabla \vert}{p},\, sign(\nabla) \right) &
\textrm{if } \vert \nabla \vert \geq \exp(-p)
\\
\left(-1, \exp(p \, \nabla) \right) &
\textrm{otherwise}
\end{cases}
]
where $p>0$ controls how small gradients are disregarded. The LLGG paper uses
$p=10$.
References
Andrychowicz, M., Denil, M., Gomez, S., Hoffman, M. W., et al. (2016). Learning to learn by gradient descent by gradient descent. arXiv:1606.04474 (cs.NE). (Code is available at their GitHub repo.)