πŸ“œ Free Certificate Upon Completion - Earn a verified certificate when you complete all 7 modules in the Transformers Architecture course.

Self-Attention Mechanism

πŸ“š Tutorial 3 🟑 Intermediate

Master the mathematical foundations of attention - the core mechanism powering all modern AI

πŸŽ“ Complete all tutorials to earn your Free Transformers Architecture Certificate
Shareable on LinkedIn β€’ Verified by AITutorials.site β€’ No signup fee

From Concept to Mathematics

You understand attention conceptually: "for each position, decide which other positions are important." Now let's dive deep into the mathematics, implementation details, and intuition behind how this mechanism actually works in production transformers.

Goal: Take token embeddings and transform them using attention so each token becomes "aware" of relevant context. By the end, you'll understand every matrix multiplication and design choice.

The Self-Attention Journey

Self-attention transforms a sequence of tokens from independent embeddings to context-aware representations in 6 steps:

The 6-Step Process

  1. Project to Q, K, V: Transform input embeddings into three different representations
  2. Compute Scores: Calculate similarity between all query-key pairs (QK^T)
  3. Scale: Divide by √d_k to prevent gradient vanishing
  4. Mask (optional): Block attention to certain positions (e.g., future tokens)
  5. Softmax: Convert scores to probability distribution (attention weights)
  6. Apply to Values: Use weights to aggregate value vectors

🎯 Key Insight: Each step has a specific purpose. Understanding why each exists is crucial for debugging and optimizing transformer models.

Visual Overview: Information Flow

INPUT: Token Embeddings (seq_len Γ— d_model)
    "The cat sat on mat"
    ↓
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Step 1: Linear Projections           β”‚
β”‚ Q = X @ W_Q  (what I'm looking for)  β”‚
β”‚ K = X @ W_K  (what I contain)        β”‚
β”‚ V = X @ W_V  (my information)        β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
    ↓
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Step 2: Compute Similarity Matrix     β”‚
β”‚ Scores = Q @ K^T                      β”‚
β”‚ Result: (seq_len Γ— seq_len)          β”‚
β”‚                                       β”‚
β”‚       The  cat  sat  on  mat          β”‚
β”‚ The  [0.8  0.2  0.1  0.1  0.1]       β”‚
β”‚ cat  [0.3  0.9  0.4  0.1  0.2]       β”‚
β”‚ sat  [0.2  0.5  0.8  0.3  0.4]       β”‚
β”‚ ...                                   β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
    ↓
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Step 3: Scale by √d_k                 β”‚
β”‚ Scores = Scores / √64                 β”‚
β”‚ (prevents softmax saturation)         β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
    ↓
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Step 4: Softmax β†’ Probabilities       β”‚
β”‚ Weights = softmax(Scores)             β”‚
β”‚ Each row sums to 1.0                  β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
    ↓
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Step 5: Apply to Values               β”‚
β”‚ Output = Weights @ V                  β”‚
β”‚ Weighted combination of all tokens    β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
    ↓
OUTPUT: Context-Aware Representations
    Each token now "knows" about others!

The Three Projections: Q, K, V - The Foundation

Self-attention starts by creating three different "views" of the input. This might seem redundant (why not just use X?), but each projection serves a distinct purpose.

Why Three Separate Projections?

πŸ’‘ Intuition: The Restaurant Analogy

Imagine you're at a restaurant:

  • Query (Q): Your food preferences - "I want spicy vegetarian food"
  • Key (K): Menu item descriptions - "Spicy Thai Curry (vegetarian, hot)"
  • Value (V): The actual dishes - the food itself

You compare your query against menu keys to decide which items match your preferences. Then you get the actual values (food) corresponding to your selections. You don't eat the menu description - you eat the actual dish!

Mathematical Formulation

import torch
import torch.nn as nn

# Input: embeddings of all tokens
# Shape: [batch_size, seq_len, d_model]
# Example: [2, 4, 512] (2 sequences, 4 tokens, 512D embeddings)
X = torch.randn(2, 4, 512)

print(f"Input shape: {X.shape}")
print(f"Each token is represented as a {X.shape[-1]}-dimensional vector")

# Learned linear transformations (Weight matrices)
# These are learned during training via backpropagation
d_model = 512
d_k = 64  # Typical: d_model / num_heads
d_v = 64

W_Q = nn.Linear(d_model, d_k, bias=False)  # Query projection
W_K = nn.Linear(d_model, d_k, bias=False)  # Key projection  
W_V = nn.Linear(d_model, d_v, bias=False)  # Value projection

# Project input to Q, K, V
Q = W_Q(X)  # [2, 4, 64] - Queries (what each token is looking for)
K = W_K(X)  # [2, 4, 64] - Keys (what each token represents)
V = W_V(X)  # [2, 4, 64] - Values (information each token carries)

print(f"\nQuery shape:  {Q.shape}")  # [2, 4, 64]
print(f"Key shape:    {K.shape}")    # [2, 4, 64]
print(f"Value shape:  {V.shape}")    # [2, 4, 64]

# Key observation: All three have same dimensions
# This allows Q @ K^T to produce a square matrix

Detailed Breakdown: What Are We Computing?

πŸ” Query Matrix (Q)

Dimensions: (batch, seq_len, d_k)

Purpose: Represents "what am I looking for?"

For each token, Q encodes what information that token needs from other tokens. Think of it as a search query.

Example: Token "she"
Q_she = W_Q @ embedding_she
      β‰ˆ "Looking for: [female
           referent, subject,
           recent mention]"

πŸ”‘ Key Matrix (K)

Dimensions: (batch, seq_len, d_k)

Purpose: Represents "what do I contain?"

For each token, K encodes what information that token offers. These are matched against queries.

Example: Token "Alice"
K_Alice = W_K @ embedding_Alice
        β‰ˆ "I contain: [female,
             proper noun, agent,
             human]"

πŸ’Ž Value Matrix (V)

Dimensions: (batch, seq_len, d_v)

Purpose: Actual information to combine

V contains the actual semantic content that will be mixed together based on attention weights.

Example: Token "Alice"
V_Alice = W_V @ embedding_Alice
        β‰ˆ [rich semantic
           representation of
           "Alice" concept]

Why Learned Projections vs. Using X Directly?

Option 1: No Projections (Bad ❌)

Q = K = V = X  # Just use input directly

Problem:
- No flexibility
- Can't learn what to look for vs what to offer
- One representation must serve three purposes
- Severely limits model expressiveness

Option 2: Learned Projections (Good βœ…)

Q = W_Q @ X  # Learn what to search for
K = W_K @ X  # Learn what to advertise
V = W_V @ X  # Learn what to share

Benefits:
- W_Q, W_K, W_V learned via backpropagation
- Model discovers optimal queries and keys
- Flexibility to represent different aspects
- Dramatically improves model capacity

Concrete Example: Processing "The cat sat"

import torch
import torch.nn as nn

# Simplified example with small dimensions
d_model = 8  # Tiny for illustration
d_k = 4

# Simulate token embeddings (normally from embedding layer)
tokens = ["The", "cat", "sat"]
X = torch.tensor([
    [0.1, 0.2, -0.1, 0.3, 0.4, -0.2, 0.1, 0.0],  # "The"
    [0.3, -0.1, 0.5, 0.2, 0.1, 0.4, -0.3, 0.2],  # "cat"
    [-0.2, 0.4, 0.1, -0.3, 0.5, 0.1, 0.2, -0.1]   # "sat"
]).unsqueeze(0)  # Add batch dimension: [1, 3, 8]

print("Input embeddings (X):")
print(X)
print(f"Shape: {X.shape}\n")

# Create learned projections
W_Q = nn.Linear(d_model, d_k, bias=False)
W_K = nn.Linear(d_model, d_k, bias=False)
W_V = nn.Linear(d_model, d_k, bias=False)

# Project
Q = W_Q(X)  # [1, 3, 4]
K = W_K(X)  # [1, 3, 4]
V = W_V(X)  # [1, 3, 4]

print("Queries (Q):")
for i, token in enumerate(tokens):
    print(f"{token}: {Q[0, i, :]}")

print("\nKeys (K):")
for i, token in enumerate(tokens):
    print(f"{token}: {K[0, i, :]}")

print("\nValues (V):")
for i, token in enumerate(tokens):
    print(f"{token}: {V[0, i, :]}")

# Observation: Q, K, V are all different!
# Model learned to transform input into three specialized representations

What Do Q, K, V Represent?

Query (Q)

"What information am I looking for?"

Represents the information needs of each position

Key (K)

"What information do I contain?"

Represents what each token offers

Value (V)

"Here's my information"

Actual information to combine

Analogy: Imagine a library search.

  • Query: "I'm looking for books about machine learning"
  • Keys: Each book's tags/keywords
  • Values: The actual content of books

Computing Attention Scores - The Similarity Matrix

Step 1: Compute how much each query "matches" with each key using the dot product. This creates an attention score matrix showing which tokens should attend to which.

The Core Computation: Q @ K^T

# Similarity between all queries and keys
scores = Q @ K.transpose(-2, -1)
# Q: [2, 4, 64]
# K^T: [2, 64, 4]
# Result: [2, 4, 4]

# Each [4, 4] matrix: similarity between each pair of tokens
# scores[0, 1, 2] = similarity between token 1's query and token 2's key
print(f"Attention scores shape: {scores.shape}")

# Visualize what we just computed:
# For batch 0, token 1, show scores to all tokens
print(f"Token 1 attention scores: {scores[0, 1, :]}")

Why Dot Product Measures Similarity

Mathematical Intuition:

The dot product between two vectors measures both their alignment AND magnitude:

# Example: Simple 2D dot products
query = torch.tensor([1.0, 0.0])  # Looking for "dimension 1"
key_1 = torch.tensor([1.0, 0.0])   # Contains "dimension 1" βœ“
key_2 = torch.tensor([0.0, 1.0])   # Contains "dimension 2" βœ—

dot_1 = (query * key_1).sum()  # = 1.0*1.0 + 0.0*0.0 = 1.0 (high score)
dot_2 = (query * key_2).sum()  # = 1.0*0.0 + 0.0*1.0 = 0.0 (low score)

# Query matches key_1 perfectly, doesn't match key_2 at all!

Geometric Interpretation:

Dot product = ||a|| Γ— ||b|| Γ— cos(ΞΈ)

Where ΞΈ is angle between vectors:
- ΞΈ = 0Β°   β†’ cos(ΞΈ) = 1   β†’ Maximum similarity
- ΞΈ = 90Β°  β†’ cos(ΞΈ) = 0   β†’ No similarity
- ΞΈ = 180Β° β†’ cos(ΞΈ) = -1  β†’ Opposite

Real Example: "The cat sat on the mat"

import torch
import torch.nn as nn

# Simulate processing a real sentence
tokens = ["The", "cat", "sat", "on", "the", "mat"]
seq_len = 6
d_model = 512
d_k = 64

# Simulate embeddings (normally from embedding layer)
X = torch.randn(1, seq_len, d_model)

# Create projection layers
W_Q = nn.Linear(d_model, d_k, bias=False)
W_K = nn.Linear(d_model, d_k, bias=False)

# Project to Q and K
Q = W_Q(X)  # [1, 6, 64]
K = W_K(X)  # [1, 6, 64]

# Compute attention scores
scores = Q @ K.transpose(-2, -1)  # [1, 6, 6]

print("Attention Score Matrix (before softmax):")
print("Rows = from token, Columns = to token\n")
print("        ", " ".join(f"{t:>6s}" for t in tokens))
for i, from_token in enumerate(tokens):
    row = scores[0, i, :]
    print(f"{from_token:>6s}:", " ".join(f"{s.item():>6.2f}" for s in row))

# Interesting observations:
# - Diagonal values (self-attention) are often high
# - "cat" might have high score with "sat" (subject-verb)
# - "on" might attend strongly to "mat" (preposition-object)

Understanding Score Magnitudes

⚠️ Problem: Scores can get very large!

# When d_k = 64, scores can range from -20 to +20
# When d_k = 512, scores can range from -100 to +100!

# Why? Dot product sums d_k terms:
# score = q[0]*k[0] + q[1]*k[1] + ... + q[d_k-1]*k[d_k-1]
#       = sum of d_k products

# With random initialization (values ~[-1, 1]):
# Expected value of score β‰ˆ 0
# Standard deviation β‰ˆ sqrt(d_k)

# For d_k=64: std β‰ˆ 8
# For d_k=512: std β‰ˆ 22.6

# This causes problems in the next step (softmax)!

Solution preview: We'll scale by 1/√d_k in the next section to keep scores in a reasonable range.

The Scaling Factor

There's a crucial detail: we scale the scores by 1/√d_k. Why?

Problem: As d_k grows, dot products get larger. Large scores push softmax into flat regions where gradients vanish.

d_k=64: dot product typically ~40-100 (will softmax correctly)
d_k=768: dot product typically ~300-500 (softmax nearly all-or-nothing)


# Scale scores
d_k = 64
scores = Q @ K.transpose(-2, -1)  # [2, 4, 4]
scores = scores / (d_k ** 0.5)     # Divide by √64 = 8

# This keeps scores in reasonable range:
# Instead of [0, 500], we get roughly [-7, 7]
# Softmax works best in this range

Without scaling, training becomes unstable. With it, the model learns smoothly.

Computing Attention Weights with Softmax

Step 2: Convert scores to probabilities (weights) using softmax:


import torch.nn.functional as F

# Softmax converts scores to probability distribution
attention_weights = F.softmax(scores, dim=-1)
# [2, 4, 4]

# Each row sums to 1.0
# Example row (attending to 4 tokens):
# attention_weights[0, 0, :] = [0.1, 0.7, 0.15, 0.05]
#                               ↑    ↑    ↑     ↑
#                            token0 token1 token2 token3
#                          "Pay 10% attention to token 0"
#                          "Pay 70% attention to token 1"
#                          "Pay 15% attention to token 2"
#                          "Pay 5% attention to token 3"

This is beautiful: the model learns which tokens are relevant for each position!

Interpretation: attention_weights[i, j, k] tells us: "For position i, how much attention should position j pay to position k?"

Applying Attention to Values

Step 3: Use attention weights to combine the values:


# Apply attention weights to values
output = attention_weights @ V
# [2, 4, 4] @ [2, 4, 64] = [2, 4, 64]

# This is a weighted sum!
# output[0, 0, :] = 0.1*V[0,0,:] + 0.7*V[0,1,:] + 0.15*V[0,2,:] + 0.05*V[0,3,:]
#                 = "70% token 1 + 10% token 0 + 15% token 2 + 5% token 3"

# Each token's output is now a context-aware combination of all inputs!

Compare before and after:

Before attention:
Token 1 embedding: [0.2, -0.5, 0.8, ...]  ← Just token 1's info

After attention:
Token 1 output: [0.15, -0.3, 0.6, ...]    ← Blended with neighbors
                ↑ Now includes context!

Putting It All Together

Here's the complete self-attention formula and implementation:

Attention(Q, K, V) = softmax(QK^T / √d_k) V

import torch
import torch.nn.functional as F

class SelfAttention(torch.nn.Module):
    def __init__(self, d_model, d_k, d_v):
        super().__init__()
        self.W_Q = torch.nn.Linear(d_model, d_k)
        self.W_K = torch.nn.Linear(d_model, d_k)
        self.W_V = torch.nn.Linear(d_model, d_v)
        self.d_k = d_k
    
    def forward(self, X):
        """
        X: token embeddings [batch_size, seq_len, d_model]
        """
        # Project to Q, K, V
        Q = self.W_Q(X)  # [batch, seq_len, d_k]
        K = self.W_K(X)  # [batch, seq_len, d_k]
        V = self.W_V(X)  # [batch, seq_len, d_v]
        
        # Compute attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1))  # [batch, seq_len, seq_len]
        scores = scores / (self.d_k ** 0.5)  # Scale
        
        # Softmax to get attention weights
        attention_weights = F.softmax(scores, dim=-1)  # [batch, seq_len, seq_len]
        
        # Apply to values
        output = torch.matmul(attention_weights, V)  # [batch, seq_len, d_v]
        
        return output, attention_weights

# Usage
d_model = 512
d_k = d_v = 64
attention = SelfAttention(d_model, d_k, d_v)

X = torch.randn(2, 4, 512)  # 2 sequences, 4 tokens, 512D
output, weights = attention(X)

print(output.shape)      # [2, 4, 64]
print(weights.shape)     # [2, 4, 4]
print(weights.sum(dim=-1))  # All rows sum to ~1.0

What Is the Model Learning?

The weight matrices (W_Q, W_K, W_V) are learned during training. The model learns:

  • W_Q: How to project tokens to "what I'm looking for"
  • W_K: How to project tokens to "what I contain" (comparable to queries)
  • W_V: How to project tokens to "information to share"

Different layers and attention heads learn different patterns. Some might learn:

Syntactic Attention

Focus on grammatical dependencies (subject-verb, modifiers)

Semantic Attention

Focus on conceptual relationships (co-reference, meaning)

Position Attention

Focus on nearby tokens (local context)

Long-Range Attention

Focus on distant relevant tokens

Visualizing Attention Patterns

One of the most powerful aspects of self-attention is that we can visualize what the model is "paying attention to". Let's build visualization tools:

Heatmap Visualization

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

def visualize_attention(attention_weights, tokens, layer_name="Self-Attention"):
    """
    Visualize attention weights as a heatmap.
    
    Args:
        attention_weights: [seq_len, seq_len] tensor
        tokens: list of token strings
        layer_name: name for plot title
    """
    # Convert to numpy
    attn = attention_weights.detach().cpu().numpy()
    
    # Create heatmap
    fig, ax = plt.subplots(figsize=(10, 8))
    sns.heatmap(attn, 
                annot=True,  # Show values
                fmt='.2f',   # 2 decimal places
                cmap='YlOrRd',  # Color scheme
                xticklabels=tokens,
                yticklabels=tokens,
                cbar_kws={'label': 'Attention Weight'},
                ax=ax)
    
    ax.set_xlabel('Key (attending to)', fontsize=12)
    ax.set_ylabel('Query (attending from)', fontsize=12)
    ax.set_title(f'{layer_name}: Attention Weights', fontsize=14)
    plt.tight_layout()
    plt.show()

# Example usage
tokens = ["The", "cat", "sat", "on", "the", "mat"]
seq_len = len(tokens)

# Create sample attention weights (normally from model)
# Here we'll create plausible patterns
attn_weights = torch.tensor([
    [0.40, 0.15, 0.10, 0.10, 0.15, 0.10],  # "The" β†’ attends to itself, "cat"
    [0.10, 0.50, 0.25, 0.05, 0.05, 0.05],  # "cat" β†’ itself, "sat"
    [0.08, 0.35, 0.40, 0.12, 0.03, 0.02],  # "sat" β†’ "cat" (subject), itself
    [0.05, 0.05, 0.10, 0.35, 0.15, 0.30],  # "on" β†’ itself, "mat"
    [0.40, 0.08, 0.07, 0.10, 0.25, 0.10],  # "the" β†’ first "The", itself
    [0.05, 0.10, 0.05, 0.25, 0.15, 0.40],  # "mat" β†’ "on", itself
])

visualize_attention(attn_weights, tokens)

What Do Different Attention Patterns Mean?

🎯 Diagonal Pattern

[0.9, 0.05, 0.05]
[0.05, 0.9, 0.05]
[0.05, 0.05, 0.9]

Meaning: Strong self-attention. Each token primarily attends to itself.

Common in: Early layers, positional encoding layers, residual paths

πŸ“Š Uniform Pattern

[0.33, 0.33, 0.33]
[0.33, 0.33, 0.33]
[0.33, 0.33, 0.33]

Meaning: Equal attention to all tokens. No specific focus.

Common in: Untrained models, layers that aggregate global context

πŸ” Focused Pattern

[0.05, 0.90, 0.05]
[0.10, 0.10, 0.80]
[0.85, 0.10, 0.05]

Meaning: Sharp attention to specific tokens.

Common in: Deep layers, syntactic dependencies (verb→subject), coreference resolution

πŸ“ Local Pattern

[0.7, 0.2, 0.1, 0.0]
[0.2, 0.5, 0.2, 0.1]
[0.1, 0.2, 0.5, 0.2]
[0.0, 0.1, 0.2, 0.7]

Meaning: Attention to nearby tokens (band-diagonal).

Common in: Character-level models, capturing local context, some efficient transformer variants

Real Example: Analyzing BERT's Attention

from transformers import BertTokenizer, BertModel
import torch

# Load pre-trained BERT
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased', output_attentions=True)
model.eval()

# Tokenize sentence
text = "The cat sat on the mat"
inputs = tokenizer(text, return_tensors='pt')
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])

print(f"Tokens: {tokens}")

# Get attention weights
with torch.no_grad():
    outputs = model(**inputs)
    attentions = outputs.attentions  # Tuple of 12 layers
    
# Each attention is [batch, num_heads, seq_len, seq_len]
# Let's examine layer 6, head 3
layer = 6
head = 3
attn = attentions[layer][0, head, :, :].cpu().numpy()

print(f"\nLayer {layer}, Head {head} Attention Pattern:")
print(f"Shape: {attn.shape}")

# Find strongest attention connections
for i, from_token in enumerate(tokens):
    # Get top 3 attended tokens for this position
    top_indices = attn[i].argsort()[-3:][::-1]
    top_weights = attn[i][top_indices]
    
    print(f"\n{from_token} attends to:")
    for idx, weight in zip(top_indices, top_weights):
        print(f"  {tokens[idx]}: {weight:.3f}")

# Visualize specific head
visualize_attention(torch.tensor(attn), tokens, f"BERT Layer {layer} Head {head}")

Common Attention Patterns in Trained Models

Research Finding: Attention heads specialize!

  • Positional Heads: Attend to specific relative positions (Β±1, Β±2 tokens)
  • Syntactic Heads: Capture grammatical relationships (subject-verb, determiner-noun)
  • Semantic Heads: Link semantically related words (synonyms, coreferences)
  • Rare Word Heads: Attend strongly to unique/important tokens
  • Delimiter Heads: Focus on punctuation and sentence boundaries
  • Broadcast Heads: Uniform attention - aggregate global context

πŸ“š Reference: "Analyzing Multi-Head Self-Attention: Specialized Heads Do the Heavy Lifting, the Rest Can Be Pruned" (Michel et al., 2019)

Causal Attention (For Generation)

When generating text, we can't look at future tokens. Causal attention (also called masked attention) prevents the model from "cheating" by seeing tokens that haven't been generated yet.

The Autoregressive Constraint

⚠️ Problem with bidirectional attention during generation:

Generating: "The cat sat ___"

If we allow full bidirectional attention:
- Position 3 could attend to position 4 (future!)
- Model would learn to just copy the answer
- No actual language modeling

Solution: Mask future positions!

Implementing Causal Masking

import torch
import torch.nn.functional as F

# Create causal mask (upper triangular matrix)
seq_len = 5
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()

print("Causal Mask (True = masked/forbidden):")
print(mask)
# Output:
# [[False,  True,  True,  True,  True],
#  [False, False,  True,  True,  True],
#  [False, False, False,  True,  True],
#  [False, False, False, False,  True],
#  [False, False, False, False, False]]

# Visualize what each token can attend to
tokens = ["The", "cat", "sat", "on", "mat"]
print("\nAttention Constraints:")
for i in range(seq_len):
    allowed = [tokens[j] for j in range(seq_len) if not mask[i, j]]
    print(f"{tokens[i]:>4s} can attend to: {', '.join(allowed)}")

# Output:
# The can attend to: The
# cat can attend to: The, cat
# sat can attend to: The, cat, sat
#  on can attend to: The, cat, sat, on
# mat can attend to: The, cat, sat, on, mat

Applying the Mask During Attention Computation

def causal_attention(Q, K, V, mask=None):
    """
    Compute causal self-attention with optional masking.
    
    Args:
        Q, K, V: [batch, seq_len, d_k] tensors
        mask: [seq_len, seq_len] boolean tensor (True = mask)
    """
    d_k = Q.size(-1)
    
    # Compute scores
    scores = Q @ K.transpose(-2, -1) / (d_k ** 0.5)
    # Shape: [batch, seq_len, seq_len]
    
    if mask is not None:
        # Set masked positions to -inf BEFORE softmax
        # This makes softmax output 0 for those positions
        scores = scores.masked_fill(mask.unsqueeze(0), float('-inf'))
    
    # Softmax converts -inf to 0.0
    attention_weights = F.softmax(scores, dim=-1)
    
    # Apply to values
    output = attention_weights @ V
    
    return output, attention_weights

# Example usage
batch_size = 2
seq_len = 4
d_k = 64

Q = torch.randn(batch_size, seq_len, d_k)
K = torch.randn(batch_size, seq_len, d_k)
V = torch.randn(batch_size, seq_len, d_k)

# Create causal mask
causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()

# Compute with mask
output, attn = causal_attention(Q, K, V, mask=causal_mask)

print("Attention weights shape:", attn.shape)
print("\nAttention weights (batch 0):")
print(attn[0])

# Verify: each row should only have non-zero values up to its position
print("\nVerification - Position 0 (can only see position 0):")
print(attn[0, 0, :])  # Should be [1.0, 0.0, 0.0, 0.0]

print("\nPosition 2 (can see positions 0, 1, 2):")
print(attn[0, 2, :])  # Should be [~0.33, ~0.33, ~0.33, 0.0]

Why -inf for Masking?

Mathematical Reasoning:

# Softmax formula: softmax(x_i) = exp(x_i) / Ξ£ exp(x_j)

# If we set x_i = -inf:
# exp(-inf) = 0
# So the softmax output for that position = 0 / (sum) = 0

# This is exactly what we want: zero attention weight!

Alternative (Wrong ❌):

# Bad: Set masked positions to 0 AFTER softmax
attention_weights = F.softmax(scores, dim=-1)
attention_weights = attention_weights.masked_fill(mask, 0.0)

# Problem: Rows no longer sum to 1.0!
# This breaks the probability distribution property

Causal vs. Bidirectional Attention Comparison

import matplotlib.pyplot as plt

def compare_attention_types():
    """Visualize difference between causal and bidirectional attention."""
    seq_len = 6
    tokens = ["The", "cat", "sat", "on", "the", "mat"]
    
    # Create sample scores
    scores = torch.randn(1, seq_len, seq_len) * 2
    
    # Bidirectional (no mask)
    attn_bi = F.softmax(scores / 8, dim=-1)[0]
    
    # Causal (with mask)
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
    scores_masked = scores.clone()
    scores_masked = scores_masked.masked_fill(mask.unsqueeze(0), float('-inf'))
    attn_causal = F.softmax(scores_masked / 8, dim=-1)[0]
    
    # Plot side by side
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
    
    # Bidirectional
    sns.heatmap(attn_bi.numpy(), annot=True, fmt='.2f', cmap='YlOrRd',
                xticklabels=tokens, yticklabels=tokens, ax=ax1)
    ax1.set_title('Bidirectional Attention\n(BERT-style)', fontsize=14)
    ax1.set_xlabel('Attending to')
    ax1.set_ylabel('Attending from')
    
    # Causal
    sns.heatmap(attn_causal.numpy(), annot=True, fmt='.2f', cmap='YlOrRd',
                xticklabels=tokens, yticklabels=tokens, ax=ax2)
    ax2.set_title('Causal Attention\n(GPT-style)', fontsize=14)
    ax2.set_xlabel('Attending to')
    ax2.set_ylabel('Attending from')
    
    plt.tight_layout()
    plt.show()

compare_attention_types()

Key Insight: Causal attention is essential for autoregressive language models (GPT, LLaMA), while bidirectional attention is used in encoder models (BERT, RoBERTa) where we have the full context available.

Model Architecture Decision:
  • Encoder (BERT): Bidirectional attention - see full context
  • Decoder (GPT): Causal attention - only see past context
  • Encoder-Decoder (T5): Both - bidirectional in encoder, causal in decoder

Computational Complexity

Understanding the cost:

Time Complexity: O(nΒ² Γ— d)

  • n = sequence length
  • nΒ²: All pairs of tokens
  • d: Dimension of embeddings

Space Complexity: O(nΒ²)

  • Must store attention weights for all pairs

But: Both are highly parallelizable on GPUs (matrix operations)

This is why transformers need powerful hardware but scale with data/parameters efficiently.

Key Takeaways

  • Q, K, V Projections: Linear transformations creating different views of input
  • Dot-Product Similarity: Q@K^T measures alignment between queries and keys
  • Scaling Factor: √d_k prevents exploding scores and softmax issues
  • Softmax Weights: Convert scores to probability distribution
  • Value Combination: Weighted sum using learned attention weights
  • Learned Patterns: Model learns what to attend to through training
  • Causal Masking: Prevents looking at future tokens in generation
  • Computational Cost: O(nΒ²) but highly parallelizable

Test Your Knowledge

Q1: What are the three components of self-attention?

Input, Output, Hidden
Encoder, Decoder, Attention
Query, Key, Value
Token, Position, Embedding

Q2: How are attention scores computed?

By averaging all vectors
By taking the dot product of Query and Key, then applying softmax
By multiplying all inputs
By random selection

Q3: Why do we apply softmax to attention scores?

To make computations faster
To reduce memory usage
To increase model size
To normalize scores into a probability distribution

Q4: What does scaled dot-product attention divide by?

The square root of the key dimension
The number of tokens
The batch size
The sequence length

Q5: What is the output of the attention mechanism?

Only the attention scores
The original input unchanged
A weighted sum of Value vectors based on attention scores
The Key vectors