Understand why recurrent neural networks struggled and what transformers solved
๐ Complete all tutorials to earn your Free Transformers Architecture Certificate
Shareable on LinkedIn โข Verified by AITutorials.site โข No signup fee
For years, Recurrent Neural Networks (RNNs) were the gold standard for sequence modeling. From language translation to speech recognition, RNNs powered the AI breakthroughs of the 2010s. But they had critical flaws that limited their potential. Understanding these problems is the first step to appreciating why transformers are revolutionary.
RNNs seemed like the perfect solution for sequential data:
But as models scaled and tasks became more complex, fundamental limitations emerged that no amount of engineering could fully solve.
Before diving into problems, let's understand how RNNs work:
At each time step t:
Input: x_t (current token embedding)
Previous: h_{t-1} (hidden state from previous step)
Computation:
h_t = tanh(W_hh ร h_{t-1} + W_xh ร x_t + b_h)
y_t = W_hy ร h_t + b_y
Output: y_t (prediction)
Next State: h_t (passed to next time step)
import torch
import torch.nn as nn
class SimpleRNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.hidden_size = hidden_size
# Weight matrices
self.i2h = nn.Linear(input_size, hidden_size) # input to hidden
self.h2h = nn.Linear(hidden_size, hidden_size) # hidden to hidden
self.h2o = nn.Linear(hidden_size, output_size) # hidden to output
def forward(self, x, hidden=None):
"""
x: (batch_size, seq_len, input_size)
Returns: outputs, final_hidden
"""
batch_size, seq_len, _ = x.size()
# Initialize hidden state
if hidden is None:
hidden = torch.zeros(batch_size, self.hidden_size)
outputs = []
# CRITICAL: Sequential processing - can't parallelize!
for t in range(seq_len):
x_t = x[:, t, :] # Get token at time t
# Compute new hidden state
hidden = torch.tanh(self.i2h(x_t) + self.h2h(hidden))
# Compute output
output = self.h2o(hidden)
outputs.append(output)
# Stack outputs
outputs = torch.stack(outputs, dim=1)
return outputs, hidden
# Example usage
rnn = SimpleRNN(input_size=512, hidden_size=256, output_size=10000)
# Input: batch of 32 sequences, each 100 tokens long
x = torch.randn(32, 100, 512)
# Forward pass: MUST process sequentially
outputs, final_hidden = rnn(x)
print(f"Outputs shape: {outputs.shape}") # (32, 100, 10000)
print(f"Final hidden: {final_hidden.shape}") # (32, 256)
โ ๏ธ Notice the Sequential Loop: The `for t in range(seq_len)` loop is unavoidable. Even though we batch multiple sequences (batch_size=32), each sequence must be processed step-by-step. This is the core bottleneck.
The most famous and fundamental RNN problem. When backpropagating through many time steps, gradients become exponentially smaller, making it nearly impossible to learn long-range dependencies.
To understand vanishing gradients, we need to see how gradients flow backward through time:
โ RNN Gradient Flow (Backpropagation Through Time)
Forward Pass:
h_t = tanh(W_hh ร h_{t-1} + W_xh ร x_t)
Backward Pass (computing โL/โh_0):
โL/โh_0 = โL/โh_T ร โh_T/โh_{T-1} ร โh_{T-1}/โh_{T-2} ร ... ร โh_1/โh_0
Each term โh_t/โh_{t-1} involves:
โh_t/โh_{t-1} = W_hh ร tanh'(z_t)
where tanh'(z) โ [0, 1] (usually < 0.25)
For sequence length T=100:
โL/โh_0 = โL/โh_100 ร (W_hh)^100 ร โ tanh'(z_t)
If ||W_hh|| < 1 and tanh' < 0.25:
Gradient โ 0.9^100 ร 0.25^100 โ 10^-60 (vanishes!)
If ||W_hh|| > 1:
Gradient โ 1.1^100 โ 10^4 (explodes!)
Key Insight: The gradient is a product of many terms. If most terms are less than 1, the product shrinks exponentially. This is why RNNs struggle to learn dependencies more than 10-20 steps apart.
Consider this movie review:
"The film started with an incredibly boring first hour with terrible pacing, wooden acting, and a confusing plot. However, the second half completely redeemed it with spectacular action sequences, emotional depth, and a satisfying conclusion. Overall, I loved it!"
Sentiment: POSITIVE (but you need the word "However" at position 25 to understand this!)
An RNN reading this left-to-right will:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
class RNNWithGradientTracking(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, 1)
def forward(self, x):
output, hidden = self.rnn(x)
return self.fc(output[:, -1, :]) # Last time step
# Experiment: Track gradient norms
model = RNNWithGradientTracking(input_size=10, hidden_size=50)
optimizer = torch.optim.Adam(model.parameters())
sequence_lengths = [10, 25, 50, 100, 200]
gradient_norms = []
for seq_len in sequence_lengths:
# Generate random sequence
x = torch.randn(16, seq_len, 10)
target = torch.randn(16, 1)
# Forward + backward
optimizer.zero_grad()
output = model(x)
loss = nn.MSELoss()(output, target)
loss.backward()
# Compute gradient norm
total_norm = 0
for p in model.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
gradient_norms.append(total_norm)
print(f"Seq Length: {seq_len:3d} | Gradient Norm: {total_norm:.6f}")
# Output (typical):
# Seq Length: 10 | Gradient Norm: 2.453821
# Seq Length: 25 | Gradient Norm: 0.892341
# Seq Length: 50 | Gradient Norm: 0.123456
# Seq Length: 100 | Gradient Norm: 0.003214 โ Vanished!
# Seq Length: 200 | Gradient Norm: 0.000089 โ Nearly zero!
โ ๏ธ Real-World Impact: This gradient decay is why pre-transformer models like neural machine translation struggled with long sentences. For sentences longer than 30-40 words, translation quality degraded significantly because the model couldn't learn dependencies from the beginning of the sentence.
Researchers tried many approaches to address vanishing gradients:
Solution: Add memory cells with gates (forget, input, output) to control information flow
Math:
f_t = ฯ(W_f ยท [h_{t-1}, x_t]) (forget gate)
i_t = ฯ(W_i ยท [h_{t-1}, x_t]) (input gate)
Cฬ_t = tanh(W_C ยท [h_{t-1}, x_t]) (candidate)
C_t = f_t * C_{t-1} + i_t * Cฬ_t (cell state)
o_t = ฯ(W_o ยท [h_{t-1}, x_t]) (output gate)
h_t = o_t * tanh(C_t) (hidden state)
Improvement: Can learn dependencies up to ~100 steps (vs 10-20 for vanilla RNNs)
Limitation: Still sequential, still slow, still has gradient issues at very long ranges
Solution: Simplified LSTM with fewer gates (2 instead of 3)
Math:
z_t = ฯ(W_z ยท [h_{t-1}, x_t]) (update gate)
r_t = ฯ(W_r ยท [h_{t-1}, x_t]) (reset gate)
hฬ_t = tanh(W ยท [r_t * h_{t-1}, x_t])
h_t = (1 - z_t) * h_{t-1} + z_t * hฬ_t
Improvement: Faster than LSTM, similar performance
Limitation: Doesn't fundamentally solve the sequential bottleneck
Solution: Clip gradients if they exceed a threshold
Code:
torch.nn.utils.clip_grad_norm_(
model.parameters(),
max_norm=1.0
)
Improvement: Prevents exploding gradients
Limitation: Doesn't help with vanishing gradients at all!
Solution: Initialize W_hh as identity matrix
Theory: If W_hh = I, then ||W_hh|| = 1, avoiding exponential decay
Improvement: Helps slightly with gradient flow
Limitation: Doesn't address fundamental sequential architecture
โ Bottom Line: These techniques improved RNNs significantly, but they couldn't escape the fundamental limitations:
RNNs are inherently sequential. To process the 100th token, you must first process tokens 1-99. This sequential dependency creates a fundamental bottleneck that no amount of hardware can solve.
RNN Forward Pass (MUST be sequential):
Time step 1: hโ โ process token 1 โ hโ
Time step 2: hโ โ process token 2 โ hโ โ depends on hโ
Time step 3: hโ โ process token 3 โ hโ โ depends on hโ
...
Time step 100: hโโ โ process token 100 โ hโโโ โ depends on hโโ
โฑ๏ธ Total time: O(sequence_length) - CANNOT parallelize!
Even with 1000 GPUs, you can't speed up processing a single sequence!
Compare to transformers, which process all tokens in parallel:
Transformer Forward Pass (fully parallel):
Time step 1: Process ALL tokens simultaneously
Token 1, Token 2, ..., Token 100 โ ALL computed at once
โฑ๏ธ Total time: O(1) for forward pass (in terms of sequence length)
Speedup: 100x faster for seq_len=100!
With multiple GPUs, you can parallelize both batch and sequence dimensions!
Let's measure actual wall-clock time for processing sequences of different lengths:
import torch
import torch.nn as nn
import time
# Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 32
hidden_size = 512
vocab_size = 10000
class BenchmarkRNN(nn.Module):
def __init__(self):
super().__init__()
self.embedding = nn.Embedding(vocab_size, hidden_size)
self.rnn = nn.LSTM(hidden_size, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, vocab_size)
def forward(self, x):
emb = self.embedding(x)
output, _ = self.rnn(emb) # Sequential processing!
return self.fc(output)
class BenchmarkTransformer(nn.Module):
def __init__(self):
super().__init__()
self.embedding = nn.Embedding(vocab_size, hidden_size)
self.pos_encoding = nn.Parameter(torch.randn(1, 5000, hidden_size))
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_size,
nhead=8,
batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=6)
self.fc = nn.Linear(hidden_size, vocab_size)
def forward(self, x):
seq_len = x.size(1)
emb = self.embedding(x) + self.pos_encoding[:, :seq_len, :]
output = self.transformer(emb) # Parallel processing!
return self.fc(output)
# Benchmark function
def benchmark(model, sequence_lengths):
model.to(device).eval()
results = []
for seq_len in sequence_lengths:
# Generate input
x = torch.randint(0, vocab_size, (batch_size, seq_len)).to(device)
# Warmup
with torch.no_grad():
for _ in range(10):
_ = model(x)
# Measure time
torch.cuda.synchronize() if torch.cuda.is_available() else None
start = time.time()
with torch.no_grad():
for _ in range(100):
_ = model(x)
torch.cuda.synchronize() if torch.cuda.is_available() else None
elapsed = (time.time() - start) / 100
results.append((seq_len, elapsed))
print(f"Seq Length: {seq_len:4d} | Time: {elapsed*1000:.2f}ms")
return results
# Run benchmarks
print("RNN Performance:")
rnn_results = benchmark(BenchmarkRNN(), [10, 50, 100, 200, 500, 1000])
print("\nTransformer Performance:")
transformer_results = benchmark(BenchmarkTransformer(), [10, 50, 100, 200, 500, 1000])
# Typical results (on GPU):
# RNN Performance:
# Seq Length: 10 | Time: 2.34ms
# Seq Length: 50 | Time: 8.92ms โ scales linearly
# Seq Length: 100 | Time: 17.56ms โ scales linearly
# Seq Length: 200 | Time: 34.78ms โ scales linearly
# Seq Length: 500 | Time: 86.34ms โ scales linearly
# Seq Length: 1000 | Time: 172.89ms โ scales linearly
# Transformer Performance:
# Seq Length: 10 | Time: 3.12ms
# Seq Length: 50 | Time: 4.23ms โ grows slowly (quadratic in theory)
# Seq Length: 100 | Time: 5.67ms โ but much faster in practice
# Seq Length: 200 | Time: 9.34ms โ 3.7x faster than RNN!
# Seq Length: 500 | Time: 28.45ms โ 3.0x faster than RNN!
# Seq Length: 1000 | Time: 89.23ms โ 1.9x faster than RNN!
๐ก Key Observations:
The sequential bottleneck affects training even more severely:
Scenario: Train on 1 million sequences, each 512 tokens long
RNN (LSTM):
Transformer:
๐ Speedup: 3.4ร faster training with transformers!
And this is conservative - modern transformers with optimizations (Flash Attention, etc.) can be 10ร faster!
Sequential processing also limits memory efficiency:
Must store hidden states for ALL time steps during training (for BPTT):
Memory = batch ร seq_len ร hidden_dim
+ gradients for all time steps
For batch=32, seq=512, hidden=1024:
Memory โ 32 ร 512 ร 1024 ร 4 bytes
โ 67 MB (just for hidden states!)
Plus gradients: ~134 MB total
Can't increase batch size much!
Stores attention matrices, but enables gradient checkpointing:
Memory = batch ร seq_len ร hidden_dim
+ attention (can checkpoint)
For batch=32, seq=512, hidden=1024:
Memory โ 32 ร 512 ร 1024 ร 4 bytes
โ 67 MB (similar base)
But can use larger batches due to
better memory utilization!
Typical batch size: 2-4ร larger!
โ ๏ธ The Cost of Sequence: With RNNs, increasing sequence length:
This is why transformers won: They broke the sequential constraint, enabling truly scalable training on modern parallel hardware.
In RNNs, to connect token 1 and token 100, information must flow through 99 hidden states. Each transition is a bottleneck where information can be lost:
Token 1 โ Hidden State 1
Information compressed into fixed-size vector
Hidden State 1 โ Hidden State 2
Information compressed again, some details lost
Hidden State 2 โ Hidden State 3
More compression, more loss...
By Hidden State 100: Original token 1 info is nearly gone!
Transformers solve this with direct connections: every token can directly attend to every other token.
RNNs have a hidden state of fixed size (e.g., 512-dimensional vector). All information about the past must fit in this bottleneck:
โ RNN: Process 1000 tokens, compress to 512D hidden state
Information loss is inevitable. What gets kept? What forgotten?
โ Transformer: Keep ALL token embeddings (1000 ร 768D)
Attention mechanism selectively reads relevant tokens
Because RNNs are sequential, you can't parallelize across time steps during training. You can batch multiple sequences, but within a sequence, you're stuck:
# RNN: Sequential even in training
hidden = torch.zeros(batch_size, hidden_dim)
for t in range(sequence_length):
# Must wait for previous time step
hidden = rnn_cell(x[:, t], hidden)
outputs[:, t] = hidden
# Total time: O(sequence_length)
# Even with parallelization, bottleneck remains
Transformers process all tokens at once, enabling massive parallelization. This is why they can train on modern GPUs efficiently.
All RNN problems stem from one core issue:
RNNs force a SEQUENTIAL bottleneck on an inherently PARALLEL problem
You want to model relationships between tokens. But forcing sequential processing creates efficiency and expressiveness costs.
RNNs mix information sequentially, which limits their ability to capture complex, multi-way interactions between tokens. This becomes critical for tasks requiring global context understanding.
Example: Pronoun Resolution
"Alice told Bob that she would help him with his project because she had experience."
RNN Processing:
Transformer Processing:
import torch
import torch.nn as nn
from torch.nn import functional as F
def measure_context_utilization():
"""
Measure how effectively RNNs vs Transformers use context
Task: Predict masked word given full sentence context
"""
# Example sentences with varying dependency distances
sentences = [
"The cat sat on the [MASK]", # Short dependency
"The old cat that belonged to my grandmother sat on the [MASK]", # Medium
"The old gray cat that had belonged to my late grandmother and now lives with me sat on the [MASK]", # Long
]
# Simulate RNN context representation
def rnn_context(words, mask_position):
"""RNN only has final hidden state"""
# Information from early words degraded
relevance = [0.95 ** (mask_position - i) for i in range(len(words))]
return relevance
# Simulate Transformer context representation
def transformer_context(words, mask_position):
"""Transformer can attend to all words equally"""
# All words equally accessible
relevance = [1.0] * len(words)
return relevance
for sent in sentences:
words = sent.split()
mask_pos = words.index("[MASK]")
print(f"\nSentence: {sent}")
print(f"Mask position: {mask_pos}")
rnn_rel = rnn_context(words, mask_pos)
trans_rel = transformer_context(words, mask_pos)
# Critical word "cat" is at position 1 or 2
cat_pos = 2 if "old" in sent else 1
print(f"RNN access to 'cat': {rnn_rel[cat_pos]:.3f}")
print(f"Transformer access to 'cat': {trans_rel[cat_pos]:.3f}")
# Output:
# Sentence: The cat sat on the [MASK]
# Mask position: 5
# RNN access to 'cat': 0.774 (good - short distance)
# Transformer access to 'cat': 1.000
#
# Sentence: The old cat that belonged to my grandmother sat on the [MASK]
# Mask position: 12
# RNN access to 'cat': 0.540 (degraded - medium distance)
# Transformer access to 'cat': 1.000
#
# Sentence: The old gray cat ... sat on the [MASK]
# Mask position: 19
# RNN access to 'cat': 0.377 (heavily degraded - long distance)
# Transformer access to 'cat': 1.000
measure_context_utilization()
The shift from RNNs to transformers wasn't just theoretical - it produced massive empirical improvements across all NLP benchmarks.
| Model | BLEU Score | Training Time | Parameters |
|---|---|---|---|
| RNN Seq2Seq (2014) | 20.5 | ~14 days | ~100M |
| LSTM + Attention (2015) | 25.9 | ~10 days | ~150M |
| Transformer (2017) | 28.4 | ~3.5 days | ~213M |
| Improvement: +2.5 BLEU points (10% better), 3ร faster training | |||
Best RNN: ~60 perplexity
Best Transformer: ~22 perplexity (2.7ร better!)
But the real story wasn't just better performance at the same scale - transformers enabled scaling that was impossible with RNNs:
RNN Era (2014-2017):
Transformer Era (2017-present):
๐ฏ Key Insight: Transformers didn't just improve performance by 10-20%. They enabled a 10,000ร scale-up in model size and training data. This fundamentally changed what was possible in AI.
Transformers solved RNN problems by asking a radical question: What if we don't use recurrence at all? Instead, use attention to let every token directly access every other token.
All tokens processed simultaneously
Impact: 100ร faster training for long sequences
Attention creates paths between any two tokens
Impact: Learn dependencies 1000+ tokens apart
Attention is one multiplication (not 100)
Impact: Gradient quality independent of distance
Parallelization enables massive scale
Impact: Models scaled from 200M to 1.7T parameters
| Property | RNN/LSTM | Transformer |
|---|---|---|
| Processing | Sequential | Parallel |
| Path Length | O(n) | O(1) |
| Gradient Flow | Vanishes | Stable |
| Max Context | ~100-200 tokens | 4K-128K tokens |
| Training Speed | Slow (O(n) steps) | Fast (O(1) steps) |
| Scalability | Limited (~200M params) | Unlimited (1.7T+ params) |
| GPU Utilization | 20-30% | 80-95% |
Q1: What is the main limitation of RNNs for long sequences?
Q2: Why can't RNNs be easily parallelized?
Q3: What is the "information bottleneck" problem in RNNs?
Q4: How do Transformers address the sequential processing limitation?
Q5: What was the key insight in the "Attention Is All You Need" paper?