📜 Free Certificate Upon Completion - Earn a verified certificate when you complete all 7 modules in the Transformers Architecture course.

Multi-Head Attention & Positional Encoding

📚 Tutorial 4 🟡 Intermediate

Understand multiple attention perspectives and how transformers track position without recurrence

🎓 Complete all tutorials to earn your Free Transformers Architecture Certificate
Shareable on LinkedIn • Verified by AITutorials.site • No signup fee

Multi-Head Attention: More Perspectives = Better Understanding

Single attention is powerful, but what if we computed attention from multiple different angles simultaneously? That's the revolutionary idea behind multi-head attention - the mechanism that gives transformers their remarkable ability to understand language.

Core Insight: Just like humans see depth using two eyes, transformers understand relationships using multiple attention heads. Each head learns to focus on different aspects of language.

The Problem with Single-Head Attention

⚠️ Limitation: One attention mechanism can't capture everything!

Consider the sentence: "The lawyer questioned the witness about the crime."

  • Syntactic relationships: "lawyer" → "The", "questioned" → "lawyer" (subject)
  • Semantic relationships: "crime" → "witness", "questioned" → "about"
  • Long-range dependencies: "witness" → "lawyer" (agent-patient)
  • Discourse structure: Main action vs. details

Problem: A single attention head must choose one pattern to prioritize. It can't simultaneously capture syntax, semantics, and discourse structure!

The Multi-Head Solution

Instead of forcing one attention mechanism to do everything, run multiple attention mechanisms in parallel - each can specialize in different linguistic phenomena.

The Human Vision Analogy

Your brain doesn't process vision with one system:

  • One pathway: Detects shapes and objects ("what" pathway)
  • Another pathway: Tracks motion and spatial relationships ("where" pathway)
  • Color processing: Separate neurons for different wavelengths
  • Depth perception: Combines inputs from both eyes

Multi-head attention works the same way: Different "heads" (pathways) each focus on different linguistic features, then their outputs are combined for complete understanding.

How Many Heads Do Modern Models Use?

Model d_model Num Heads d_k per head Notes
BERT-base 768 12 64 768 / 12 = 64 per head
GPT-2 768 12 64 Same as BERT-base
GPT-3 12,288 96 128 12,288 / 96 = 128 per head
LLaMA-2 70B 8,192 64 128 8,192 / 64 = 128 per head
GPT-4 (est.) ~18,000 ~128 ~140 Unconfirmed architecture

📊 Pattern: The ratio d_model / num_heads stays remarkably consistent (64-128) across models. This suggests there's an optimal "head dimension" that balances expressiveness with computational efficiency.

How Multi-Head Attention Works - Step by Step

Multi-head attention is elegantly simple: run multiple attention mechanisms in parallel, each with its own learned projections. Let's break down the process:

The 5-Step Process

  1. Linear Projection: Project input to Q, K, V for ALL heads simultaneously (one big matrix multiplication)
  2. Split into Heads: Reshape from [batch, seq, d_model] to [batch, num_heads, seq, d_k]
  3. Parallel Attention: Compute attention for all heads at once (highly parallelized on GPU)
  4. Concatenate: Merge all head outputs back to [batch, seq, d_model]
  5. Output Projection: Final linear transformation to mix information from all heads

Conceptual Example: 2 Heads Processing "The cat sat"

import torch
import torch.nn.functional as F

# Setup
tokens = ["The", "cat", "sat"]
seq_len = 3
d_model = 8  # Small for illustration
num_heads = 2
d_k = d_model // num_heads  # 8 / 2 = 4 per head

# Input embeddings (normally from embedding layer)
X = torch.tensor([
    [0.1, 0.2, -0.1, 0.3, 0.4, -0.2, 0.1, 0.0],  # "The"
    [0.3, -0.1, 0.5, 0.2, 0.1, 0.4, -0.3, 0.2],  # "cat"
    [-0.2, 0.4, 0.1, -0.3, 0.5, 0.1, 0.2, -0.1]   # "sat"
]).unsqueeze(0)  # [1, 3, 8]

print("Input shape:", X.shape)
print("We have 2 heads, each will get 4 dimensions\n")

# Step 1: Project to Q, K, V (for ALL heads at once)
W_Q = torch.nn.Linear(d_model, d_model, bias=False)
W_K = torch.nn.Linear(d_model, d_model, bias=False)
W_V = torch.nn.Linear(d_model, d_model, bias=False)

Q = W_Q(X)  # [1, 3, 8]
K = W_K(X)  # [1, 3, 8]
V = W_V(X)  # [1, 3, 8]

print("After projection:")
print(f"Q shape: {Q.shape}")
print(f"K shape: {K.shape}")
print(f"V shape: {V.shape}\n")

# Step 2: Split into heads
# Reshape: [batch, seq, d_model] → [batch, seq, num_heads, d_k]
# Then transpose: [batch, num_heads, seq, d_k]
batch_size = X.shape[0]

Q = Q.view(batch_size, seq_len, num_heads, d_k).transpose(1, 2)
K = K.view(batch_size, seq_len, num_heads, d_k).transpose(1, 2)
V = V.view(batch_size, seq_len, num_heads, d_k).transpose(1, 2)

print("After splitting into heads:")
print(f"Q shape: {Q.shape}  # [batch, num_heads, seq, d_k]")
print(f"K shape: {K.shape}")
print(f"V shape: {V.shape}\n")

print("Head 0 gets Q[:, 0, :, :] - dimensions 0-3")
print("Head 1 gets Q[:, 1, :, :] - dimensions 4-7\n")

# Step 3: Compute attention for BOTH heads in parallel
scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)
# [1, 2, 3, 3] - 2 attention matrices (one per head)

attention_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attention_weights, V)
# [1, 2, 3, 4] - 2 heads, each produces 3 tokens of 4 dimensions

print("After attention:")
print(f"Output shape: {output.shape}  # [batch, num_heads, seq, d_k]")
print(f"\nHead 0 output: {output[0, 0, :, :].shape}  # [3, 4]")
print(f"Head 1 output: {output[0, 1, :, :].shape}  # [3, 4]\n")

# Step 4: Concatenate heads
# Transpose back: [batch, num_heads, seq, d_k] → [batch, seq, num_heads, d_k]
output = output.transpose(1, 2).contiguous()
# Reshape: [batch, seq, num_heads, d_k] → [batch, seq, d_model]
output = output.view(batch_size, seq_len, d_model)

print("After concatenating heads:")
print(f"Output shape: {output.shape}  # [1, 3, 8]")
print("Back to original dimensions!\n")

# Step 5: Output projection
W_O = torch.nn.Linear(d_model, d_model, bias=False)
final_output = W_O(output)

print(f"Final output shape: {final_output.shape}  # [1, 3, 8]")
print("\nThis output has information from BOTH heads combined!")

Visual Representation of Multi-Head Splitting

Original Q matrix: [batch, seq_len, d_model=512]
                    
     Dimension:  0   64  128 192 256 320 384 448  512
                 ↓   ↓   ↓   ↓   ↓   ↓   ↓   ↓    ↓
     Token 0:  [================================================]
     Token 1:  [================================================]
     Token 2:  [================================================]
     Token 3:  [================================================]

After reshaping to [batch, num_heads=8, seq_len, d_k=64]:

     Head 0 (dims 0-63):
     Token 0:  [=======]
     Token 1:  [=======]
     Token 2:  [=======]
     Token 3:  [=======]

     Head 1 (dims 64-127):
     Token 0:          [=======]
     Token 1:          [=======]
     Token 2:          [=======]
     Token 3:          [=======]

     ... (Heads 2-6) ...

     Head 7 (dims 448-511):
     Token 0:                                  [=======]
     Token 1:                                  [=======]
     Token 2:                                  [=======]
     Token 3:                                  [=======]

Each head independently computes attention on its slice!
Then we concatenate everything back together.

Why This Design is Brilliant

✅ Parallelizable

All heads compute simultaneously on GPU. Same wall-clock time as single head!

✅ Specialized Learning

Each head learns different patterns through gradient descent naturally.

✅ Parameter Efficient

Total params same as single attention with larger d_k. Same computation cost!

✅ Flexible Representation

Model can blend different perspectives in output projection.

Single-Head vs Multi-Head: Direct Comparison

# Single-head attention
d_model = 512
d_k = 512  # Full dimension

Q = W_Q(X)  # [batch, seq, 512]
K = W_K(X)  # [batch, seq, 512]
V = W_V(X)  # [batch, seq, 512]

scores = Q @ K.T / sqrt(512)
attn = softmax(scores)
output = attn @ V  # [batch, seq, 512]

# Parameters: 3 × (512 × 512) = 786,432

#----------------------------------------------------------------

# Multi-head attention (8 heads)
num_heads = 8
d_k = 512 // 8 = 64  # Per-head dimension

Q = W_Q(X)  # [batch, seq, 512] - same projection size!
K = W_K(X)  # [batch, seq, 512]
V = W_V(X)  # [batch, seq, 512]

# Reshape to [batch, 8, seq, 64]
# Each head operates on 64 dimensions

# Compute 8 attention patterns in parallel
outputs = [attention_head_i(Q_i, K_i, V_i) for i in range(8)]
output = concat(outputs)  # [batch, seq, 512]
output = W_O(output)  # Final projection

# Parameters: 3 × (512 × 512) + (512 × 512) = 1,048,576
# Only 33% more parameters, but 8× the representational capacity!

Why Multiple Heads? Head Specialization in Practice

Research has shown that different attention heads spontaneously specialize in different linguistic phenomena during training. This isn't programmed - it emerges naturally through gradient descent!

Documented Head Specialization Patterns

🔤 Syntactic Heads

Focus: Grammatical dependencies

  • Determiners → Nouns: "the" → "cat"
  • Adjectives → Nouns: "big" → "dog"
  • Subject → Verb: "Alice" → "walked"
  • Verb → Object: "ate" → "apple"

💭 Semantic Heads

Focus: Meaning relationships

  • Coreference: "she" → "Alice"
  • Related concepts: "doctor" → "hospital"
  • Semantic roles: agent, patient, instrument
  • Thematic relations: cause-effect

📍 Positional Heads

Focus: Relative position

  • Previous token (n-1)
  • Next token (n+1)
  • Window of ±2-3 tokens
  • Sentence start/end markers

🌐 Global Context Heads

Focus: Broad information

  • Uniform attention to all tokens
  • Topic-level information
  • Document-wide context
  • Aggregation functions

⚠️ Rare Word Heads

Focus: Important tokens

  • Named entities: "Einstein", "Paris"
  • Numbers and dates
  • Technical terms
  • Unique identifiers

📊 Delimiter Heads

Focus: Structure markers

  • Punctuation: periods, commas
  • Special tokens: [CLS], [SEP]
  • Sentence boundaries
  • Paragraph structure

Real Example: Visualizing Head Specialization in BERT

from transformers import BertTokenizer, BertModel
import torch
import matplotlib.pyplot as plt
import seaborn as sns

# Load BERT and analyze attention patterns
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased', output_attentions=True)
model.eval()

# Example sentence with clear syntactic structure
text = "The elderly professor taught the enthusiastic students about quantum physics."
inputs = tokenizer(text, return_tensors='pt')
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])

print(f"Tokens: {tokens}")
print(f"Analyzing {model.config.num_hidden_layers} layers, {model.config.num_attention_heads} heads per layer\n")

# Get attention weights
with torch.no_grad():
    outputs = model(**inputs)
    attentions = outputs.attentions  # Tuple of 12 layers

# Analyze each head in layer 6 (middle layer often shows clear patterns)
layer_idx = 6
layer_attentions = attentions[layer_idx][0]  # [num_heads, seq_len, seq_len]

print(f"Layer {layer_idx} Head Specialization:\n")

# Function to identify head type based on attention pattern
def analyze_head(attn_matrix, tokens):
    """Classify what pattern this head is focusing on."""
    seq_len = len(tokens)
    
    # Check for syntactic pattern (determiners → nouns)
    det_indices = [i for i, t in enumerate(tokens) if t in ['the', 'a', 'an']]
    if det_indices:
        # Do following tokens attend to determiners?
        det_attention = attn_matrix[:, det_indices].mean()
        if det_attention > 0.3:
            return "Syntactic (Determiner-Noun)"
    
    # Check for positional pattern (diagonal/nearby)
    diagonal_strength = torch.diag(attn_matrix).mean()
    if diagonal_strength > 0.5:
        return "Positional (Self-attention)"
    
    # Check for next-token pattern
    next_token_attn = torch.diag(attn_matrix, diagonal=1).mean()
    if next_token_attn > 0.4:
        return "Positional (Next-token)"
    
    # Check for uniform (broadcast) pattern
    attn_std = attn_matrix.std()
    if attn_std < 0.15:
        return "Global Context (Uniform)"
    
    # Check for delimiter pattern
    delim_indices = [i for i, t in enumerate(tokens) if t in ['.', ',', '[SEP]', '[CLS]']]
    if delim_indices:
        delim_attention = attn_matrix[:, delim_indices].mean()
        if delim_attention > 0.4:
            return "Delimiter-focused"
    
    return "Semantic/Other"

# Analyze each head
for head_idx in range(model.config.num_attention_heads):
    attn = layer_attentions[head_idx].cpu()
    head_type = analyze_head(attn, tokens)
    
    # Find strongest attention connections
    max_val = attn.max().item()
    max_pos = (attn == max_val).nonzero(as_tuple=True)
    from_token = tokens[max_pos[0][0].item()]
    to_token = tokens[max_pos[1][0].item()]
    
    print(f"Head {head_idx:2d}: {head_type:<30s} | Strongest: {from_token} → {to_token} ({max_val:.3f})")

# Visualize different head types side by side
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
head_examples = [0, 2, 5, 7, 9, 11]  # Select diverse heads

for idx, head_idx in enumerate(head_examples):
    ax = axes[idx // 3, idx % 3]
    attn = layer_attentions[head_idx].cpu().numpy()
    
    sns.heatmap(attn, annot=False, cmap='YlOrRd', 
                xticklabels=tokens, yticklabels=tokens,
                cbar_kws={'label': 'Attention'}, ax=ax)
    
    head_type = analyze_head(layer_attentions[head_idx].cpu(), tokens)
    ax.set_title(f'Head {head_idx}: {head_type}', fontsize=10)
    ax.set_xlabel('Attending to', fontsize=8)
    ax.set_ylabel('Attending from', fontsize=8)
    ax.tick_params(labelsize=6)

plt.tight_layout()
plt.show()

Research Finding: Head Pruning Studies

📚 Key Research: "Are Sixteen Heads Really Better than One?" (Michel et al., 2019)

  • Finding 1: In BERT, you can remove 50% of heads with minimal performance loss!
  • Finding 2: Some heads are "doing the heavy lifting" while others contribute little
  • Finding 3: Head importance varies by task - translation needs different heads than classification
  • Finding 4: Early layers: positional/syntactic heads. Deep layers: semantic heads

💡 Implication: This led to efficient models like DistilBERT (removes heads) and MobileBERT (shares heads) that maintain 97%+ performance with 40% fewer parameters.

Head Importance by Layer Depth

Layer Range Typical Head Focus Example Patterns
Layers 1-3 (Early) Surface-level, positional Previous token, next token, nearby words, punctuation
Layers 4-7 (Middle) Syntactic structure Subject-verb, verb-object, determiner-noun, modifiers
Layers 8-12 (Deep) Semantic relationships Coreference, semantic roles, discourse structure, topic
Key Benefit of Multi-Head Attention: A single attention mechanism would need to choose between capturing syntax OR semantics OR position. Multiple heads capture ALL of these simultaneously, giving transformers their remarkable ability to understand language at multiple levels of abstraction.

Multi-Head Implementation


import torch
import torch.nn.functional as F

class MultiHeadAttention(torch.nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # Single set of projection matrices
        # These automatically split across heads
        self.W_Q = torch.nn.Linear(d_model, d_model)
        self.W_K = torch.nn.Linear(d_model, d_model)
        self.W_V = torch.nn.Linear(d_model, d_model)
        
        # Output projection to combine all heads
        self.W_O = torch.nn.Linear(d_model, d_model)
    
    def forward(self, X, mask=None):
        batch_size, seq_len, d_model = X.shape
        
        # Project to Q, K, V
        Q = self.W_Q(X)  # [batch, seq_len, d_model]
        K = self.W_K(X)
        V = self.W_V(X)
        
        # Split into multiple heads
        # Reshape to [batch, seq_len, num_heads, d_k]
        # Then transpose to [batch, num_heads, seq_len, d_k]
        Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        # Now: [batch, num_heads, seq_len, d_k]
        
        # Compute attention for all heads in parallel
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
        # [batch, num_heads, seq_len, seq_len]
        
        # Apply mask if provided (e.g., causal mask)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        # Softmax
        attention_weights = F.softmax(scores, dim=-1)
        
        # Apply to values
        output = torch.matmul(attention_weights, V)
        # [batch, num_heads, seq_len, d_k]
        
        # Reshape and concatenate heads
        output = output.transpose(1, 2).contiguous()
        # [batch, seq_len, num_heads, d_k]
        output = output.view(batch_size, seq_len, d_model)
        # [batch, seq_len, d_model]
        
        # Final output projection
        output = self.W_O(output)
        
        return output, attention_weights

# Usage
d_model = 512
num_heads = 8
mha = MultiHeadAttention(d_model, num_heads)

X = torch.randn(2, 4, 512)
output, weights = mha(X)
print(output.shape)        # [2, 4, 512]
print(weights.shape)       # [2, 8, 4, 4] - one attention map per head!

Positional Encoding: The Missing Ingredient

Here's a fundamental problem with attention: it's permutation-invariant. The attention mechanism computes similarities between tokens, but doesn't inherently know where those tokens are in the sequence!

Demonstrating the Position Problem

import torch
import torch.nn.functional as F

# Simple attention function (no position info)
def basic_attention(X):
    """X: [batch, seq_len, d_model]"""
    Q = K = V = X  # Simplified
    scores = Q @ K.transpose(-2, -1)
    attn = F.softmax(scores, dim=-1)
    output = attn @ V
    return output

# Two sentences with different word order but same tokens
tokens1 = ["The", "cat", "sat", "on", "the", "mat"]
tokens2 = ["mat", "the", "on", "sat", "cat", "The"]  # Scrambled!

# Simulate embeddings (same embeddings, different order)
embeddings = torch.randn(6, 512)  # 6 unique token embeddings

# Sentence 1: embeddings in order [0,1,2,3,4,5]
X1 = embeddings[[0,1,2,3,4,5]].unsqueeze(0)

# Sentence 2: same embeddings, scrambled order [5,4,3,2,1,0]
X2 = embeddings[[5,4,3,2,1,0]].unsqueeze(0)

print("X1 (original order):", X1.shape)
print("X2 (scrambled order):", X2.shape)

# Apply attention
out1 = basic_attention(X1)
out2 = basic_attention(X2)

# Check if outputs are related by the same permutation
# They should be! Attention doesn't know about position.
print("\nAttention is permutation-invariant:")
print("If we permute inputs, outputs are permuted the same way")
print("This means attention CAN'T distinguish word order!")

❌ The Problem in Action

Consider these completely different meanings:

1. "The dog bit the man"    ← Dog is agent
2. "The man bit the dog"    ← Man is agent

3. "She told him not to go"     ← Positive instruction
4. "She told him to not go"     ← Same words, different meaning

5. "I didn't say she stole money"  ← Emphasis position changes meaning!
   "I didn't SAY she stole money"  (I implied it)
   "I didn't say SHE stole money"  (someone else did)
   "I didn't say she STOLE money" (she borrowed it)

Without position information, attention can't tell these apart! All have the same tokens, just in different positions.

Why RNNs Didn't Have This Problem

RNNs naturally encode position through sequential processing:

# RNN processes tokens one at a time, left to right
h_0 = initial_state
h_1 = RNN(h_0, token_0)  # Position 0
h_2 = RNN(h_1, token_1)  # Position 1 (knows it came after token_0)
h_3 = RNN(h_2, token_2)  # Position 2 (knows entire history)

# Position is implicit in the processing order

But transformers process all tokens in parallel! This is what makes them fast, but it means we must explicitly add position information.

The Solution: Positional Encoding

Add position information to token embeddings before passing them to attention. There are several ways to do this:

1. Sinusoidal (Original)

Fixed sine/cosine functions at different frequencies

Used in: Original Transformer (2017)

2. Learned Embeddings

Treat position as another embedding table

Used in: BERT, GPT-2

3. Rotary (RoPE)

Rotate Q and K by position angle

Used in: LLaMA, GPT-NeoX, PaLM

4. ALiBi

Add linear bias to attention scores

Used in: BLOOM, MPT

Core Principle: We add (not concatenate) position information to token embeddings. This allows the model to use both content and position information from the very first attention layer.

Sinusoidal Positional Encoding - Deep Dive

The original "Attention is All You Need" paper (2017) introduced sinusoidal positional encodings - a mathematically elegant solution that uses sine and cosine functions at different frequencies.

The Mathematical Formula

Positional Encoding Formula:

PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))

Where:
• pos = position in sequence (0, 1, 2, ..., seq_len-1)
• i = dimension index (0, 1, 2, ..., d_model/2-1)
• 2i = even dimensions get sine
• 2i+1 = odd dimensions get cosine

Understanding the Design Choices

Why 10000? Why sine/cosine? Let's break it down:

1. Different Frequencies for Different Dimensions

# Frequency for each dimension pair decreases exponentially
d_model = 512
dimensions = torch.arange(0, d_model, 2)

# Calculate wavelengths
wavelengths = 10000 ** (dimensions / d_model)

print("Dimension | Wavelength")
print("-" * 30)
for i in range(0, d_model, 64):  # Sample every 64 dimensions
    wavelength = 10000 ** (i / d_model)
    print(f"  {i:3d}     | {wavelength:10.2f}")

# Output:
#   0     |       1.00  (high frequency, changes every position)
#  64     |       3.98  (changes every ~4 positions)
# 128     |      15.85  (changes every ~16 positions)
# 192     |      63.10  (changes every ~63 positions)
# 256     |     251.19  (changes every ~251 positions)
# 320     |    1000.00  (changes every ~1000 positions)
# 384     |    3981.07  (changes every ~4000 positions)
# 448     |   15848.93  (changes every ~15,000 positions)

Why this matters: Different dimensions capture position at different scales:

  • Low dimensions (high freq): Distinguish nearby positions (1 vs 2)
  • Mid dimensions (med freq): Distinguish local chunks (position 10 vs 50)
  • High dimensions (low freq): Distinguish distant regions (beginning vs end)

2. Why Sine AND Cosine?

Using both sine and cosine for each frequency provides two key benefits:

import matplotlib.pyplot as plt
import numpy as np

pos = np.arange(0, 100)
freq = 1/10  # One cycle per 10 positions

# Sine and cosine are 90° out of phase
sine_wave = np.sin(2 * np.pi * freq * pos)
cosine_wave = np.cos(2 * np.pi * freq * pos)

plt.figure(figsize=(12, 4))
plt.plot(pos, sine_wave, label='sin', linewidth=2)
plt.plot(pos, cosine_wave, label='cos', linewidth=2)
plt.xlabel('Position')
plt.ylabel('Value')
plt.title('Sine and Cosine: Orthogonal Information')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

# Benefit 1: Linear combinations can represent any phase shift
# Benefit 2: Derivatives are well-behaved (sin' = cos, cos' = -sin)

3. Why 10000 Specifically?

The value 10000 was chosen empirically to:

  • Create wavelengths ranging from 2π (fastest) to 10000×2π (slowest)
  • Allow the model to learn position relationships up to ~10,000 tokens
  • Provide smooth gradients across typical sequence lengths (512-2048)
  • Balance between too much variation (hard to learn) and too little (ambiguous)

Complete Implementation with Detailed Comments

import torch
import numpy as np
import matplotlib.pyplot as plt

def positional_encoding(d_model, max_seq_len=512):
    """
    Create sinusoidal positional encodings.
    
    Args:
        d_model: Dimension of model (must be even)
        max_seq_len: Maximum sequence length to precompute
        
    Returns:
        Positional encoding matrix [max_seq_len, d_model]
    """
    # Create position indices: [0, 1, 2, ..., max_seq_len-1]
    pos = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
    # Shape: [max_seq_len, 1]
    # Example for max_seq_len=5: [[0], [1], [2], [3], [4]]
    
    # Create dimension indices: [0, 2, 4, ..., d_model-2]
    # We only need half because each pair (2i, 2i+1) shares the same frequency
    i = torch.arange(0, d_model, 2, dtype=torch.float).unsqueeze(0)
    # Shape: [1, d_model/2]
    # Example for d_model=8: [[0, 2, 4, 6]]
    
    # Calculate angular frequency for each dimension
    # freq = 1 / (10000^(2i/d_model))
    # This creates exponentially decreasing frequencies:
    # - Lower dimensions oscillate rapidly (high frequency)
    # - Higher dimensions oscillate slowly (low frequency)
    freq = 1.0 / (10000 ** (i / d_model))
    # Shape: [1, d_model/2]
    
    # Multiply position by frequency to get angles
    # Broadcasting: [max_seq_len, 1] × [1, d_model/2] → [max_seq_len, d_model/2]
    angle_rates = pos * freq
    # Shape: [max_seq_len, d_model/2]
    # Each row represents one position
    # Each column represents one frequency component
    
    # Initialize positional encoding matrix
    pe = torch.zeros(max_seq_len, d_model)
    
    # Apply sine to even indices (0, 2, 4, ...)
    pe[:, 0::2] = torch.sin(angle_rates)
    
    # Apply cosine to odd indices (1, 3, 5, ...)
    pe[:, 1::2] = torch.cos(angle_rates)
    
    # Final shape: [max_seq_len, d_model]
    # Each row is the positional encoding for one position
    # Each position has a unique pattern across all dimensions
    
    return pe

# Generate positional encodings
d_model = 512
max_seq_len = 100
pe = positional_encoding(d_model, max_seq_len)

print(f"Positional Encoding Shape: {pe.shape}")
print(f"Each position has {d_model} values")
print(f"Values range: [{pe.min():.3f}, {pe.max():.3f}]")

# Verify properties
print(f"\nPosition 0 encoding: {pe[0, :8]}")  # First 8 dimensions
print(f"Position 1 encoding: {pe[1, :8]}")
print(f"Position 2 encoding: {pe[2, :8]}")

# Show that each position has unique encoding
print(f"\nAll positions are unique: {len(torch.unique(pe, dim=0)) == max_seq_len}")

# Visualize the encoding pattern
plt.figure(figsize=(12, 8))
plt.imshow(pe.numpy().T, cmap='RdBu', aspect='auto', interpolation='nearest')
plt.xlabel('Position in Sequence', fontsize=12)
plt.ylabel('Embedding Dimension', fontsize=12)
plt.title('Sinusoidal Positional Encoding Heatmap', fontsize=14)
plt.colorbar(label='Encoding Value')
plt.tight_layout()
plt.show()

Visualizing Positional Encoding Patterns

# Visualize how different dimensions encode position differently
fig, axes = plt.subplots(3, 1, figsize=(14, 10))

positions = np.arange(0, 100)
pe_numpy = pe[:100].numpy()

# Plot low-frequency dimensions (high index)
axes[0].plot(positions, pe_numpy[:, 0], label='Dim 0 (fastest)', linewidth=2)
axes[0].plot(positions, pe_numpy[:, 1], label='Dim 1', linewidth=2)
axes[0].set_title('Low Dimensions: High Frequency (changes every position)', fontsize=12)
axes[0].set_ylabel('Encoding Value')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Plot mid-frequency dimensions
axes[1].plot(positions, pe_numpy[:, 64], label='Dim 64', linewidth=2)
axes[1].plot(positions, pe_numpy[:, 65], label='Dim 65', linewidth=2)
axes[1].set_title('Middle Dimensions: Medium Frequency (captures local structure)', fontsize=12)
axes[1].set_ylabel('Encoding Value')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Plot high-frequency dimensions (low index)
axes[2].plot(positions, pe_numpy[:, 256], label='Dim 256', linewidth=2)
axes[2].plot(positions, pe_numpy[:, 257], label='Dim 257', linewidth=2)
axes[2].set_title('High Dimensions: Low Frequency (captures global position)', fontsize=12)
axes[2].set_xlabel('Position')
axes[2].set_ylabel('Encoding Value')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("Notice how:")
print("- Low dims oscillate rapidly (distinguish 1 vs 2)")
print("- Mid dims oscillate moderately (distinguish 10 vs 50)")
print("- High dims oscillate slowly (distinguish early vs late in sequence)")

Key Properties of Sinusoidal Encoding

✅ Bounded

All values in [-1, 1]. Won't cause numerical instability or gradient explosion.

✅ Deterministic

No parameters to learn. Same encoding every time. Can extrapolate to unseen lengths.

✅ Linear Relationships

PE(pos+k) can be expressed as linear function of PE(pos), allowing model to learn relative positions.

✅ Multi-Scale

Different dimensions capture position at different granularities (fine to coarse).

Usage in Practice

# Complete example: Adding positional encoding to token embeddings
batch_size = 2
seq_len = 10
d_model = 512
vocab_size = 10000

# Token embeddings (learned during training)
token_embedding = torch.nn.Embedding(vocab_size, d_model)

# Positional encodings (fixed)
pe = positional_encoding(d_model, max_seq_len=1000)

# Input token IDs
token_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
print(f"Token IDs shape: {token_ids.shape}")  # [2, 10]

# Get token embeddings
embeddings = token_embedding(token_ids)
print(f"Token embeddings shape: {embeddings.shape}")  # [2, 10, 512]

# Add positional encodings
# Select the positional encodings for the current sequence length
pos_encodings = pe[:seq_len, :].unsqueeze(0)  # [1, 10, 512]
print(f"Positional encodings shape: {pos_encodings.shape}")

# Add (broadcast across batch dimension)
embeddings_with_position = embeddings + pos_encodings
print(f"Final embeddings shape: {embeddings_with_position.shape}")  # [2, 10, 512]

print("\nNow each token embedding contains:")
print("- Token semantic information (learned)")
print("- Position information (fixed sinusoidal)")
print("- Ready for attention mechanism!")

Rotary Positional Embeddings (Modern)

Newer models use rotary positional embeddings (RoPE), which are more efficient and generalize better:


# Rotary embeddings rotate Q and K by position angle
def apply_rotary_pos_emb(x, freqs, t=0):
    """
    Apply rotary position embedding
    x: [batch, num_heads, seq_len, head_dim]
    freqs: [seq_len, head_dim/2]
    t: position offset (useful for streaming)
    """
    # Extract 2D consecutive pairs and apply 2D rotation
    x1, x2 = x[..., 0::2], x[..., 1::2]
    
    # Rotation matrix for each frequency
    cos_freqs = freqs.cos()
    sin_freqs = freqs.sin()
    
    # Apply rotation: (x1, x2) → (x1*cos - x2*sin, x1*sin + x2*cos)
    x_rot1 = x1 * cos_freqs - x2 * sin_freqs
    x_rot2 = x1 * sin_freqs + x2 * cos_freqs
    
    # Interleave back
    x_rot = torch.empty_like(x)
    x_rot[..., 0::2] = x_rot1
    x_rot[..., 1::2] = x_rot2
    
    return x_rot

# Advantages over sinusoidal:
# - Better long-context extrapolation
# - More efficient GPU implementation
# - Works better with fine-tuning
# - Used in LLaMA, Mistral, recent GPT-4

Comparing Positional Encoding Methods

Method Pros Cons
Sinusoidal Simple, no learning, generalizes OK Fixed, doesn't adapt to data
Learned Adaptive, data-driven Can't extrapolate beyond training length
RoPE (Rotary) Extrapolates well, efficient, modern Slightly more complex implementation
ALiBi (Attention Bias) Super simple, no embeddings needed Less flexibility for complex patterns

Why Add, Not Concatenate?

Why add positional encoding to embeddings instead of concatenating?

Option 1 (Used): x = token_embedding + position_encoding

  • Dimensions stay same (don't grow)
  • Position and content interact from the start
  • More parameter-efficient

Option 2 (Not used): x = [token_embedding; position_embedding]

  • Doubles dimensions
  • More parameters in first layer
  • Position and content kept separate longer

Addition is simpler and works better in practice.

Complete Forward Pass: Embeddings + Position


import torch
import torch.nn.functional as F

class TransformerEmbedding(torch.nn.Module):
    def __init__(self, vocab_size, d_model, max_seq_len=512):
        super().__init__()
        self.token_embedding = torch.nn.Embedding(vocab_size, d_model)
        self.positional_encoding = self.create_positional_encoding(d_model, max_seq_len)
        self.d_model = d_model
    
    def create_positional_encoding(self, d_model, max_seq_len):
        pos = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
        i = torch.arange(0, d_model, 2, dtype=torch.float).unsqueeze(0)
        freq = 1.0 / (10000 ** (i / d_model))
        angle_rates = pos * freq
        
        pe = torch.zeros(max_seq_len, d_model)
        pe[:, 0::2] = torch.sin(angle_rates)
        pe[:, 1::2] = torch.cos(angle_rates)
        
        # Register as buffer (not a parameter)
        self.register_buffer("pe", pe.unsqueeze(0))
        return pe
    
    def forward(self, token_ids):
        """
        token_ids: [batch, seq_len]
        """
        seq_len = token_ids.size(1)
        
        # Token embeddings
        embeddings = self.token_embedding(token_ids)  # [batch, seq_len, d_model]
        
        # Scale embeddings
        embeddings = embeddings * (self.d_model ** 0.5)
        
        # Add positional encoding
        embeddings = embeddings + self.pe[:, :seq_len, :]
        
        return embeddings

# Usage
vocab_size = 50000
d_model = 512
embedding_layer = TransformerEmbedding(vocab_size, d_model)

token_ids = torch.randint(0, vocab_size, (2, 4))
embeddings = embedding_layer(token_ids)
print(embeddings.shape)  # [2, 4, 512]
# Now each embedding includes both token meaning AND position!

Key Takeaways

  • Multi-Head Attention: Multiple parallel attention mechanisms, each learning different patterns
  • Head Specialization: Different heads focus on syntax, semantics, discourse, etc.
  • Concatenation: All heads concatenated back to full dimension
  • Positional Encoding: Add position info since attention is permutation-invariant
  • Sinusoidal PE: Original approach using sine/cosine at different frequencies
  • Rotary PE (RoPE): Modern approach with better extrapolation
  • Additive Combination: Add position to embeddings (not concatenate)
  • Scaling Factor: Multiply embeddings by √d_model before adding position

Test Your Knowledge

Q1: What is the purpose of multi-head attention?

To make the model larger
To slow down training
To allow the model to attend to different representation subspaces simultaneously
To eliminate positional encoding

Q2: Why do we need positional encodings in Transformers?

To make the model more complex
Because self-attention has no inherent notion of token order
To reduce computational cost
To eliminate the need for embeddings

Q3: How many attention heads does a multi-head attention layer typically have?

Always 1
Always 2
Always 4
Typically 8 or 12, but varies by model

Q4: What are the two main types of positional encoding?

Sinusoidal (fixed) and learned (trainable)
Random and sequential
Static and dynamic
Local and global

Q5: In multi-head attention, what happens after each head computes its attention?

The heads are discarded
Only the first head is used
The outputs are concatenated and linearly transformed
The outputs are averaged