Understand multiple attention perspectives and how transformers track position without recurrence
🎓 Complete all tutorials to earn your Free Transformers Architecture Certificate
Shareable on LinkedIn • Verified by AITutorials.site • No signup fee
Single attention is powerful, but what if we computed attention from multiple different angles simultaneously? That's the revolutionary idea behind multi-head attention - the mechanism that gives transformers their remarkable ability to understand language.
⚠️ Limitation: One attention mechanism can't capture everything!
Consider the sentence: "The lawyer questioned the witness about the crime."
Problem: A single attention head must choose one pattern to prioritize. It can't simultaneously capture syntax, semantics, and discourse structure!
Instead of forcing one attention mechanism to do everything, run multiple attention mechanisms in parallel - each can specialize in different linguistic phenomena.
Your brain doesn't process vision with one system:
Multi-head attention works the same way: Different "heads" (pathways) each focus on different linguistic features, then their outputs are combined for complete understanding.
| Model | d_model | Num Heads | d_k per head | Notes |
|---|---|---|---|---|
| BERT-base | 768 | 12 | 64 | 768 / 12 = 64 per head |
| GPT-2 | 768 | 12 | 64 | Same as BERT-base |
| GPT-3 | 12,288 | 96 | 128 | 12,288 / 96 = 128 per head |
| LLaMA-2 70B | 8,192 | 64 | 128 | 8,192 / 64 = 128 per head |
| GPT-4 (est.) | ~18,000 | ~128 | ~140 | Unconfirmed architecture |
📊 Pattern: The ratio d_model / num_heads stays remarkably consistent (64-128) across models. This suggests there's an optimal "head dimension" that balances expressiveness with computational efficiency.
Multi-head attention is elegantly simple: run multiple attention mechanisms in parallel, each with its own learned projections. Let's break down the process:
import torch
import torch.nn.functional as F
# Setup
tokens = ["The", "cat", "sat"]
seq_len = 3
d_model = 8 # Small for illustration
num_heads = 2
d_k = d_model // num_heads # 8 / 2 = 4 per head
# Input embeddings (normally from embedding layer)
X = torch.tensor([
[0.1, 0.2, -0.1, 0.3, 0.4, -0.2, 0.1, 0.0], # "The"
[0.3, -0.1, 0.5, 0.2, 0.1, 0.4, -0.3, 0.2], # "cat"
[-0.2, 0.4, 0.1, -0.3, 0.5, 0.1, 0.2, -0.1] # "sat"
]).unsqueeze(0) # [1, 3, 8]
print("Input shape:", X.shape)
print("We have 2 heads, each will get 4 dimensions\n")
# Step 1: Project to Q, K, V (for ALL heads at once)
W_Q = torch.nn.Linear(d_model, d_model, bias=False)
W_K = torch.nn.Linear(d_model, d_model, bias=False)
W_V = torch.nn.Linear(d_model, d_model, bias=False)
Q = W_Q(X) # [1, 3, 8]
K = W_K(X) # [1, 3, 8]
V = W_V(X) # [1, 3, 8]
print("After projection:")
print(f"Q shape: {Q.shape}")
print(f"K shape: {K.shape}")
print(f"V shape: {V.shape}\n")
# Step 2: Split into heads
# Reshape: [batch, seq, d_model] → [batch, seq, num_heads, d_k]
# Then transpose: [batch, num_heads, seq, d_k]
batch_size = X.shape[0]
Q = Q.view(batch_size, seq_len, num_heads, d_k).transpose(1, 2)
K = K.view(batch_size, seq_len, num_heads, d_k).transpose(1, 2)
V = V.view(batch_size, seq_len, num_heads, d_k).transpose(1, 2)
print("After splitting into heads:")
print(f"Q shape: {Q.shape} # [batch, num_heads, seq, d_k]")
print(f"K shape: {K.shape}")
print(f"V shape: {V.shape}\n")
print("Head 0 gets Q[:, 0, :, :] - dimensions 0-3")
print("Head 1 gets Q[:, 1, :, :] - dimensions 4-7\n")
# Step 3: Compute attention for BOTH heads in parallel
scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)
# [1, 2, 3, 3] - 2 attention matrices (one per head)
attention_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attention_weights, V)
# [1, 2, 3, 4] - 2 heads, each produces 3 tokens of 4 dimensions
print("After attention:")
print(f"Output shape: {output.shape} # [batch, num_heads, seq, d_k]")
print(f"\nHead 0 output: {output[0, 0, :, :].shape} # [3, 4]")
print(f"Head 1 output: {output[0, 1, :, :].shape} # [3, 4]\n")
# Step 4: Concatenate heads
# Transpose back: [batch, num_heads, seq, d_k] → [batch, seq, num_heads, d_k]
output = output.transpose(1, 2).contiguous()
# Reshape: [batch, seq, num_heads, d_k] → [batch, seq, d_model]
output = output.view(batch_size, seq_len, d_model)
print("After concatenating heads:")
print(f"Output shape: {output.shape} # [1, 3, 8]")
print("Back to original dimensions!\n")
# Step 5: Output projection
W_O = torch.nn.Linear(d_model, d_model, bias=False)
final_output = W_O(output)
print(f"Final output shape: {final_output.shape} # [1, 3, 8]")
print("\nThis output has information from BOTH heads combined!")
Original Q matrix: [batch, seq_len, d_model=512]
Dimension: 0 64 128 192 256 320 384 448 512
↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓
Token 0: [================================================]
Token 1: [================================================]
Token 2: [================================================]
Token 3: [================================================]
After reshaping to [batch, num_heads=8, seq_len, d_k=64]:
Head 0 (dims 0-63):
Token 0: [=======]
Token 1: [=======]
Token 2: [=======]
Token 3: [=======]
Head 1 (dims 64-127):
Token 0: [=======]
Token 1: [=======]
Token 2: [=======]
Token 3: [=======]
... (Heads 2-6) ...
Head 7 (dims 448-511):
Token 0: [=======]
Token 1: [=======]
Token 2: [=======]
Token 3: [=======]
Each head independently computes attention on its slice!
Then we concatenate everything back together.
All heads compute simultaneously on GPU. Same wall-clock time as single head!
Each head learns different patterns through gradient descent naturally.
Total params same as single attention with larger d_k. Same computation cost!
Model can blend different perspectives in output projection.
# Single-head attention
d_model = 512
d_k = 512 # Full dimension
Q = W_Q(X) # [batch, seq, 512]
K = W_K(X) # [batch, seq, 512]
V = W_V(X) # [batch, seq, 512]
scores = Q @ K.T / sqrt(512)
attn = softmax(scores)
output = attn @ V # [batch, seq, 512]
# Parameters: 3 × (512 × 512) = 786,432
#----------------------------------------------------------------
# Multi-head attention (8 heads)
num_heads = 8
d_k = 512 // 8 = 64 # Per-head dimension
Q = W_Q(X) # [batch, seq, 512] - same projection size!
K = W_K(X) # [batch, seq, 512]
V = W_V(X) # [batch, seq, 512]
# Reshape to [batch, 8, seq, 64]
# Each head operates on 64 dimensions
# Compute 8 attention patterns in parallel
outputs = [attention_head_i(Q_i, K_i, V_i) for i in range(8)]
output = concat(outputs) # [batch, seq, 512]
output = W_O(output) # Final projection
# Parameters: 3 × (512 × 512) + (512 × 512) = 1,048,576
# Only 33% more parameters, but 8× the representational capacity!
Research has shown that different attention heads spontaneously specialize in different linguistic phenomena during training. This isn't programmed - it emerges naturally through gradient descent!
Focus: Grammatical dependencies
Focus: Meaning relationships
Focus: Relative position
Focus: Broad information
Focus: Important tokens
Focus: Structure markers
from transformers import BertTokenizer, BertModel
import torch
import matplotlib.pyplot as plt
import seaborn as sns
# Load BERT and analyze attention patterns
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased', output_attentions=True)
model.eval()
# Example sentence with clear syntactic structure
text = "The elderly professor taught the enthusiastic students about quantum physics."
inputs = tokenizer(text, return_tensors='pt')
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
print(f"Tokens: {tokens}")
print(f"Analyzing {model.config.num_hidden_layers} layers, {model.config.num_attention_heads} heads per layer\n")
# Get attention weights
with torch.no_grad():
outputs = model(**inputs)
attentions = outputs.attentions # Tuple of 12 layers
# Analyze each head in layer 6 (middle layer often shows clear patterns)
layer_idx = 6
layer_attentions = attentions[layer_idx][0] # [num_heads, seq_len, seq_len]
print(f"Layer {layer_idx} Head Specialization:\n")
# Function to identify head type based on attention pattern
def analyze_head(attn_matrix, tokens):
"""Classify what pattern this head is focusing on."""
seq_len = len(tokens)
# Check for syntactic pattern (determiners → nouns)
det_indices = [i for i, t in enumerate(tokens) if t in ['the', 'a', 'an']]
if det_indices:
# Do following tokens attend to determiners?
det_attention = attn_matrix[:, det_indices].mean()
if det_attention > 0.3:
return "Syntactic (Determiner-Noun)"
# Check for positional pattern (diagonal/nearby)
diagonal_strength = torch.diag(attn_matrix).mean()
if diagonal_strength > 0.5:
return "Positional (Self-attention)"
# Check for next-token pattern
next_token_attn = torch.diag(attn_matrix, diagonal=1).mean()
if next_token_attn > 0.4:
return "Positional (Next-token)"
# Check for uniform (broadcast) pattern
attn_std = attn_matrix.std()
if attn_std < 0.15:
return "Global Context (Uniform)"
# Check for delimiter pattern
delim_indices = [i for i, t in enumerate(tokens) if t in ['.', ',', '[SEP]', '[CLS]']]
if delim_indices:
delim_attention = attn_matrix[:, delim_indices].mean()
if delim_attention > 0.4:
return "Delimiter-focused"
return "Semantic/Other"
# Analyze each head
for head_idx in range(model.config.num_attention_heads):
attn = layer_attentions[head_idx].cpu()
head_type = analyze_head(attn, tokens)
# Find strongest attention connections
max_val = attn.max().item()
max_pos = (attn == max_val).nonzero(as_tuple=True)
from_token = tokens[max_pos[0][0].item()]
to_token = tokens[max_pos[1][0].item()]
print(f"Head {head_idx:2d}: {head_type:<30s} | Strongest: {from_token} → {to_token} ({max_val:.3f})")
# Visualize different head types side by side
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
head_examples = [0, 2, 5, 7, 9, 11] # Select diverse heads
for idx, head_idx in enumerate(head_examples):
ax = axes[idx // 3, idx % 3]
attn = layer_attentions[head_idx].cpu().numpy()
sns.heatmap(attn, annot=False, cmap='YlOrRd',
xticklabels=tokens, yticklabels=tokens,
cbar_kws={'label': 'Attention'}, ax=ax)
head_type = analyze_head(layer_attentions[head_idx].cpu(), tokens)
ax.set_title(f'Head {head_idx}: {head_type}', fontsize=10)
ax.set_xlabel('Attending to', fontsize=8)
ax.set_ylabel('Attending from', fontsize=8)
ax.tick_params(labelsize=6)
plt.tight_layout()
plt.show()
📚 Key Research: "Are Sixteen Heads Really Better than One?" (Michel et al., 2019)
💡 Implication: This led to efficient models like DistilBERT (removes heads) and MobileBERT (shares heads) that maintain 97%+ performance with 40% fewer parameters.
| Layer Range | Typical Head Focus | Example Patterns |
|---|---|---|
| Layers 1-3 (Early) | Surface-level, positional | Previous token, next token, nearby words, punctuation |
| Layers 4-7 (Middle) | Syntactic structure | Subject-verb, verb-object, determiner-noun, modifiers |
| Layers 8-12 (Deep) | Semantic relationships | Coreference, semantic roles, discourse structure, topic |
import torch
import torch.nn.functional as F
class MultiHeadAttention(torch.nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.num_heads = num_heads
self.d_k = d_model // num_heads
# Single set of projection matrices
# These automatically split across heads
self.W_Q = torch.nn.Linear(d_model, d_model)
self.W_K = torch.nn.Linear(d_model, d_model)
self.W_V = torch.nn.Linear(d_model, d_model)
# Output projection to combine all heads
self.W_O = torch.nn.Linear(d_model, d_model)
def forward(self, X, mask=None):
batch_size, seq_len, d_model = X.shape
# Project to Q, K, V
Q = self.W_Q(X) # [batch, seq_len, d_model]
K = self.W_K(X)
V = self.W_V(X)
# Split into multiple heads
# Reshape to [batch, seq_len, num_heads, d_k]
# Then transpose to [batch, num_heads, seq_len, d_k]
Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
# Now: [batch, num_heads, seq_len, d_k]
# Compute attention for all heads in parallel
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
# [batch, num_heads, seq_len, seq_len]
# Apply mask if provided (e.g., causal mask)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Softmax
attention_weights = F.softmax(scores, dim=-1)
# Apply to values
output = torch.matmul(attention_weights, V)
# [batch, num_heads, seq_len, d_k]
# Reshape and concatenate heads
output = output.transpose(1, 2).contiguous()
# [batch, seq_len, num_heads, d_k]
output = output.view(batch_size, seq_len, d_model)
# [batch, seq_len, d_model]
# Final output projection
output = self.W_O(output)
return output, attention_weights
# Usage
d_model = 512
num_heads = 8
mha = MultiHeadAttention(d_model, num_heads)
X = torch.randn(2, 4, 512)
output, weights = mha(X)
print(output.shape) # [2, 4, 512]
print(weights.shape) # [2, 8, 4, 4] - one attention map per head!
Here's a fundamental problem with attention: it's permutation-invariant. The attention mechanism computes similarities between tokens, but doesn't inherently know where those tokens are in the sequence!
import torch
import torch.nn.functional as F
# Simple attention function (no position info)
def basic_attention(X):
"""X: [batch, seq_len, d_model]"""
Q = K = V = X # Simplified
scores = Q @ K.transpose(-2, -1)
attn = F.softmax(scores, dim=-1)
output = attn @ V
return output
# Two sentences with different word order but same tokens
tokens1 = ["The", "cat", "sat", "on", "the", "mat"]
tokens2 = ["mat", "the", "on", "sat", "cat", "The"] # Scrambled!
# Simulate embeddings (same embeddings, different order)
embeddings = torch.randn(6, 512) # 6 unique token embeddings
# Sentence 1: embeddings in order [0,1,2,3,4,5]
X1 = embeddings[[0,1,2,3,4,5]].unsqueeze(0)
# Sentence 2: same embeddings, scrambled order [5,4,3,2,1,0]
X2 = embeddings[[5,4,3,2,1,0]].unsqueeze(0)
print("X1 (original order):", X1.shape)
print("X2 (scrambled order):", X2.shape)
# Apply attention
out1 = basic_attention(X1)
out2 = basic_attention(X2)
# Check if outputs are related by the same permutation
# They should be! Attention doesn't know about position.
print("\nAttention is permutation-invariant:")
print("If we permute inputs, outputs are permuted the same way")
print("This means attention CAN'T distinguish word order!")
Consider these completely different meanings:
1. "The dog bit the man" ← Dog is agent
2. "The man bit the dog" ← Man is agent
3. "She told him not to go" ← Positive instruction
4. "She told him to not go" ← Same words, different meaning
5. "I didn't say she stole money" ← Emphasis position changes meaning!
"I didn't SAY she stole money" (I implied it)
"I didn't say SHE stole money" (someone else did)
"I didn't say she STOLE money" (she borrowed it)
Without position information, attention can't tell these apart! All have the same tokens, just in different positions.
RNNs naturally encode position through sequential processing:
# RNN processes tokens one at a time, left to right
h_0 = initial_state
h_1 = RNN(h_0, token_0) # Position 0
h_2 = RNN(h_1, token_1) # Position 1 (knows it came after token_0)
h_3 = RNN(h_2, token_2) # Position 2 (knows entire history)
# Position is implicit in the processing order
But transformers process all tokens in parallel! This is what makes them fast, but it means we must explicitly add position information.
Add position information to token embeddings before passing them to attention. There are several ways to do this:
Fixed sine/cosine functions at different frequencies
Used in: Original Transformer (2017)
Treat position as another embedding table
Used in: BERT, GPT-2
Rotate Q and K by position angle
Used in: LLaMA, GPT-NeoX, PaLM
Add linear bias to attention scores
Used in: BLOOM, MPT
The original "Attention is All You Need" paper (2017) introduced sinusoidal positional encodings - a mathematically elegant solution that uses sine and cosine functions at different frequencies.
Positional Encoding Formula:
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
Where:
• pos = position in sequence (0, 1, 2, ..., seq_len-1)
• i = dimension index (0, 1, 2, ..., d_model/2-1)
• 2i = even dimensions get sine
• 2i+1 = odd dimensions get cosine
Why 10000? Why sine/cosine? Let's break it down:
# Frequency for each dimension pair decreases exponentially
d_model = 512
dimensions = torch.arange(0, d_model, 2)
# Calculate wavelengths
wavelengths = 10000 ** (dimensions / d_model)
print("Dimension | Wavelength")
print("-" * 30)
for i in range(0, d_model, 64): # Sample every 64 dimensions
wavelength = 10000 ** (i / d_model)
print(f" {i:3d} | {wavelength:10.2f}")
# Output:
# 0 | 1.00 (high frequency, changes every position)
# 64 | 3.98 (changes every ~4 positions)
# 128 | 15.85 (changes every ~16 positions)
# 192 | 63.10 (changes every ~63 positions)
# 256 | 251.19 (changes every ~251 positions)
# 320 | 1000.00 (changes every ~1000 positions)
# 384 | 3981.07 (changes every ~4000 positions)
# 448 | 15848.93 (changes every ~15,000 positions)
Why this matters: Different dimensions capture position at different scales:
Using both sine and cosine for each frequency provides two key benefits:
import matplotlib.pyplot as plt
import numpy as np
pos = np.arange(0, 100)
freq = 1/10 # One cycle per 10 positions
# Sine and cosine are 90° out of phase
sine_wave = np.sin(2 * np.pi * freq * pos)
cosine_wave = np.cos(2 * np.pi * freq * pos)
plt.figure(figsize=(12, 4))
plt.plot(pos, sine_wave, label='sin', linewidth=2)
plt.plot(pos, cosine_wave, label='cos', linewidth=2)
plt.xlabel('Position')
plt.ylabel('Value')
plt.title('Sine and Cosine: Orthogonal Information')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
# Benefit 1: Linear combinations can represent any phase shift
# Benefit 2: Derivatives are well-behaved (sin' = cos, cos' = -sin)
The value 10000 was chosen empirically to:
import torch
import numpy as np
import matplotlib.pyplot as plt
def positional_encoding(d_model, max_seq_len=512):
"""
Create sinusoidal positional encodings.
Args:
d_model: Dimension of model (must be even)
max_seq_len: Maximum sequence length to precompute
Returns:
Positional encoding matrix [max_seq_len, d_model]
"""
# Create position indices: [0, 1, 2, ..., max_seq_len-1]
pos = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
# Shape: [max_seq_len, 1]
# Example for max_seq_len=5: [[0], [1], [2], [3], [4]]
# Create dimension indices: [0, 2, 4, ..., d_model-2]
# We only need half because each pair (2i, 2i+1) shares the same frequency
i = torch.arange(0, d_model, 2, dtype=torch.float).unsqueeze(0)
# Shape: [1, d_model/2]
# Example for d_model=8: [[0, 2, 4, 6]]
# Calculate angular frequency for each dimension
# freq = 1 / (10000^(2i/d_model))
# This creates exponentially decreasing frequencies:
# - Lower dimensions oscillate rapidly (high frequency)
# - Higher dimensions oscillate slowly (low frequency)
freq = 1.0 / (10000 ** (i / d_model))
# Shape: [1, d_model/2]
# Multiply position by frequency to get angles
# Broadcasting: [max_seq_len, 1] × [1, d_model/2] → [max_seq_len, d_model/2]
angle_rates = pos * freq
# Shape: [max_seq_len, d_model/2]
# Each row represents one position
# Each column represents one frequency component
# Initialize positional encoding matrix
pe = torch.zeros(max_seq_len, d_model)
# Apply sine to even indices (0, 2, 4, ...)
pe[:, 0::2] = torch.sin(angle_rates)
# Apply cosine to odd indices (1, 3, 5, ...)
pe[:, 1::2] = torch.cos(angle_rates)
# Final shape: [max_seq_len, d_model]
# Each row is the positional encoding for one position
# Each position has a unique pattern across all dimensions
return pe
# Generate positional encodings
d_model = 512
max_seq_len = 100
pe = positional_encoding(d_model, max_seq_len)
print(f"Positional Encoding Shape: {pe.shape}")
print(f"Each position has {d_model} values")
print(f"Values range: [{pe.min():.3f}, {pe.max():.3f}]")
# Verify properties
print(f"\nPosition 0 encoding: {pe[0, :8]}") # First 8 dimensions
print(f"Position 1 encoding: {pe[1, :8]}")
print(f"Position 2 encoding: {pe[2, :8]}")
# Show that each position has unique encoding
print(f"\nAll positions are unique: {len(torch.unique(pe, dim=0)) == max_seq_len}")
# Visualize the encoding pattern
plt.figure(figsize=(12, 8))
plt.imshow(pe.numpy().T, cmap='RdBu', aspect='auto', interpolation='nearest')
plt.xlabel('Position in Sequence', fontsize=12)
plt.ylabel('Embedding Dimension', fontsize=12)
plt.title('Sinusoidal Positional Encoding Heatmap', fontsize=14)
plt.colorbar(label='Encoding Value')
plt.tight_layout()
plt.show()
# Visualize how different dimensions encode position differently
fig, axes = plt.subplots(3, 1, figsize=(14, 10))
positions = np.arange(0, 100)
pe_numpy = pe[:100].numpy()
# Plot low-frequency dimensions (high index)
axes[0].plot(positions, pe_numpy[:, 0], label='Dim 0 (fastest)', linewidth=2)
axes[0].plot(positions, pe_numpy[:, 1], label='Dim 1', linewidth=2)
axes[0].set_title('Low Dimensions: High Frequency (changes every position)', fontsize=12)
axes[0].set_ylabel('Encoding Value')
axes[0].legend()
axes[0].grid(True, alpha=0.3)
# Plot mid-frequency dimensions
axes[1].plot(positions, pe_numpy[:, 64], label='Dim 64', linewidth=2)
axes[1].plot(positions, pe_numpy[:, 65], label='Dim 65', linewidth=2)
axes[1].set_title('Middle Dimensions: Medium Frequency (captures local structure)', fontsize=12)
axes[1].set_ylabel('Encoding Value')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
# Plot high-frequency dimensions (low index)
axes[2].plot(positions, pe_numpy[:, 256], label='Dim 256', linewidth=2)
axes[2].plot(positions, pe_numpy[:, 257], label='Dim 257', linewidth=2)
axes[2].set_title('High Dimensions: Low Frequency (captures global position)', fontsize=12)
axes[2].set_xlabel('Position')
axes[2].set_ylabel('Encoding Value')
axes[2].legend()
axes[2].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print("Notice how:")
print("- Low dims oscillate rapidly (distinguish 1 vs 2)")
print("- Mid dims oscillate moderately (distinguish 10 vs 50)")
print("- High dims oscillate slowly (distinguish early vs late in sequence)")
All values in [-1, 1]. Won't cause numerical instability or gradient explosion.
No parameters to learn. Same encoding every time. Can extrapolate to unseen lengths.
PE(pos+k) can be expressed as linear function of PE(pos), allowing model to learn relative positions.
Different dimensions capture position at different granularities (fine to coarse).
# Complete example: Adding positional encoding to token embeddings
batch_size = 2
seq_len = 10
d_model = 512
vocab_size = 10000
# Token embeddings (learned during training)
token_embedding = torch.nn.Embedding(vocab_size, d_model)
# Positional encodings (fixed)
pe = positional_encoding(d_model, max_seq_len=1000)
# Input token IDs
token_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
print(f"Token IDs shape: {token_ids.shape}") # [2, 10]
# Get token embeddings
embeddings = token_embedding(token_ids)
print(f"Token embeddings shape: {embeddings.shape}") # [2, 10, 512]
# Add positional encodings
# Select the positional encodings for the current sequence length
pos_encodings = pe[:seq_len, :].unsqueeze(0) # [1, 10, 512]
print(f"Positional encodings shape: {pos_encodings.shape}")
# Add (broadcast across batch dimension)
embeddings_with_position = embeddings + pos_encodings
print(f"Final embeddings shape: {embeddings_with_position.shape}") # [2, 10, 512]
print("\nNow each token embedding contains:")
print("- Token semantic information (learned)")
print("- Position information (fixed sinusoidal)")
print("- Ready for attention mechanism!")
Newer models use rotary positional embeddings (RoPE), which are more efficient and generalize better:
# Rotary embeddings rotate Q and K by position angle
def apply_rotary_pos_emb(x, freqs, t=0):
"""
Apply rotary position embedding
x: [batch, num_heads, seq_len, head_dim]
freqs: [seq_len, head_dim/2]
t: position offset (useful for streaming)
"""
# Extract 2D consecutive pairs and apply 2D rotation
x1, x2 = x[..., 0::2], x[..., 1::2]
# Rotation matrix for each frequency
cos_freqs = freqs.cos()
sin_freqs = freqs.sin()
# Apply rotation: (x1, x2) → (x1*cos - x2*sin, x1*sin + x2*cos)
x_rot1 = x1 * cos_freqs - x2 * sin_freqs
x_rot2 = x1 * sin_freqs + x2 * cos_freqs
# Interleave back
x_rot = torch.empty_like(x)
x_rot[..., 0::2] = x_rot1
x_rot[..., 1::2] = x_rot2
return x_rot
# Advantages over sinusoidal:
# - Better long-context extrapolation
# - More efficient GPU implementation
# - Works better with fine-tuning
# - Used in LLaMA, Mistral, recent GPT-4
| Method | Pros | Cons |
|---|---|---|
| Sinusoidal | Simple, no learning, generalizes OK | Fixed, doesn't adapt to data |
| Learned | Adaptive, data-driven | Can't extrapolate beyond training length |
| RoPE (Rotary) | Extrapolates well, efficient, modern | Slightly more complex implementation |
| ALiBi (Attention Bias) | Super simple, no embeddings needed | Less flexibility for complex patterns |
Why add positional encoding to embeddings instead of concatenating?
Option 1 (Used): x = token_embedding + position_encoding
Option 2 (Not used): x = [token_embedding; position_embedding]
Addition is simpler and works better in practice.
import torch
import torch.nn.functional as F
class TransformerEmbedding(torch.nn.Module):
def __init__(self, vocab_size, d_model, max_seq_len=512):
super().__init__()
self.token_embedding = torch.nn.Embedding(vocab_size, d_model)
self.positional_encoding = self.create_positional_encoding(d_model, max_seq_len)
self.d_model = d_model
def create_positional_encoding(self, d_model, max_seq_len):
pos = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
i = torch.arange(0, d_model, 2, dtype=torch.float).unsqueeze(0)
freq = 1.0 / (10000 ** (i / d_model))
angle_rates = pos * freq
pe = torch.zeros(max_seq_len, d_model)
pe[:, 0::2] = torch.sin(angle_rates)
pe[:, 1::2] = torch.cos(angle_rates)
# Register as buffer (not a parameter)
self.register_buffer("pe", pe.unsqueeze(0))
return pe
def forward(self, token_ids):
"""
token_ids: [batch, seq_len]
"""
seq_len = token_ids.size(1)
# Token embeddings
embeddings = self.token_embedding(token_ids) # [batch, seq_len, d_model]
# Scale embeddings
embeddings = embeddings * (self.d_model ** 0.5)
# Add positional encoding
embeddings = embeddings + self.pe[:, :seq_len, :]
return embeddings
# Usage
vocab_size = 50000
d_model = 512
embedding_layer = TransformerEmbedding(vocab_size, d_model)
token_ids = torch.randint(0, vocab_size, (2, 4))
embeddings = embedding_layer(token_ids)
print(embeddings.shape) # [2, 4, 512]
# Now each embedding includes both token meaning AND position!
Q1: What is the purpose of multi-head attention?
Q2: Why do we need positional encodings in Transformers?
Q3: How many attention heads does a multi-head attention layer typically have?
Q4: What are the two main types of positional encoding?
Q5: In multi-head attention, what happens after each head computes its attention?