Skip to content
Go back

Clearly Explaining: Long Short-Term Memory (LSTM) [1/16]

· 8 min read

TL;DR

The Problem: Vanilla RNNs look like they have memory, but they’re hard to train to remember things for a long time. When you train them, the “learning signal” (the gradient) often becomes too small (vanishes) or too large (explodes) as it travels backward through many time steps.

The Fix: LSTM adds a separate memory lane called the cell state ctc_t plus gates (valves) that control what to keep, write, and show. This creates a stable path for learning long-range dependencies.


The Problem

When you read a sentence, you keep a running idea of what’s happening.

Models that read sequences (text, audio, time series) need the same ability to carry context forward. That’s exactly what Recurrent Neural Networks (RNNs) were designed to do.

Imagine you’re reading one word at a time.

At each step tt, you have:

RNNs update its summary like this:

ht=ϕ(Wxhxt+Whhht1+b)h_t = \phi(W_{xh} x_t + W_{hh} h_{t-1} + b)

The key idea being:

new summary = function(current input + previous summary)

So why did we need LSTMs at all…?

Because RNNs have two different problems:

Issue 1: One vector has to do two jobs at once

A single summary vector hth_t (given above) is forced to do two jobs at once:

  1. store useful past info
  2. produce useful present prediction

That’s already a tough balancing act but it gets worse during training.

Issue 2: Training signal becomes unstable over long sequences

Training works like this:

For sequences, that backward signal has to pass through every time step in reverse. This is called Backpropagation Through Time (BPTT).

To send the learning signal back 50 steps, you multiply the gradient 50 times which is once per step

So the model either:

can’t learn long-range dependencies because the signal dies → vanishing gradients

or

becomes unstable because the signal blows up → exploding gradients

This is known as the vanishing/exploding gradient problem, which is why vanilla RNNs tended to learn only short-term patterns.

Note: For RNNs the model weights are shared across time.

back prop

Deep Dive: the math of vanishing/exploding gradients

A vanilla RNN’s hidden state is given by:

ht=ϕ(Whhht1+Wxhxt+b)h_t = \phi(W_{hh}h_{t-1} + W_{xh}x_t + b)

Let LL be the loss at the end of the sequence. The gradient w.r.t. an earlier state htkh_{t-k} is:

Lhtk=(j=tk+1thjhj1)Lht\frac{\partial L}{\partial h_{t-k}} = \left(\prod_{j=t-k+1}^{t} \frac{\partial h_j}{\partial h_{j-1}}\right) \frac{\partial L}{\partial h_t}

The one-step Jacobian is:

hjhj1=diag(ϕ(aj))Whh\frac{\partial h_j}{\partial h_{j-1}} = \mathrm{diag}(\phi'(a_j)) W_{hh}

So the gradient magnitude is controlled by repeated multiplication of matrices that include WhhW_{hh}.

If the spectral norm / dominant singular value of this Jacobian is < 1, the product shrinks exponentially with kk → vanishing gradients.

If it’s > 1, the product grows exponentially → exploding gradients.

This is the core numerical instability of training vanilla RNNs across long horizons.


The Solution: Long Short-Term Memory (LSTM)

LSTM’s key idea is simple but powerful:

Don’t force one vector to do everything. Give the model a dedicated memory lane.

An LSTM has two states:

Think of it like this:

Instead of constantly overwriting memory, LSTM updates memory like an editor:

ct=(keep some old)+(write some new)c_t = \text{(keep some old)} + \text{(write some new)}

In the standard modern form:

ct=ftct1+itc~tc_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t

Where:

So LSTM explicitly learns:

lstm architecture

What are gates ?

A gate is just a small neural network layer that outputs values in (0,1)(0, 1) using a sigmoid.

LSTM uses three gates:

A simple way to remember this is: LSTM = memory + 3 valves.


The Full Architecture (Step-by-Step)

Lets track the data moving through the cell.

At time tt, we receive input xtx_t and the previous states ht1h_{t-1} and ct1c_{t-1}. Here is what happens inside the cell:

Step 1: Decide what to forget

First, the Forget Gate looks at the previous context (ht1h_{t-1}) and the new input (xtx_t). It outputs a number between 0 and 1 for each number in the cell state.

ft=σ(Wfxt+Ufht1+bf)f_t = \sigma(W_f x_t + U_f h_{t-1} + b_f)

forget gate

Step 2: Decide what to store

Next, we decide what new information to store in the cell state. This has two parts:

  1. Input Gate (iti_t): A sigmoid layer decides which values we’ll update.
  2. Candidate Values (c~t\tilde{c}_t): A tanh layer creates a vector of new candidate values that could be added to the state.
it=σ(Wixt+Uiht1+bi)c~t=tanh(Wcxt+Ucht1+bc)\begin{aligned} i_t &= \sigma(W_i x_t + U_i h_{t-1} + b_i) \\ \tilde{c}_t &= \tanh(W_c x_t + U_c h_{t-1} + b_c) \end{aligned}

input gate

Step 3: Update the long-term memory

This is the critical line where the magic happens. We update the old cell state ct1c_{t-1} into the new cell state ctc_t.

We multiply the old state by ftf_t (forgetting the things we decided to forget earlier). Then we add itc~ti_t * \tilde{c}_t (adding the new candidate values, scaled by how much we decided to update each state value).

ct=ftct1Retention+itc~tNew Infoc_t = \underbrace{f_t \odot c_{t-1}}_{\text{Retention}} + \underbrace{i_t \odot \tilde{c}_t}_{\text{New Info}}

Step 4: Produce the output

Finally, we decide what we’re going to output. This output (hth_t) will be based on our newly updated cell state, but filtered.

  1. Output Gate (oto_t): Decides which parts of the cell state to output.
  2. Filtering: We put the cell state through tanh (to push the values to be between -1 and 1) and multiply it by the output gate.
ot=σ(Woxt+Uoht1+bo)ht=ottanh(ct)\begin{aligned} o_t &= \sigma(W_o x_t + U_o h_{t-1} + b_o) \\ h_t &= o_t \odot \tanh(c_t) \end{aligned}

output gate

Where:


How does LSTM fix the issues with vanilla RNNs?

The key to the LSTM’s success is that it changes the fundamental mathematical operation of memory.

RNNs operate via Multiplication. In a standard RNN, the hidden state is constantly being multiplied by a weight matrix WW. If you multiply a number by a fraction (say, 0.9) fifty times, it approaches zero. If you multiply by a large number (say, 1.1) fifty times, it explodes. The gradient has to fight through this multiplicative gauntlet.

LSTMs operate via Addition. Look closely at the cell state update equation again:

ct=ftct1Retention+itc~tNew Infoc_t = \underbrace{f_t \odot c_{t-1}}_{\text{Retention}} + \underbrace{i_t \odot \tilde{c}_t}_{\text{New Info}}

The cell state is a linear “conveyor belt.” Information flows straight down the entire chain with only minor linear interactions.

When we do backpropagation, the gradient of an addition is 1. This means the error signal can flow backward through time without being squashed or exploded, provided the forget gate ftf_t is active. This property is often called the Constant Error Carousel (CEC).

By defaulting to “remembering” (additive updates) rather than “transforming” (multiplicative updates), LSTMs make it much easier for the gradient to find a path from step 100 back to step 1.

Deep Dive: The math of the additive gradient

Let’s look at the gradient of the cell state ctc_t with respect to the previous cell state ct1c_{t-1}. Since ct=ftct1+itc~tc_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t, we can apply the product rule to find the derivative:

ctct1=ftLinear Term+ct1ftct1+c~titct1+itc~tct1Gate Dependencies\frac{\partial c_t}{\partial c_{t-1}} = \underbrace{f_t}_{\text{Linear Term}} + \underbrace{c_{t-1} \frac{\partial f_t}{\partial c_{t-1}} + \tilde{c}_t \frac{\partial i_t}{\partial c_{t-1}} + i_t \frac{\partial \tilde{c}_t}{\partial c_{t-1}}}_{\text{Gate Dependencies}}

The beauty of the LSTM lies in that first term: ftf_t.

While the “Gate Dependencies” exist (because the gates themselves technically depend on the previous state via ht1h_{t-1}), they are complex and often negligible during backpropagation. The Linear Term ftf_t, however, provides a direct path for the gradient.

If we look at how the gradient travels back kk steps, we can approximate it by chaining these linear terms:

ctctkj=tk+1tfj\frac{\partial c_t}{\partial c_{t-k}} \approx \prod_{j=t-k+1}^{t} f_j

In a vanilla RNN, this product involves a matrix WW (which causes explosion/vanishing). In an LSTM, it involves the scalar gate fjf_j.

  • If fj=1f_j = 1, the gradient passes through perfectly (identity mapping).
  • If fj=0f_j = 0, the gradient is cut off (forgetting).

This allows the network to learn exactly when to let the error signal flow backward and when to stop it, without suffering from the numerical instability of repeated matrix multiplication.


Impact

The LSTM is arguably the most successful commercial neural network architecture for sequences prior to the Transformer era.

For nearly a decade (roughly 2012–2018), LSTMs were the state-of-the-art engine behind:

While Transformers (like BERT and GPT) have largely replaced LSTMs for massive natural language tasks due to their ability to parallelize training, LSTMs remain highly relevant in:

  1. Time-series forecasting (stock prediction, weather, IoT sensor data).
  2. Low-latency environments where the O(N2)O(N^2) complexity of Transformers is too expensive.
  3. Reinforcement Learning agents that need memory (like OpenAI Five).

Understanding the LSTM is understanding the bridge between simple neural networks and modern reasoning systems.

References

  1. Sepp Hochreiter & Jürgen SchmidhuberLong Short-Term Memory (Neural Computation, 1997)
  2. Yoshua Bengio, Patrice Simard, Paolo FrasconiLearning Long-Term Dependencies with Gradient Descent is Difficult (IEEE TNN, 1994)
  3. Paul WerbosBackpropagation Through Time: What It Does and How to Do It (1990)
  4. Razvan Pascanu, Tomas Mikolov, Yoshua Bengio — On the difficulty of training Recurrent Neural Networks (ICML, 2013)
  5. Klaus Greff et al. — LSTM: A Search Space Odyssey (2015)
  6. Ilya Sutskever, Oriol Vinyals, Quoc V. LeSequence to Sequence Learning with Neural Networks (NeurIPS, 2014)
  7. Google (GNMT authors) — Google’s Neural Machine Translation System: Bridging the Gap between Human and Machine Translation (2016)
  8. Alex Graves, Abdel-rahman Mohamed, Geoffrey HintonSpeech Recognition with Deep Recurrent Neural Networks (ICASSP, 2013)
  9. Ashish Vaswani et al. — Attention Is All You Need (2017)
  10. Jacob Devlin et al. — BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding (NAACL, 2019; arXiv 2018)
  11. OpenAI (OpenAI Five authors) — Dota 2 with Large Scale Deep Reinforcement Learning (2019)

Share this post on:

Previous Post
Clearly Explaining: Pangram’s approach to AI text detection
Next Post
Clearly Explaining: Playing Atari with Deep Reinforcement Learning [1/35]