Master the mathematical foundations of attention - the core mechanism powering all modern AI
π Complete all tutorials to earn your Free Transformers Architecture Certificate
Shareable on LinkedIn β’ Verified by AITutorials.site β’ No signup fee
You understand attention conceptually: "for each position, decide which other positions are important." Now let's dive deep into the mathematics, implementation details, and intuition behind how this mechanism actually works in production transformers.
Self-attention transforms a sequence of tokens from independent embeddings to context-aware representations in 6 steps:
π― Key Insight: Each step has a specific purpose. Understanding why each exists is crucial for debugging and optimizing transformer models.
INPUT: Token Embeddings (seq_len Γ d_model)
"The cat sat on mat"
β
βββββββββββββββββββββββββββββββββββββββββ
β Step 1: Linear Projections β
β Q = X @ W_Q (what I'm looking for) β
β K = X @ W_K (what I contain) β
β V = X @ W_V (my information) β
βββββββββββββββββββββββββββββββββββββββββ
β
βββββββββββββββββββββββββββββββββββββββββ
β Step 2: Compute Similarity Matrix β
β Scores = Q @ K^T β
β Result: (seq_len Γ seq_len) β
β β
β The cat sat on mat β
β The [0.8 0.2 0.1 0.1 0.1] β
β cat [0.3 0.9 0.4 0.1 0.2] β
β sat [0.2 0.5 0.8 0.3 0.4] β
β ... β
βββββββββββββββββββββββββββββββββββββββββ
β
βββββββββββββββββββββββββββββββββββββββββ
β Step 3: Scale by βd_k β
β Scores = Scores / β64 β
β (prevents softmax saturation) β
βββββββββββββββββββββββββββββββββββββββββ
β
βββββββββββββββββββββββββββββββββββββββββ
β Step 4: Softmax β Probabilities β
β Weights = softmax(Scores) β
β Each row sums to 1.0 β
βββββββββββββββββββββββββββββββββββββββββ
β
βββββββββββββββββββββββββββββββββββββββββ
β Step 5: Apply to Values β
β Output = Weights @ V β
β Weighted combination of all tokens β
βββββββββββββββββββββββββββββββββββββββββ
β
OUTPUT: Context-Aware Representations
Each token now "knows" about others!
Self-attention starts by creating three different "views" of the input. This might seem redundant (why not just use X?), but each projection serves a distinct purpose.
π‘ Intuition: The Restaurant Analogy
Imagine you're at a restaurant:
You compare your query against menu keys to decide which items match your preferences. Then you get the actual values (food) corresponding to your selections. You don't eat the menu description - you eat the actual dish!
import torch
import torch.nn as nn
# Input: embeddings of all tokens
# Shape: [batch_size, seq_len, d_model]
# Example: [2, 4, 512] (2 sequences, 4 tokens, 512D embeddings)
X = torch.randn(2, 4, 512)
print(f"Input shape: {X.shape}")
print(f"Each token is represented as a {X.shape[-1]}-dimensional vector")
# Learned linear transformations (Weight matrices)
# These are learned during training via backpropagation
d_model = 512
d_k = 64 # Typical: d_model / num_heads
d_v = 64
W_Q = nn.Linear(d_model, d_k, bias=False) # Query projection
W_K = nn.Linear(d_model, d_k, bias=False) # Key projection
W_V = nn.Linear(d_model, d_v, bias=False) # Value projection
# Project input to Q, K, V
Q = W_Q(X) # [2, 4, 64] - Queries (what each token is looking for)
K = W_K(X) # [2, 4, 64] - Keys (what each token represents)
V = W_V(X) # [2, 4, 64] - Values (information each token carries)
print(f"\nQuery shape: {Q.shape}") # [2, 4, 64]
print(f"Key shape: {K.shape}") # [2, 4, 64]
print(f"Value shape: {V.shape}") # [2, 4, 64]
# Key observation: All three have same dimensions
# This allows Q @ K^T to produce a square matrix
Dimensions: (batch, seq_len, d_k)
Purpose: Represents "what am I looking for?"
For each token, Q encodes what information that token needs from other tokens. Think of it as a search query.
Example: Token "she"
Q_she = W_Q @ embedding_she
β "Looking for: [female
referent, subject,
recent mention]"
Dimensions: (batch, seq_len, d_k)
Purpose: Represents "what do I contain?"
For each token, K encodes what information that token offers. These are matched against queries.
Example: Token "Alice"
K_Alice = W_K @ embedding_Alice
β "I contain: [female,
proper noun, agent,
human]"
Dimensions: (batch, seq_len, d_v)
Purpose: Actual information to combine
V contains the actual semantic content that will be mixed together based on attention weights.
Example: Token "Alice"
V_Alice = W_V @ embedding_Alice
β [rich semantic
representation of
"Alice" concept]
Option 1: No Projections (Bad β)
Q = K = V = X # Just use input directly
Problem:
- No flexibility
- Can't learn what to look for vs what to offer
- One representation must serve three purposes
- Severely limits model expressiveness
Option 2: Learned Projections (Good β )
Q = W_Q @ X # Learn what to search for
K = W_K @ X # Learn what to advertise
V = W_V @ X # Learn what to share
Benefits:
- W_Q, W_K, W_V learned via backpropagation
- Model discovers optimal queries and keys
- Flexibility to represent different aspects
- Dramatically improves model capacity
import torch
import torch.nn as nn
# Simplified example with small dimensions
d_model = 8 # Tiny for illustration
d_k = 4
# Simulate token embeddings (normally from embedding layer)
tokens = ["The", "cat", "sat"]
X = torch.tensor([
[0.1, 0.2, -0.1, 0.3, 0.4, -0.2, 0.1, 0.0], # "The"
[0.3, -0.1, 0.5, 0.2, 0.1, 0.4, -0.3, 0.2], # "cat"
[-0.2, 0.4, 0.1, -0.3, 0.5, 0.1, 0.2, -0.1] # "sat"
]).unsqueeze(0) # Add batch dimension: [1, 3, 8]
print("Input embeddings (X):")
print(X)
print(f"Shape: {X.shape}\n")
# Create learned projections
W_Q = nn.Linear(d_model, d_k, bias=False)
W_K = nn.Linear(d_model, d_k, bias=False)
W_V = nn.Linear(d_model, d_k, bias=False)
# Project
Q = W_Q(X) # [1, 3, 4]
K = W_K(X) # [1, 3, 4]
V = W_V(X) # [1, 3, 4]
print("Queries (Q):")
for i, token in enumerate(tokens):
print(f"{token}: {Q[0, i, :]}")
print("\nKeys (K):")
for i, token in enumerate(tokens):
print(f"{token}: {K[0, i, :]}")
print("\nValues (V):")
for i, token in enumerate(tokens):
print(f"{token}: {V[0, i, :]}")
# Observation: Q, K, V are all different!
# Model learned to transform input into three specialized representations
"What information am I looking for?"
Represents the information needs of each position
"What information do I contain?"
Represents what each token offers
"Here's my information"
Actual information to combine
Analogy: Imagine a library search.
Step 1: Compute how much each query "matches" with each key using the dot product. This creates an attention score matrix showing which tokens should attend to which.
# Similarity between all queries and keys
scores = Q @ K.transpose(-2, -1)
# Q: [2, 4, 64]
# K^T: [2, 64, 4]
# Result: [2, 4, 4]
# Each [4, 4] matrix: similarity between each pair of tokens
# scores[0, 1, 2] = similarity between token 1's query and token 2's key
print(f"Attention scores shape: {scores.shape}")
# Visualize what we just computed:
# For batch 0, token 1, show scores to all tokens
print(f"Token 1 attention scores: {scores[0, 1, :]}")
Mathematical Intuition:
The dot product between two vectors measures both their alignment AND magnitude:
# Example: Simple 2D dot products
query = torch.tensor([1.0, 0.0]) # Looking for "dimension 1"
key_1 = torch.tensor([1.0, 0.0]) # Contains "dimension 1" β
key_2 = torch.tensor([0.0, 1.0]) # Contains "dimension 2" β
dot_1 = (query * key_1).sum() # = 1.0*1.0 + 0.0*0.0 = 1.0 (high score)
dot_2 = (query * key_2).sum() # = 1.0*0.0 + 0.0*1.0 = 0.0 (low score)
# Query matches key_1 perfectly, doesn't match key_2 at all!
Geometric Interpretation:
Dot product = ||a|| Γ ||b|| Γ cos(ΞΈ)
Where ΞΈ is angle between vectors:
- ΞΈ = 0Β° β cos(ΞΈ) = 1 β Maximum similarity
- ΞΈ = 90Β° β cos(ΞΈ) = 0 β No similarity
- ΞΈ = 180Β° β cos(ΞΈ) = -1 β Opposite
import torch
import torch.nn as nn
# Simulate processing a real sentence
tokens = ["The", "cat", "sat", "on", "the", "mat"]
seq_len = 6
d_model = 512
d_k = 64
# Simulate embeddings (normally from embedding layer)
X = torch.randn(1, seq_len, d_model)
# Create projection layers
W_Q = nn.Linear(d_model, d_k, bias=False)
W_K = nn.Linear(d_model, d_k, bias=False)
# Project to Q and K
Q = W_Q(X) # [1, 6, 64]
K = W_K(X) # [1, 6, 64]
# Compute attention scores
scores = Q @ K.transpose(-2, -1) # [1, 6, 6]
print("Attention Score Matrix (before softmax):")
print("Rows = from token, Columns = to token\n")
print(" ", " ".join(f"{t:>6s}" for t in tokens))
for i, from_token in enumerate(tokens):
row = scores[0, i, :]
print(f"{from_token:>6s}:", " ".join(f"{s.item():>6.2f}" for s in row))
# Interesting observations:
# - Diagonal values (self-attention) are often high
# - "cat" might have high score with "sat" (subject-verb)
# - "on" might attend strongly to "mat" (preposition-object)
β οΈ Problem: Scores can get very large!
# When d_k = 64, scores can range from -20 to +20
# When d_k = 512, scores can range from -100 to +100!
# Why? Dot product sums d_k terms:
# score = q[0]*k[0] + q[1]*k[1] + ... + q[d_k-1]*k[d_k-1]
# = sum of d_k products
# With random initialization (values ~[-1, 1]):
# Expected value of score β 0
# Standard deviation β sqrt(d_k)
# For d_k=64: std β 8
# For d_k=512: std β 22.6
# This causes problems in the next step (softmax)!
Solution preview: We'll scale by 1/βd_k in the next section to keep scores in a reasonable range.
There's a crucial detail: we scale the scores by 1/βd_k. Why?
Problem: As d_k grows, dot products get larger. Large scores push softmax into flat regions where gradients vanish.
d_k=64: dot product typically ~40-100 (will softmax correctly)
d_k=768: dot product typically ~300-500 (softmax nearly all-or-nothing)
# Scale scores
d_k = 64
scores = Q @ K.transpose(-2, -1) # [2, 4, 4]
scores = scores / (d_k ** 0.5) # Divide by β64 = 8
# This keeps scores in reasonable range:
# Instead of [0, 500], we get roughly [-7, 7]
# Softmax works best in this range
Without scaling, training becomes unstable. With it, the model learns smoothly.
Step 2: Convert scores to probabilities (weights) using softmax:
import torch.nn.functional as F
# Softmax converts scores to probability distribution
attention_weights = F.softmax(scores, dim=-1)
# [2, 4, 4]
# Each row sums to 1.0
# Example row (attending to 4 tokens):
# attention_weights[0, 0, :] = [0.1, 0.7, 0.15, 0.05]
# β β β β
# token0 token1 token2 token3
# "Pay 10% attention to token 0"
# "Pay 70% attention to token 1"
# "Pay 15% attention to token 2"
# "Pay 5% attention to token 3"
This is beautiful: the model learns which tokens are relevant for each position!
Step 3: Use attention weights to combine the values:
# Apply attention weights to values
output = attention_weights @ V
# [2, 4, 4] @ [2, 4, 64] = [2, 4, 64]
# This is a weighted sum!
# output[0, 0, :] = 0.1*V[0,0,:] + 0.7*V[0,1,:] + 0.15*V[0,2,:] + 0.05*V[0,3,:]
# = "70% token 1 + 10% token 0 + 15% token 2 + 5% token 3"
# Each token's output is now a context-aware combination of all inputs!
Compare before and after:
Before attention:
Token 1 embedding: [0.2, -0.5, 0.8, ...] β Just token 1's info
After attention:
Token 1 output: [0.15, -0.3, 0.6, ...] β Blended with neighbors
β Now includes context!
Here's the complete self-attention formula and implementation:
Attention(Q, K, V) = softmax(QK^T / βd_k) V
import torch
import torch.nn.functional as F
class SelfAttention(torch.nn.Module):
def __init__(self, d_model, d_k, d_v):
super().__init__()
self.W_Q = torch.nn.Linear(d_model, d_k)
self.W_K = torch.nn.Linear(d_model, d_k)
self.W_V = torch.nn.Linear(d_model, d_v)
self.d_k = d_k
def forward(self, X):
"""
X: token embeddings [batch_size, seq_len, d_model]
"""
# Project to Q, K, V
Q = self.W_Q(X) # [batch, seq_len, d_k]
K = self.W_K(X) # [batch, seq_len, d_k]
V = self.W_V(X) # [batch, seq_len, d_v]
# Compute attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) # [batch, seq_len, seq_len]
scores = scores / (self.d_k ** 0.5) # Scale
# Softmax to get attention weights
attention_weights = F.softmax(scores, dim=-1) # [batch, seq_len, seq_len]
# Apply to values
output = torch.matmul(attention_weights, V) # [batch, seq_len, d_v]
return output, attention_weights
# Usage
d_model = 512
d_k = d_v = 64
attention = SelfAttention(d_model, d_k, d_v)
X = torch.randn(2, 4, 512) # 2 sequences, 4 tokens, 512D
output, weights = attention(X)
print(output.shape) # [2, 4, 64]
print(weights.shape) # [2, 4, 4]
print(weights.sum(dim=-1)) # All rows sum to ~1.0
The weight matrices (W_Q, W_K, W_V) are learned during training. The model learns:
Different layers and attention heads learn different patterns. Some might learn:
Focus on grammatical dependencies (subject-verb, modifiers)
Focus on conceptual relationships (co-reference, meaning)
Focus on nearby tokens (local context)
Focus on distant relevant tokens
One of the most powerful aspects of self-attention is that we can visualize what the model is "paying attention to". Let's build visualization tools:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
def visualize_attention(attention_weights, tokens, layer_name="Self-Attention"):
"""
Visualize attention weights as a heatmap.
Args:
attention_weights: [seq_len, seq_len] tensor
tokens: list of token strings
layer_name: name for plot title
"""
# Convert to numpy
attn = attention_weights.detach().cpu().numpy()
# Create heatmap
fig, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(attn,
annot=True, # Show values
fmt='.2f', # 2 decimal places
cmap='YlOrRd', # Color scheme
xticklabels=tokens,
yticklabels=tokens,
cbar_kws={'label': 'Attention Weight'},
ax=ax)
ax.set_xlabel('Key (attending to)', fontsize=12)
ax.set_ylabel('Query (attending from)', fontsize=12)
ax.set_title(f'{layer_name}: Attention Weights', fontsize=14)
plt.tight_layout()
plt.show()
# Example usage
tokens = ["The", "cat", "sat", "on", "the", "mat"]
seq_len = len(tokens)
# Create sample attention weights (normally from model)
# Here we'll create plausible patterns
attn_weights = torch.tensor([
[0.40, 0.15, 0.10, 0.10, 0.15, 0.10], # "The" β attends to itself, "cat"
[0.10, 0.50, 0.25, 0.05, 0.05, 0.05], # "cat" β itself, "sat"
[0.08, 0.35, 0.40, 0.12, 0.03, 0.02], # "sat" β "cat" (subject), itself
[0.05, 0.05, 0.10, 0.35, 0.15, 0.30], # "on" β itself, "mat"
[0.40, 0.08, 0.07, 0.10, 0.25, 0.10], # "the" β first "The", itself
[0.05, 0.10, 0.05, 0.25, 0.15, 0.40], # "mat" β "on", itself
])
visualize_attention(attn_weights, tokens)
[0.9, 0.05, 0.05]
[0.05, 0.9, 0.05]
[0.05, 0.05, 0.9]
Meaning: Strong self-attention. Each token primarily attends to itself.
Common in: Early layers, positional encoding layers, residual paths
[0.33, 0.33, 0.33]
[0.33, 0.33, 0.33]
[0.33, 0.33, 0.33]
Meaning: Equal attention to all tokens. No specific focus.
Common in: Untrained models, layers that aggregate global context
[0.05, 0.90, 0.05]
[0.10, 0.10, 0.80]
[0.85, 0.10, 0.05]
Meaning: Sharp attention to specific tokens.
Common in: Deep layers, syntactic dependencies (verbβsubject), coreference resolution
[0.7, 0.2, 0.1, 0.0]
[0.2, 0.5, 0.2, 0.1]
[0.1, 0.2, 0.5, 0.2]
[0.0, 0.1, 0.2, 0.7]
Meaning: Attention to nearby tokens (band-diagonal).
Common in: Character-level models, capturing local context, some efficient transformer variants
from transformers import BertTokenizer, BertModel
import torch
# Load pre-trained BERT
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased', output_attentions=True)
model.eval()
# Tokenize sentence
text = "The cat sat on the mat"
inputs = tokenizer(text, return_tensors='pt')
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
print(f"Tokens: {tokens}")
# Get attention weights
with torch.no_grad():
outputs = model(**inputs)
attentions = outputs.attentions # Tuple of 12 layers
# Each attention is [batch, num_heads, seq_len, seq_len]
# Let's examine layer 6, head 3
layer = 6
head = 3
attn = attentions[layer][0, head, :, :].cpu().numpy()
print(f"\nLayer {layer}, Head {head} Attention Pattern:")
print(f"Shape: {attn.shape}")
# Find strongest attention connections
for i, from_token in enumerate(tokens):
# Get top 3 attended tokens for this position
top_indices = attn[i].argsort()[-3:][::-1]
top_weights = attn[i][top_indices]
print(f"\n{from_token} attends to:")
for idx, weight in zip(top_indices, top_weights):
print(f" {tokens[idx]}: {weight:.3f}")
# Visualize specific head
visualize_attention(torch.tensor(attn), tokens, f"BERT Layer {layer} Head {head}")
Research Finding: Attention heads specialize!
π Reference: "Analyzing Multi-Head Self-Attention: Specialized Heads Do the Heavy Lifting, the Rest Can Be Pruned" (Michel et al., 2019)
When generating text, we can't look at future tokens. Causal attention (also called masked attention) prevents the model from "cheating" by seeing tokens that haven't been generated yet.
β οΈ Problem with bidirectional attention during generation:
Generating: "The cat sat ___"
If we allow full bidirectional attention:
- Position 3 could attend to position 4 (future!)
- Model would learn to just copy the answer
- No actual language modeling
Solution: Mask future positions!
import torch
import torch.nn.functional as F
# Create causal mask (upper triangular matrix)
seq_len = 5
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
print("Causal Mask (True = masked/forbidden):")
print(mask)
# Output:
# [[False, True, True, True, True],
# [False, False, True, True, True],
# [False, False, False, True, True],
# [False, False, False, False, True],
# [False, False, False, False, False]]
# Visualize what each token can attend to
tokens = ["The", "cat", "sat", "on", "mat"]
print("\nAttention Constraints:")
for i in range(seq_len):
allowed = [tokens[j] for j in range(seq_len) if not mask[i, j]]
print(f"{tokens[i]:>4s} can attend to: {', '.join(allowed)}")
# Output:
# The can attend to: The
# cat can attend to: The, cat
# sat can attend to: The, cat, sat
# on can attend to: The, cat, sat, on
# mat can attend to: The, cat, sat, on, mat
def causal_attention(Q, K, V, mask=None):
"""
Compute causal self-attention with optional masking.
Args:
Q, K, V: [batch, seq_len, d_k] tensors
mask: [seq_len, seq_len] boolean tensor (True = mask)
"""
d_k = Q.size(-1)
# Compute scores
scores = Q @ K.transpose(-2, -1) / (d_k ** 0.5)
# Shape: [batch, seq_len, seq_len]
if mask is not None:
# Set masked positions to -inf BEFORE softmax
# This makes softmax output 0 for those positions
scores = scores.masked_fill(mask.unsqueeze(0), float('-inf'))
# Softmax converts -inf to 0.0
attention_weights = F.softmax(scores, dim=-1)
# Apply to values
output = attention_weights @ V
return output, attention_weights
# Example usage
batch_size = 2
seq_len = 4
d_k = 64
Q = torch.randn(batch_size, seq_len, d_k)
K = torch.randn(batch_size, seq_len, d_k)
V = torch.randn(batch_size, seq_len, d_k)
# Create causal mask
causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
# Compute with mask
output, attn = causal_attention(Q, K, V, mask=causal_mask)
print("Attention weights shape:", attn.shape)
print("\nAttention weights (batch 0):")
print(attn[0])
# Verify: each row should only have non-zero values up to its position
print("\nVerification - Position 0 (can only see position 0):")
print(attn[0, 0, :]) # Should be [1.0, 0.0, 0.0, 0.0]
print("\nPosition 2 (can see positions 0, 1, 2):")
print(attn[0, 2, :]) # Should be [~0.33, ~0.33, ~0.33, 0.0]
Mathematical Reasoning:
# Softmax formula: softmax(x_i) = exp(x_i) / Ξ£ exp(x_j)
# If we set x_i = -inf:
# exp(-inf) = 0
# So the softmax output for that position = 0 / (sum) = 0
# This is exactly what we want: zero attention weight!
Alternative (Wrong β):
# Bad: Set masked positions to 0 AFTER softmax
attention_weights = F.softmax(scores, dim=-1)
attention_weights = attention_weights.masked_fill(mask, 0.0)
# Problem: Rows no longer sum to 1.0!
# This breaks the probability distribution property
import matplotlib.pyplot as plt
def compare_attention_types():
"""Visualize difference between causal and bidirectional attention."""
seq_len = 6
tokens = ["The", "cat", "sat", "on", "the", "mat"]
# Create sample scores
scores = torch.randn(1, seq_len, seq_len) * 2
# Bidirectional (no mask)
attn_bi = F.softmax(scores / 8, dim=-1)[0]
# Causal (with mask)
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
scores_masked = scores.clone()
scores_masked = scores_masked.masked_fill(mask.unsqueeze(0), float('-inf'))
attn_causal = F.softmax(scores_masked / 8, dim=-1)[0]
# Plot side by side
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
# Bidirectional
sns.heatmap(attn_bi.numpy(), annot=True, fmt='.2f', cmap='YlOrRd',
xticklabels=tokens, yticklabels=tokens, ax=ax1)
ax1.set_title('Bidirectional Attention\n(BERT-style)', fontsize=14)
ax1.set_xlabel('Attending to')
ax1.set_ylabel('Attending from')
# Causal
sns.heatmap(attn_causal.numpy(), annot=True, fmt='.2f', cmap='YlOrRd',
xticklabels=tokens, yticklabels=tokens, ax=ax2)
ax2.set_title('Causal Attention\n(GPT-style)', fontsize=14)
ax2.set_xlabel('Attending to')
ax2.set_ylabel('Attending from')
plt.tight_layout()
plt.show()
compare_attention_types()
Key Insight: Causal attention is essential for autoregressive language models (GPT, LLaMA), while bidirectional attention is used in encoder models (BERT, RoBERTa) where we have the full context available.
Understanding the cost:
Time Complexity: O(nΒ² Γ d)
Space Complexity: O(nΒ²)
But: Both are highly parallelizable on GPUs (matrix operations)
This is why transformers need powerful hardware but scale with data/parameters efficiently.
Q1: What are the three components of self-attention?
Q2: How are attention scores computed?
Q3: Why do we apply softmax to attention scores?
Q4: What does scaled dot-product attention divide by?
Q5: What is the output of the attention mechanism?