Master the complete encoder-decoder architecture and all building blocks
π Complete all tutorials to earn your Free Transformers Architecture Certificate
Shareable on LinkedIn β’ Verified by AITutorials.site β’ No signup fee
You've mastered the individual components: self-attention, multi-head attention, positional encoding, and feed-forward networks. Now we'll assemble these pieces into the complete transformer architecture - the original encoder-decoder design from "Attention Is All You Need" (Vaswani et al., 2017).
Input: "The cat sat on the mat" (English)
Goal: "Le chat Γ©tait assis sur le tapis" (French)
Architecture: Encoder processes input β Decoder generates output token-by-token
The transformer architecture revolutionized NLP because it:
Unlike RNNs, all tokens processed simultaneously
Direct attention paths between any tokens
Scales to billions of parameters efficiently
Attention weights show what model "looks at"
The transformer consists of two main components: an encoder that processes the input sequence and a decoder that generates the output sequence. Both are stacks of identical layers.
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β INPUT: "The cat sat" β
β Token IDs: [142, 2368, 3332] β
ββββββββββββββββββββββββββββ¬βββββββββββββββββββββββββββββββββββββββ
β
βββββββββββββββββ
β EMBEDDING β
β + POSITIONAL β
β ENCODING β
βββββββββ¬ββββββββ
β
ββββββββββββββββββββββββββ
β ENCODER STACK (6x) β
β β
β Layer 1: β
β β’ Multi-Head Attn β βββ
β β’ Add & Norm β β Repeated
β β’ Feed-Forward β β 6 times
β β’ Add & Norm β β
β β β
β Layers 2-6: Same β βββ
β β
ββββββββββββββ¦ββββββββββββ
β
[Contextualized Representations]
[batch, seq_len, d_model]
β
ββββββββββββββ©ββββββββββββ
β DECODER STACK (6x) β
β β
β Layer 1: β
β β’ Masked Self-Attn β (can't see future)
β β’ Add & Norm β
β β’ Cross-Attention β (attends to encoder)
β β’ Add & Norm β
β β’ Feed-Forward β
β β’ Add & Norm β
β β
β Layers 2-6: Same β
β β
ββββββββββββββ¦ββββββββββββ
β
βββββββββββββββββ
β Linear Layer β
β (d_model β β
β vocab_size) β
βββββββββ¬ββββββββ
β
βββββββββββββββββ
β Softmax β
βββββββββ¬ββββββββ
β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β OUTPUT: "Le chat Γ©tait" β
β Probabilities: [P(Le), P(chat), P(Γ©tait), ...] β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
| Parameter | Value | What It Means |
|---|---|---|
d_model |
512 | Dimension of all token representations |
num_heads |
8 | Number of attention heads (512 Γ· 8 = 64 per head) |
d_ff |
2048 | Feed-forward hidden dimension (4Γ d_model) |
num_layers |
6 | Encoder layers + Decoder layers (6 each) |
dropout |
0.1 | Regularization rate (10% neurons dropped) |
vocab_size |
~37K | Shared vocabulary (byte-pair encoding) |
Everything is d_model: Token embeddings, attention outputs, feed-forward outputs - all are 512-dimensional vectors. This uniform dimensionality allows residual connections to work seamlessly (you can add outputs from different layers).
import torch
# Example: Translating "Hello world" to French
batch_size = 1
src_seq_len = 2 # "Hello", "world"
tgt_seq_len = 2 # "", "Bonjour" (during training)
# Step 1: Tokenize input
src_tokens = torch.tensor([[5043, 2088]]) # "Hello", "world"
# Step 2: Encoder processes input
# src_tokens β Embedding β Positional Encoding β 6 Encoder Layers
encoder_output = model.encoder(src_tokens)
print(f"Encoder output shape: {encoder_output.shape}")
# Output: torch.Size([1, 2, 512])
# Meaning: batch=1, src_len=2, d_model=512
# Step 3: Decoder generates output (autoregressive)
tgt_tokens = torch.tensor([[1, 34523]]) # "", "Bonjour"
decoder_output = model.decoder(
tgt_tokens,
encoder_output # Cross-attention uses this
)
print(f"Decoder output shape: {decoder_output.shape}")
# Output: torch.Size([1, 2, 37000])
# Meaning: batch=1, tgt_len=2, vocab_size=37000
# Step 4: Take argmax to get predicted tokens
predictions = decoder_output.argmax(dim=-1)
print(f"Predicted tokens: {predictions}")
# Output: tensor([[34523, 23456]]) β "Bonjour", "monde"
The encoder's job is to process the input sequence and create rich, contextualized representations where each token's embedding incorporates information from the entire sequence. It does this through stacked layers of self-attention and feed-forward networks.
Each encoder layer has two main sub-layers, each with residual connections and layer normalization:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class EncoderLayer(nn.Module):
"""
Single encoder layer with:
1. Multi-head self-attention
2. Position-wise feed-forward network
Both with residual connections and layer normalization
"""
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
# Sub-layer 1: Multi-head self-attention
self.self_attention = MultiHeadAttention(d_model, num_heads, dropout)
# Sub-layer 2: Feed-forward network
# Classic architecture: d_model β 4*d_model β d_model
self.feed_forward = nn.Sequential(
nn.Linear(d_model, d_ff), # Expand (512 β 2048)
nn.ReLU(), # Non-linearity
nn.Dropout(dropout),
nn.Linear(d_ff, d_model), # Contract (2048 β 512)
nn.Dropout(dropout)
)
# Layer normalization (applied after residual connection)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
"""
Args:
x: [batch, seq_len, d_model] input tensor
mask: [batch, seq_len, seq_len] attention mask (optional)
Returns:
[batch, seq_len, d_model] processed tensor
"""
# Sub-layer 1: Multi-head self-attention
# Pattern: x = LayerNorm(x + Sublayer(x))
attn_output, _ = self.self_attention(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_output)) # Residual + Norm
# Sub-layer 2: Feed-forward
ff_output = self.feed_forward(x)
x = self.norm2(x + ff_output) # Residual + Norm
return x
# Let's trace what happens to a single token through one encoder layer:
print("=" * 70)
print("Token Flow Through One Encoder Layer")
print("=" * 70)
batch, seq_len, d_model = 1, 5, 512
x = torch.randn(batch, seq_len, d_model)
print(f"Input shape: {x.shape}") # [1, 5, 512]
# Create encoder layer
layer = EncoderLayer(d_model=512, num_heads=8, d_ff=2048, dropout=0.1)
# Forward pass
output = layer(x)
print(f"Output shape: {output.shape}") # [1, 5, 512] - same!
print("\nKey observation: Input and output shapes are IDENTICAL.")
print("This allows stacking many layers: output of layer N β input of layer N+1")
class TransformerEmbedding(nn.Module):
"""
Combines token embeddings with positional encoding.
"""
def __init__(self, vocab_size, d_model, max_seq_len=5000, dropout=0.1):
super().__init__()
self.token_embedding = nn.Embedding(vocab_size, d_model)
self.positional_encoding = self._create_positional_encoding(max_seq_len, d_model)
self.dropout = nn.Dropout(dropout)
self.d_model = d_model
def _create_positional_encoding(self, max_seq_len, d_model):
"""Sinusoidal positional encoding from 'Attention Is All You Need'."""
pe = torch.zeros(max_seq_len, d_model)
position = torch.arange(0, max_seq_len).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
-(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term) # Even indices
pe[:, 1::2] = torch.cos(position * div_term) # Odd indices
return pe.unsqueeze(0) # [1, max_seq_len, d_model]
def forward(self, x):
"""
Args:
x: [batch, seq_len] token IDs
Returns:
[batch, seq_len, d_model] embeddings + positional encoding
"""
seq_len = x.size(1)
# Token embedding (scaled by sqrt(d_model) - from paper)
x = self.token_embedding(x) * math.sqrt(self.d_model)
# Add positional encoding
x = x + self.positional_encoding[:, :seq_len, :].to(x.device)
return self.dropout(x)
class Encoder(nn.Module):
"""
Complete transformer encoder: embedding + N encoder layers.
"""
def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff,
dropout=0.1, max_seq_len=5000):
super().__init__()
# Embedding layer (tokens + positions)
self.embedding = TransformerEmbedding(vocab_size, d_model,
max_seq_len, dropout)
# Stack of N identical encoder layers
self.layers = nn.ModuleList([
EncoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
self.num_layers = num_layers
def forward(self, src, mask=None):
"""
Args:
src: [batch, src_len] source token IDs
mask: [batch, 1, src_len] or [batch, src_len, src_len] mask
Returns:
[batch, src_len, d_model] encoded representations
"""
# Step 1: Embed tokens and add positional encoding
x = self.embedding(src)
# Step 2: Pass through encoder layers sequentially
for layer in self.layers:
x = layer(x, mask)
return x
# Example: Building and using an encoder
print("\n" + "=" * 70)
print("Complete Encoder Example")
print("=" * 70)
encoder = Encoder(
vocab_size=50000, # English vocabulary size
d_model=512, # Representation dimension
num_layers=6, # 6 encoder layers (standard)
num_heads=8, # 8 attention heads
d_ff=2048, # Feed-forward dimension (4Γ d_model)
dropout=0.1
)
# Input: batch of English sentences
# "The cat sat" β token IDs [142, 2368, 3332]
# "Hello world" β token IDs [5043, 2088, 0] (padded)
src_tokens = torch.tensor([
[142, 2368, 3332],
[5043, 2088, 0]
])
print(f"Input tokens shape: {src_tokens.shape}") # [2, 3]
# Forward pass
encoder_output = encoder(src_tokens)
print(f"Encoder output shape: {encoder_output.shape}") # [2, 3, 512]
print("\nEach token now has a 512-dim representation incorporating")
print("information from ALL tokens in the sentence via self-attention.")
# Total parameters
total_params = sum(p.numel() for p in encoder.parameters())
print(f"\nTotal parameters: {total_params:,}")
# Approximately 45-50 million parameters for this configuration
Result: Each token's final representation incorporates information from the entire input sequence through 6 layers of self-attention.
# Visualize what the encoder learns to attend to
import matplotlib.pyplot as plt
def visualize_encoder_attention(sentence, attention_weights):
"""
Show attention patterns in encoder.
Args:
sentence: List of tokens
attention_weights: [num_layers, num_heads, seq_len, seq_len]
"""
num_layers = attention_weights.shape[0]
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()
for layer_idx in range(min(num_layers, 6)):
# Average across heads for this layer
layer_attn = attention_weights[layer_idx].mean(axis=0)
ax = axes[layer_idx]
im = ax.imshow(layer_attn, cmap='Blues', aspect='auto')
# Labels
ax.set_xticks(range(len(sentence)))
ax.set_yticks(range(len(sentence)))
ax.set_xticklabels(sentence, rotation=45)
ax.set_yticklabels(sentence)
ax.set_title(f'Layer {layer_idx + 1} Attention')
ax.set_xlabel('Attended to')
ax.set_ylabel('Attending from')
plt.colorbar(im, ax=ax)
plt.tight_layout()
plt.show()
# Example output shows:
# Layer 1: Mostly attends to adjacent tokens (local patterns)
# Layer 3: Attends to syntactically related tokens (verb β subject)
# Layer 6: Attends to semantically related tokens (pronoun β antecedent)
Layer normalization is critical for training deep transformers. Without it, you'll face exploding/vanishing gradients and unstable training. Let's understand why it's essential and how it works.
import torch
import torch.nn as nn
# Simulate deep network without normalization
d_model = 512
num_layers = 12
x = torch.randn(1, 10, d_model) # [batch, seq_len, d_model]
print(f"Initial: mean={x.mean():.4f}, std={x.std():.4f}")
# Pass through multiple layers
for i in range(num_layers):
# Simulate attention + FFN
x = nn.Linear(d_model, d_model)(x) + x # Residual
if (i + 1) % 3 == 0:
print(f"Layer {i+1}: mean={x.mean():.4f}, std={x.std():.4f}")
# Output without normalization:
# Initial: mean=0.0023, std=1.0015
# Layer 3: mean=0.1234, std=3.4567 β Growing!
# Layer 6: mean=0.5678, std=12.3456 β Exploding!
# Layer 9: mean=1.2345, std=45.6789 β Unstable!
# Layer 12: mean=2.3456, std=156.789 β Training will fail!
class LayerNorm(nn.Module):
"""
Layer normalization: normalize across the feature dimension.
"""
def __init__(self, d_model, eps=1e-6):
super().__init__()
# Learnable parameters (initialized to 1 and 0)
self.gamma = nn.Parameter(torch.ones(d_model)) # Scale
self.beta = nn.Parameter(torch.zeros(d_model)) # Shift
self.eps = eps # Small constant for numerical stability
def forward(self, x):
"""
Args:
x: [batch, seq_len, d_model]
Returns:
[batch, seq_len, d_model] normalized tensor
"""
# Compute mean and std across d_model dimension (dim=-1)
mean = x.mean(dim=-1, keepdim=True) # [batch, seq_len, 1]
std = x.std(dim=-1, keepdim=True) # [batch, seq_len, 1]
# Normalize: zero mean, unit variance
x_norm = (x - mean) / (std + self.eps)
# Apply learnable affine transformation
# gamma and beta are learned during training
output = self.gamma * x_norm + self.beta
return output
# Demonstration
print("=" * 70)
print("Layer Normalization in Action")
print("=" * 70)
batch, seq_len, d_model = 2, 5, 512
x = torch.randn(batch, seq_len, d_model) * 10 # Deliberately large values
print(f"Before LayerNorm:")
print(f" Shape: {x.shape}")
print(f" Token 0 stats: mean={x[0,0].mean():.4f}, std={x[0,0].std():.4f}")
print(f" Token 1 stats: mean={x[0,1].mean():.4f}, std={x[0,1].std():.4f}")
ln = LayerNorm(d_model)
x_norm = ln(x)
print(f"\nAfter LayerNorm:")
print(f" Shape: {x_norm.shape}") # Same shape
print(f" Token 0 stats: mean={x_norm[0,0].mean():.4f}, std={x_norm[0,0].std():.4f}")
print(f" Token 1 stats: mean={x_norm[0,1].mean():.4f}, std={x_norm[0,1].std():.4f}")
# Output:
# Before LayerNorm:
# Token 0 stats: mean=0.5234, std=9.8765
# Token 1 stats: mean=-1.2345, std=12.3456
#
# After LayerNorm:
# Token 0 stats: meanβ0.0000, stdβ1.0000
# Token 1 stats: meanβ0.0000, stdβ1.0000
| Aspect | Batch Norm | Layer Norm |
|---|---|---|
| Normalize Over | Batch dimension (across examples) | Feature dimension (across d_model) |
| Statistics | Mean/std computed per feature across batch | Mean/std computed per example across features |
| Batch Size Dependent? | β Yes (requires large batches) | β No (works with batch_size=1) |
| Inference | Uses running statistics from training | Computes stats on-the-fly |
| Sequential Models | β Breaks RNNs/autoregressive generation | β Perfect for transformers |
| Typical Use | CNNs (computer vision) | Transformers (NLP) |
# Batch Norm: Normalize across batch (vertically)
# Tensor shape: [batch=4, seq_len=3, d_model=512]
#
# Feature 0 Feature 1 ... Feature 511
# Example 0: xββ xββ ... xβ,β
ββ
# Example 1: xββ xββ ... xβ,β
ββ
# Example 2: xββ xββ ... xβ,β
ββ
# Example 3: xββ xββ ... xβ,β
ββ
# β β β
# Normalize across Normalize across ...
# examples (batch) examples (batch)
# Layer Norm: Normalize across features (horizontally)
# Tensor shape: [batch=4, seq_len=3, d_model=512]
#
# Feature 0 Feature 1 ... Feature 511
# Example 0: xββ β xββ β ... β xβ,β
ββ
# βββββββββββββββββββββββββββββββ
# Normalize across all 512 features
#
# Example 1: xββ β xββ β ... β xβ,β
ββ
# βββββββββββββββββββββββββββββββ
# Normalize independently
def encoder_layer_with_norms(x, attention, ffn, norm1, norm2):
"""
Modern transformer: Post-norm (norm AFTER residual).
"""
# Sub-layer 1: Attention
attn_out = attention(x)
x = norm1(x + attn_out) # Residual + Norm
# Sub-layer 2: Feed-forward
ffn_out = ffn(x)
x = norm2(x + ffn_out) # Residual + Norm
return x
# Note: Original paper used Post-Norm
# Modern variants often use Pre-Norm (norm BEFORE sub-layer)
# Pre-Norm is easier to train but slightly less expressive
def encoder_layer_pre_norm(x, attention, ffn, norm1, norm2):
"""
Pre-norm variant (easier to train, used in GPT-2/3).
"""
# Sub-layer 1: Attention
attn_out = attention(norm1(x)) # Norm BEFORE attention
x = x + attn_out # Residual
# Sub-layer 2: Feed-forward
ffn_out = ffn(norm2(x)) # Norm BEFORE FFN
x = x + ffn_out # Residual
return x
Residual connections (also called skip connections) are the second critical ingredient that enables training very deep transformers. They create a "gradient superhighway" that allows information and gradients to flow directly through many layers.
import torch
import torch.nn as nn
# Simulate gradient flow through deep network WITHOUT residuals
def gradient_flow_no_residual(num_layers=12):
"""
Show how gradients vanish in deep networks without residuals.
"""
# Start with gradient = 1.0 at output
grad = 1.0
print("Gradient magnitude flowing backward through layers:")
print("=" * 60)
for layer in range(num_layers, 0, -1):
# Each layer's backward pass multiplies gradient by ~0.8
# (This is typical for networks with sigmoid/tanh activations)
grad *= 0.8
if layer % 3 == 0:
print(f"Layer {layer}: gradient = {grad:.6f}")
print(f"\nFinal gradient at input layer: {grad:.10f}")
print(f"Gradient has vanished! ({grad:.2e})")
gradient_flow_no_residual()
# Output:
# Layer 12: gradient = 0.800000
# Layer 9: gradient = 0.262144
# Layer 6: gradient = 0.085900
# Layer 3: gradient = 0.028147
#
# Final gradient at input layer: 0.0068719477
# Gradient has vanished! (6.87e-03)
#
# With 12 layers, gradient is 1000Γ smaller!
# Learning in early layers becomes impossibly slow.
In deep networks without residual connections:
class ResidualBlock(nn.Module):
"""
Basic residual block: output = input + F(input)
"""
def __init__(self, d_model):
super().__init__()
self.layer = nn.Linear(d_model, d_model)
def forward(self, x):
# Key insight: ADD input to output
return x + self.layer(x)
# β ββ learned transformation
# ββ identity shortcut (skip connection)
# Without residual
class NormalBlock(nn.Module):
def __init__(self, d_model):
super().__init__()
self.layer = nn.Linear(d_model, d_model)
def forward(self, x):
return self.layer(x) # Just the transformation
# Compare gradient flow
print("\n" + "=" * 70)
print("Gradient Flow: With vs. Without Residuals")
print("=" * 70)
d_model = 512
x = torch.randn(1, 10, d_model, requires_grad=True)
# Without residual
normal_block = NormalBlock(d_model)
out_normal = normal_block(x)
loss_normal = out_normal.sum()
loss_normal.backward()
print(f"Without residual: βloss/βx max gradient = {x.grad.abs().max():.6f}")
# With residual
x.grad = None # Reset gradient
residual_block = ResidualBlock(d_model)
out_residual = residual_block(x)
loss_residual = out_residual.sum()
loss_residual.backward()
print(f"With residual: βloss/βx max gradient = {x.grad.abs().max():.6f}")
# Output:
# Without residual: βloss/βx max gradient = 0.134567
# With residual: βloss/βx max gradient = 1.234567 β Much larger!
Forward pass:
y = x + F(x)
β β
| ββ Learned transformation (attention, FFN, etc.)
ββββββ Identity connection (gradient superhighway)
Backward pass (chain rule):
βloss/βx = βloss/βy Β· βy/βx
= βloss/βy Β· β(x + F(x))/βx
= βloss/βy Β· (1 + βF/βx)
β ββ Learned gradient (might be small)
βββββββ Direct gradient (always = 1)
Key insight: Even if βF/βx β 0 (vanishing), the "+1" ensures gradients always flow!
class EncoderLayerWithResiduals(nn.Module):
"""
Encoder layer showing explicit residual connections.
"""
def __init__(self, d_model, num_heads, d_ff):
super().__init__()
self.attention = MultiHeadAttention(d_model, num_heads)
self.ffn = FeedForward(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x):
# Residual #1: Around attention
residual1 = x # Save input
x = self.attention(x)
x = self.norm1(residual1 + x) # Add residual
# β β
# | ββ Attention output
# βββββββββββ Original input (skip connection)
# Residual #2: Around feed-forward
residual2 = x # Save input
x = self.ffn(x)
x = self.norm2(residual2 + x) # Add residual
# β β
# | ββ FFN output
# ββββββββββββ Output from previous sub-layer
return x
# Visualize information flow
print("\n" + "=" * 70)
print("Information Flow with Residuals")
print("=" * 70)
x = torch.randn(1, 5, 512)
print(f"Input shape: {x.shape}")
# Track what happens
layer = EncoderLayerWithResiduals(512, 8, 2048)
# The input has TWO paths through the layer:
# Path 1 (Direct): x β (skip around attention) β (skip around FFN) β output
# Path 2 (Learned): x β attention β FFN β output
#
# Both paths are summed together!
output = layer(x)
print(f"Output shape: {output.shape}") # Same as input
print("\nKey observation:")
print("- Input can flow DIRECTLY to output (via residual connections)")
print("- OR transform through attention and FFN")
print("- Final output = combination of both paths")
print("- Gradients can flow backward through either path!")
def compare_deep_networks(num_layers=50):
"""
Show that residual networks can be trained much deeper.
"""
d_model = 512
print(f"\n{'='*70}")
print(f"Training {num_layers}-Layer Networks")
print(f"{'='*70}")
# Simulate gradient flow backward through many layers
# WITHOUT residuals
grad_no_residual = 1.0
for _ in range(num_layers):
grad_no_residual *= 0.9 # Each layer multiplies by 0.9
print(f"\nWithout residuals:")
print(f" Final gradient: {grad_no_residual:.15f}")
print(f" Scientific notation: {grad_no_residual:.2e}")
print(f" Gradient has vanished! Too small to train.")
# WITH residuals
# Residual path: gradient = 1 (direct path)
# Learned path: gradient multiplies
grad_residual = 1.0 # Always 1 from residual path!
grad_learned = 1.0
for _ in range(num_layers):
grad_learned *= 0.9
grad_total = grad_residual + grad_learned
print(f"\nWith residuals:")
print(f" Gradient via residual path: {grad_residual:.2f} (always 1.0!)")
print(f" Gradient via learned path: {grad_learned:.2e}")
print(f" Total gradient: {grad_total:.4f}")
print(f" β
Gradient is strong! Network can train!")
print(f"\n{'='*70}")
print(f"Residuals enable {num_layers}Γ deeper networks!")
print(f"{'='*70}")
compare_deep_networks(num_layers=50)
# Output shows:
# Without residuals: gradient β 5.15e-03 (vanished!)
# With residuals: gradient β 1.0052 (strong!)
class ModernTransformerLayer(nn.Module):
"""
Modern best practice: Residual + LayerNorm together.
"""
def __init__(self, d_model, num_heads, d_ff):
super().__init__()
self.attention = MultiHeadAttention(d_model, num_heads)
self.ffn = FeedForward(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x):
# Sub-layer 1: Attention
x = self.norm1(x + self.attention(x))
# β ββ Transformation
# ββββββ Residual
# ββ Normalize after adding
# Sub-layer 2: Feed-forward
x = self.norm2(x + self.ffn(x))
# β ββ Transformation
# ββββββ Residual
# ββ Normalize after adding
return x
# Why this works so well:
# 1. Residual ensures gradients flow
# 2. LayerNorm keeps activations stable
# 3. Together: can train 100+ layers reliably
# 4. Used in ALL modern transformers (BERT, GPT, T5, etc.)
After attention (which enables tokens to communicate), each token passes through a position-wise feed-forward network. This is where transformers store learned knowledge, facts, and complex transformations.
import torch.nn as nn
class FeedForward(nn.Module):
"""
Position-wise feed-forward network (applied independently to each token).
Classic architecture: d_model β 4*d_model β d_model
"""
def __init__(self, d_model, d_ff, dropout=0.1):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_model, d_ff), # Expand: 512 β 2048
nn.ReLU(), # Non-linearity
nn.Dropout(dropout), # Regularization
nn.Linear(d_ff, d_model), # Contract: 2048 β 512
nn.Dropout(dropout)
)
def forward(self, x):
"""
Args:
x: [batch, seq_len, d_model]
Returns:
[batch, seq_len, d_model] (same shape!)
"""
return self.net(x)
# Example usage
print("=" * 70)
print("Feed-Forward Network Example")
print("=" * 70)
batch, seq_len, d_model = 2, 5, 512
d_ff = 2048 # 4Γ expansion
x = torch.randn(batch, seq_len, d_model)
print(f"Input shape: {x.shape}")
ffn = FeedForward(d_model, d_ff)
output = ffn(x)
print(f"Output shape: {output.shape}") # Same as input!
# Count parameters
params = sum(p.numel() for p in ffn.parameters())
print(f"\nFFN Parameters: {params:,}")
# First layer: 512 Γ 2048 + 2048 bias = 1,050,624
# Second layer: 2048 Γ 512 + 512 bias = 1,049,088
# Total: ~2.1 million parameters
# Most of transformer's parameters are in FFN!
Create high-dimensional space. More dimensions = more capacity to learn complex transformations.
Non-linearity is crucial. Without it, the whole FFN collapses to a single linear transformation.
Project back to d_model so output can be added via residual connection to input.
# CRITICAL: FFN processes each token INDEPENDENTLY
# No information is exchanged between tokens (unlike attention)
x = torch.randn(1, 3, 512) # 3 tokens
# Token 0: [0, 0.5, -0.2, ...] 512 values
# Token 1: [1.2, -0.8, 0.4, ...] 512 values
# Token 2: [0.3, 0.1, -1.0, ...] 512 values
ffn = FeedForward(512, 2048)
# Process all tokens
output = ffn(x)
# Verify: processing tokens individually gives same result
output_token0 = ffn(x[:, 0:1, :]) # Just token 0
output_token1 = ffn(x[:, 1:2, :]) # Just token 1
output_token2 = ffn(x[:, 2:3, :]) # Just token 2
combined = torch.cat([output_token0, output_token1, output_token2], dim=1)
print("Are they equal?", torch.allclose(output, combined))
# Output: True
# This is different from attention, where:
# - Token 0's output depends on tokens 0, 1, 2 (via attention)
# - FFN: Token 0's output depends ONLY on token 0's input
Historical and Empirical Choice:
| Model | d_model | d_ff | Ratio |
|---|---|---|---|
| Original Transformer | 512 | 2048 | 4Γ |
| BERT-Base | 768 | 3072 | 4Γ |
| GPT-2 | 768 | 3072 | 4Γ |
| GPT-3 | 12288 | 49152 | 4Γ |
| LLaMA-2 70B | 8192 | 28672 | 3.5Γ |
import torch.nn.functional as F
# Original: ReLU
class FFN_ReLU(nn.Module):
def __init__(self, d_model, d_ff):
super().__init__()
self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_ff, d_model)
def forward(self, x):
return self.fc2(F.relu(self.fc1(x)))
# Modern: GELU (Gaussian Error Linear Unit)
# Used in BERT, GPT-2/3, most modern transformers
class FFN_GELU(nn.Module):
def __init__(self, d_model, d_ff):
super().__init__()
self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_ff, d_model)
def forward(self, x):
return self.fc2(F.gelu(self.fc1(x)))
# Advanced: SwiGLU (Swish-Gated Linear Unit)
# Used in PaLM, LLaMA - best performance but more parameters
class FFN_SwiGLU(nn.Module):
def __init__(self, d_model, d_ff):
super().__init__()
# SwiGLU needs TWO projections for gating
self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_model, d_ff)
self.fc3 = nn.Linear(d_ff, d_model)
def forward(self, x):
# Gating mechanism: multiply two paths
gate = F.silu(self.fc1(x)) # Swish activation
value = self.fc2(x)
return self.fc3(gate * value)
# Comparison
x = torch.randn(1, 5, 512)
ffn_relu = FFN_ReLU(512, 2048)
ffn_gelu = FFN_GELU(512, 2048)
ffn_swiglu = FFN_SwiGLU(512, 2048)
out_relu = ffn_relu(x)
out_gelu = ffn_gelu(x)
out_swiglu = ffn_swiglu(x)
print("Output shapes:")
print(f" ReLU: {out_relu.shape}")
print(f" GELU: {out_gelu.shape}")
print(f" SwiGLU: {out_swiglu.shape}")
# GELU and SwiGLU generally give 1-2% better performance
# but are slightly more expensive to compute
Research suggests FFN layers act as key-value memories:
The high-dimensional intermediate space (d_ff) allows storing millions of such patterns as distributed representations in the weight matrices.
The decoder generates output tokens one at a time, attending to both: 1. Previously generated tokens (self-attention, causal masked) 2. Encoder output (cross-attention)
class DecoderLayer(torch.nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
# Masked self-attention (can't look at future tokens)
self.self_attention = MultiHeadAttention(d_model, num_heads)
# Cross-attention (attends to encoder output)
self.cross_attention = MultiHeadAttention(d_model, num_heads)
# Feed-forward
self.ff = torch.nn.Sequential(
torch.nn.Linear(d_model, d_ff),
torch.nn.ReLU(),
torch.nn.Linear(d_ff, d_model)
)
# Normalization and dropout
self.norm1 = torch.nn.LayerNorm(d_model)
self.norm2 = torch.nn.LayerNorm(d_model)
self.norm3 = torch.nn.LayerNorm(d_model)
self.dropout = torch.nn.Dropout(dropout)
def forward(self, x, encoder_output, self_attention_mask=None,
cross_attention_mask=None):
# Masked self-attention (can't see future)
self_attn_out, _ = self.self_attention(x, self_attention_mask)
x = self.norm1(x + self.dropout(self_attn_out))
# Cross-attention to encoder
cross_attn_out, _ = self.cross_attention(
query=x,
key=encoder_output,
value=encoder_output,
mask=cross_attention_mask
)
x = self.norm2(x + self.dropout(cross_attn_out))
# Feed-forward
ff_out = self.ff(x)
x = self.norm3(x + self.dropout(ff_out))
return x
class Decoder(torch.nn.Module):
def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff):
super().__init__()
self.embedding = TransformerEmbedding(vocab_size, d_model)
self.layers = torch.nn.ModuleList([
DecoderLayer(d_model, num_heads, d_ff)
for _ in range(num_layers)
])
self.output_projection = torch.nn.Linear(d_model, vocab_size)
def forward(self, decoder_input_ids, encoder_output,
decoder_mask=None, cross_mask=None):
x = self.embedding(decoder_input_ids)
for layer in self.layers:
x = layer(x, encoder_output, decoder_mask, cross_mask)
logits = self.output_projection(x) # [batch, seq_len, vocab_size]
return logits
# Key difference from encoder:
# - Self-attention is MASKED (can't look forward)
# - Added cross-attention (attends to encoder)
# - Output projection to vocabulary
During generation, the decoder can't peek at future tokens:
def create_causal_mask(seq_len):
"""
Lower triangular matrix: position i can attend to 0...i
"""
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
# False = can attend, True = mask out
return mask
# For inference with 1 token:
mask = create_causal_mask(1)
# Masks no future tokens (there are none)
# For 2 tokens:
mask = create_causal_mask(2)
# Token 0 can attend to: 0 β
# Token 1 can attend to: 0, 1 β
# Both see all past + present, nothing future
# In attention computation:
scores = torch.matmul(Q, K.T) / sqrt(d_k)
scores = scores.masked_fill(mask, float('-inf'))
weights = softmax(scores, dim=-1) # Softmax ignores -inf positions
class Transformer(torch.nn.Module):
def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512,
num_layers=6, num_heads=8, d_ff=2048, dropout=0.1):
super().__init__()
self.encoder = Encoder(src_vocab_size, d_model, num_layers,
num_heads, d_ff, dropout)
self.decoder = Decoder(tgt_vocab_size, d_model, num_layers,
num_heads, d_ff, dropout)
def forward(self, src_ids, tgt_ids, src_mask=None, tgt_mask=None,
cross_mask=None):
"""
src_ids: [batch, src_len] - source tokens
tgt_ids: [batch, tgt_len] - target tokens
"""
# Encoder
encoder_output = self.encoder(src_ids, src_mask)
# Decoder
decoder_output = self.decoder(tgt_ids, encoder_output,
tgt_mask, cross_mask)
return decoder_output # [batch, tgt_len, vocab_size]
# Training example
model = Transformer(
src_vocab_size=50000,
tgt_vocab_size=50000,
d_model=512,
num_layers=6,
num_heads=8,
d_ff=2048
)
src_ids = torch.randint(0, 50000, (2, 10)) # English
tgt_ids = torch.randint(0, 50000, (2, 12)) # French (shifted)
logits = model(src_ids, tgt_ids)
print(logits.shape) # [2, 12, 50000]
# For generation, decode greedily or with sampling
# (covered in next module)
What happens as information flows through the 12-layer stack?
Early Layers (1-2): Local syntactic patterns
Middle Layers (3-4): Syntactic and semantic structure
Late Layers (5-6): High-level semantics
| Component | Why Used |
|---|---|
| Multi-Head Attention | Different representation subspaces, parallel attention patterns |
| Layer Norm | Stabilize training, normalization across features (not batch) |
| Residual Connections | Enable deep networks, gradient flow through many layers |
| Feed-Forward FFN | Non-linearity, increased model capacity, knowledge storage |
| Positional Encoding | Add position since attention is permutation-invariant |
| Causal Masking | Prevent looking at future tokens during generation |
| Cross-Attention | Connect decoder to encoder, provide context for generation |
Q1: What are the two main components of the original Transformer architecture?
Q2: What is the purpose of the feedforward network in each Transformer block?
Q3: What type of attention does the decoder use when looking at encoder outputs?
Q4: What is the purpose of layer normalization in Transformers?
Q5: What are residual connections (skip connections) used for?