Analyze and visualize what transformer attention heads learn using pretrained models
In this project, you'll peek inside pretrained transformer models (BERT and GPT-2) to visualize and understand what different attention heads learn. You'll discover that different heads specialize in different linguistic patterns!
pip install transformerspip install matplotlib seabornimport torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import BertTokenizer, BertModel, GPT2Tokenizer, GPT2LMHeadModel
# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 8)
# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Load BERT
print("\nš„ Loading BERT-base-uncased...")
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased', output_attentions=True)
bert_model.to(device)
bert_model.eval()
print(f"ā
BERT loaded: {bert_model.config.num_hidden_layers} layers, "
f"{bert_model.config.num_attention_heads} heads per layer")
# Load GPT-2
print("\nš„ Loading GPT-2...")
gpt2_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
gpt2_model = GPT2LMHeadModel.from_pretrained('gpt2', output_attentions=True)
gpt2_model.to(device)
gpt2_model.eval()
print(f"ā
GPT-2 loaded: {gpt2_model.config.n_layer} layers, "
f"{gpt2_model.config.n_head} heads per layer")
print("\nš Models ready for analysis!")
def get_bert_attention(text, tokenizer, model, device):
"""
Extract attention weights from BERT.
Args:
text: Input text string
tokenizer: BERT tokenizer
model: BERT model
device: torch device
Returns:
tokens: List of tokens
attentions: Tuple of attention tensors (one per layer)
Each tensor: [batch=1, num_heads, seq_len, seq_len]
"""
# Tokenize
inputs = tokenizer(text, return_tensors='pt', add_special_tokens=True)
input_ids = inputs['input_ids'].to(device)
# Get tokens for visualization
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
# Forward pass
with torch.no_grad():
outputs = model(input_ids)
attentions = outputs.attentions # Tuple of [1, num_heads, seq_len, seq_len]
return tokens, attentions
def get_gpt2_attention(text, tokenizer, model, device):
"""
Extract attention weights from GPT-2.
Args:
text: Input text string
tokenizer: GPT-2 tokenizer
model: GPT-2 model
device: torch device
Returns:
tokens: List of tokens
attentions: Tuple of attention tensors (one per layer)
"""
# Tokenize
inputs = tokenizer(text, return_tensors='pt')
input_ids = inputs['input_ids'].to(device)
# Get tokens
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
# Forward pass
with torch.no_grad():
outputs = model(input_ids)
attentions = outputs.attentions
return tokens, attentions
# Test with example sentence
example_text = "The cat sat on the mat and looked at the bird."
print("\nš Extracting attention from BERT...")
bert_tokens, bert_attentions = get_bert_attention(example_text, bert_tokenizer, bert_model, device)
print(f"ā
Tokens: {bert_tokens}")
print(f"ā
Attention shape: {len(bert_attentions)} layers, "
f"each {bert_attentions[0].shape}")
print("\nš Extracting attention from GPT-2...")
gpt2_tokens, gpt2_attentions = get_gpt2_attention(example_text, gpt2_tokenizer, gpt2_model, device)
print(f"ā
Tokens: {gpt2_tokens}")
print(f"ā
Attention shape: {len(gpt2_attentions)} layers, "
f"each {gpt2_attentions[0].shape}")
def plot_attention_head(attention, tokens, layer, head, ax=None):
"""
Plot attention weights for a single head.
Args:
attention: [seq_len, seq_len] attention matrix
tokens: List of token strings
layer: Layer number (for title)
head: Head number (for title)
ax: Matplotlib axis (optional)
"""
if ax is None:
fig, ax = plt.subplots(figsize=(10, 8))
# Plot heatmap
sns.heatmap(attention, xticklabels=tokens, yticklabels=tokens,
cmap='YlOrRd', vmin=0, vmax=1,
square=True, linewidths=0.5, cbar_kws={'label': 'Attention Weight'},
ax=ax)
ax.set_title(f'Layer {layer}, Head {head}', fontsize=14, fontweight='bold')
ax.set_xlabel('Key Tokens', fontsize=12)
ax.set_ylabel('Query Tokens', fontsize=12)
# Rotate labels
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
plt.tight_layout()
# Visualize specific attention head
layer_idx = 5 # Middle layer
head_idx = 3 # Random head
# Get attention for this layer and head
attention = bert_attentions[layer_idx][0, head_idx].cpu().numpy()
# Plot
fig, ax = plt.subplots(figsize=(12, 10))
plot_attention_head(attention, bert_tokens, layer_idx, head_idx, ax)
plt.savefig(f'bert_attention_layer{layer_idx}_head{head_idx}.png', dpi=150, bbox_inches='tight')
plt.show()
print(f"ā
Visualized BERT Layer {layer_idx}, Head {head_idx}")
def plot_all_heads_in_layer(attentions, tokens, layer_idx, model_name='BERT'):
"""
Plot all attention heads in a specific layer.
Args:
attentions: Tuple of attention tensors
tokens: List of tokens
layer_idx: Which layer to visualize
model_name: Model name for title
"""
# Get attention for this layer: [1, num_heads, seq_len, seq_len]
layer_attention = attentions[layer_idx][0].cpu().numpy()
num_heads = layer_attention.shape[0]
# Create grid of subplots
cols = 4
rows = (num_heads + cols - 1) // cols
fig, axes = plt.subplots(rows, cols, figsize=(20, rows * 4))
axes = axes.flatten()
for head_idx in range(num_heads):
ax = axes[head_idx]
attention = layer_attention[head_idx]
sns.heatmap(attention, xticklabels=tokens, yticklabels=tokens,
cmap='YlOrRd', vmin=0, vmax=1, square=True,
linewidths=0.2, cbar=False, ax=ax)
ax.set_title(f'Head {head_idx}', fontsize=10, fontweight='bold')
ax.set_xlabel('')
ax.set_ylabel('')
ax.tick_params(labelsize=7)
# Only show labels on outer edges
if head_idx % cols != 0:
ax.set_yticklabels([])
if head_idx < num_heads - cols:
ax.set_xticklabels([])
else:
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
# Hide extra subplots
for idx in range(num_heads, len(axes)):
axes[idx].axis('off')
plt.suptitle(f'{model_name} - Layer {layer_idx} - All Attention Heads',
fontsize=16, fontweight='bold', y=0.995)
plt.tight_layout()
plt.savefig(f'{model_name.lower()}_layer{layer_idx}_all_heads.png', dpi=150, bbox_inches='tight')
plt.show()
# Visualize all heads in layer 5
plot_all_heads_in_layer(bert_attentions, bert_tokens, layer_idx=5, model_name='BERT')
print("ā
Visualized all heads in BERT layer 5")
def analyze_attention_patterns(attentions, tokens):
"""
Analyze what patterns different heads learn.
Args:
attentions: Tuple of attention tensors
tokens: List of tokens
Returns:
Dictionary with pattern analysis
"""
patterns = {
'local_attention': [], # Attends to nearby tokens
'global_attention': [], # Attends to all tokens uniformly
'positional_attention': [], # Attends to specific positions
'special_token_attention': [] # Attends to special tokens ([CLS], [SEP])
}
num_layers = len(attentions)
num_heads = attentions[0].shape[1]
for layer_idx in range(num_layers):
for head_idx in range(num_heads):
attention = attentions[layer_idx][0, head_idx].cpu().numpy()
seq_len = attention.shape[0]
# Measure locality (how much attention goes to nearby tokens)
locality_score = 0
for i in range(seq_len):
for j in range(max(0, i-2), min(seq_len, i+3)):
locality_score += attention[i, j]
locality_score /= seq_len
# Measure entropy (uniform distribution has high entropy)
entropy = -np.sum(attention * np.log(attention + 1e-10), axis=-1).mean()
# Measure attention to first token ([CLS])
cls_attention = attention[:, 0].mean()
# Categorize patterns
if locality_score > 0.7:
patterns['local_attention'].append((layer_idx, head_idx, locality_score))
if entropy > 2.0:
patterns['global_attention'].append((layer_idx, head_idx, entropy))
if cls_attention > 0.3:
patterns['special_token_attention'].append((layer_idx, head_idx, cls_attention))
return patterns
# Analyze BERT patterns
print("\nš¬ Analyzing BERT attention patterns...")
bert_patterns = analyze_attention_patterns(bert_attentions, bert_tokens)
print(f"\nš Pattern Analysis Results:")
print(f" Local attention heads: {len(bert_patterns['local_attention'])}")
for layer, head, score in bert_patterns['local_attention'][:5]:
print(f" Layer {layer}, Head {head}: {score:.3f}")
print(f"\n Global attention heads: {len(bert_patterns['global_attention'])}")
for layer, head, entropy in bert_patterns['global_attention'][:5]:
print(f" Layer {layer}, Head {head}: entropy={entropy:.3f}")
print(f"\n Special token attention heads: {len(bert_patterns['special_token_attention'])}")
for layer, head, score in bert_patterns['special_token_attention'][:5]:
print(f" Layer {layer}, Head {head}: [CLS] attention={score:.3f}")
def plot_attention_flow(attentions, tokens, token_idx):
"""
Show how a specific token's attention changes across layers.
Args:
attentions: Tuple of attention tensors
tokens: List of tokens
token_idx: Index of token to track
"""
num_layers = len(attentions)
num_heads = attentions[0].shape[1]
seq_len = len(tokens)
# Extract attention from all layers for this token
attention_by_layer = np.zeros((num_layers, seq_len))
for layer_idx in range(num_layers):
# Average across heads
layer_attn = attentions[layer_idx][0, :, token_idx, :].cpu().numpy()
attention_by_layer[layer_idx] = layer_attn.mean(axis=0)
# Plot
fig, ax = plt.subplots(figsize=(14, 8))
im = ax.imshow(attention_by_layer.T, aspect='auto', cmap='YlOrRd',
vmin=0, vmax=attention_by_layer.max())
ax.set_xlabel('Layer', fontsize=14)
ax.set_ylabel('Token', fontsize=14)
ax.set_title(f'Attention Flow for Token: "{tokens[token_idx]}"',
fontsize=16, fontweight='bold', pad=20)
ax.set_xticks(range(num_layers))
ax.set_xticklabels(range(num_layers))
ax.set_yticks(range(seq_len))
ax.set_yticklabels(tokens)
# Add colorbar
cbar = plt.colorbar(im, ax=ax)
cbar.set_label('Attention Weight', fontsize=12)
plt.tight_layout()
plt.savefig(f'attention_flow_token_{token_idx}.png', dpi=150, bbox_inches='tight')
plt.show()
# Visualize attention flow for the word "cat"
cat_idx = bert_tokens.index('cat') if 'cat' in bert_tokens else 2
plot_attention_flow(bert_attentions, bert_tokens, cat_idx)
print(f"ā
Visualized attention flow for token '{bert_tokens[cat_idx]}'")
def compare_models_attention(bert_attentions, gpt2_attentions,
bert_tokens, gpt2_tokens, layer_idx=5):
"""
Compare attention patterns between BERT and GPT-2.
Args:
bert_attentions: BERT attention tensors
gpt2_attentions: GPT-2 attention tensors
bert_tokens: BERT tokens
gpt2_tokens: GPT-2 tokens
layer_idx: Which layer to compare
"""
fig, axes = plt.subplots(1, 2, figsize=(20, 8))
# BERT (encoder, bidirectional)
bert_attn = bert_attentions[layer_idx][0, 0].cpu().numpy() # Head 0
sns.heatmap(bert_attn, xticklabels=bert_tokens, yticklabels=bert_tokens,
cmap='Blues', square=True, linewidths=0.5, ax=axes[0])
axes[0].set_title('BERT (Bidirectional Encoder)\nLayer 5, Head 0',
fontsize=14, fontweight='bold')
axes[0].set_xlabel('Key Tokens')
axes[0].set_ylabel('Query Tokens')
# GPT-2 (decoder, causal)
gpt2_attn = gpt2_attentions[layer_idx][0, 0].cpu().numpy() # Head 0
sns.heatmap(gpt2_attn, xticklabels=gpt2_tokens, yticklabels=gpt2_tokens,
cmap='Reds', square=True, linewidths=0.5, ax=axes[1])
axes[1].set_title('GPT-2 (Causal Decoder)\nLayer 5, Head 0',
fontsize=14, fontweight='bold')
axes[1].set_xlabel('Key Tokens')
axes[1].set_ylabel('Query Tokens')
plt.suptitle('BERT vs GPT-2: Attention Patterns', fontsize=16, fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig('bert_vs_gpt2_attention.png', dpi=150, bbox_inches='tight')
plt.show()
print("\nš Key Differences:")
print(" BERT (Blue):")
print(" - Bidirectional: can attend to all tokens (past and future)")
print(" - Useful for understanding context from both directions")
print(" - Applications: classification, named entity recognition")
print()
print(" GPT-2 (Red):")
print(" - Causal (triangular pattern): only attends to past tokens")
print(" - Prevents information leakage during generation")
print(" - Applications: text generation, autocompletion")
# Compare models
compare_models_attention(bert_attentions, gpt2_attentions,
bert_tokens, gpt2_tokens, layer_idx=5)
print("ā
Model comparison complete")
def interactive_attention_explorer(text, model_type='bert'):
"""
Interactive tool to explore attention for any text.
Args:
text: Input text
model_type: 'bert' or 'gpt2'
"""
# Get model and tokenizer
if model_type == 'bert':
tokenizer = bert_tokenizer
model = bert_model
tokens, attentions = get_bert_attention(text, tokenizer, model, device)
else:
tokenizer = gpt2_tokenizer
model = gpt2_model
tokens, attentions = get_gpt2_attention(text, tokenizer, model, device)
num_layers = len(attentions)
num_heads = attentions[0].shape[1]
print(f"\n{'='*70}")
print(f"š Attention Explorer - {model_type.upper()}")
print(f"{'='*70}")
print(f"Text: {text}")
print(f"Tokens: {tokens}")
print(f"Layers: {num_layers}, Heads per layer: {num_heads}")
print(f"{'='*70}\n")
# Find interesting heads
print("šÆ Finding interesting attention patterns...\n")
for layer_idx in range(min(3, num_layers)): # Check first 3 layers
for head_idx in range(min(4, num_heads)): # Check first 4 heads
attention = attentions[layer_idx][0, head_idx].cpu().numpy()
# Check if this head has interesting patterns
# Pattern 1: Diagonal (local attention)
diagonal_sum = sum([attention[i, i] for i in range(len(tokens))])
# Pattern 2: First column (attention to [CLS])
first_col_sum = attention[:, 0].sum()
# Pattern 3: Last row (last token attending to all)
last_row_entropy = -np.sum(attention[-1] * np.log(attention[-1] + 1e-10))
if diagonal_sum > 0.5:
print(f" Layer {layer_idx}, Head {head_idx}: LOCAL pattern "
f"(diagonal={diagonal_sum:.2f})")
if first_col_sum > len(tokens) * 0.3:
print(f" Layer {layer_idx}, Head {head_idx}: CLS/START pattern "
f"(first_col={first_col_sum:.2f})")
if last_row_entropy > 2.0:
print(f" Layer {layer_idx}, Head {head_idx}: GLOBAL pattern "
f"(entropy={last_row_entropy:.2f})")
print(f"\n{'='*70}")
# Visualize most interesting head
layer_idx = 2
head_idx = 0
attention = attentions[layer_idx][0, head_idx].cpu().numpy()
fig, ax = plt.subplots(figsize=(12, 10))
plot_attention_head(attention, tokens, layer_idx, head_idx, ax)
plt.savefig(f'{model_type}_attention_explorer.png', dpi=150, bbox_inches='tight')
plt.show()
# Try different sentences
test_sentences = [
"The quick brown fox jumps over the lazy dog.",
"She sells seashells by the seashore.",
"To be or not to be, that is the question.",
"The cat and the dog are friends, but the bird is scared.",
]
print("\nš Exploring Attention for Different Sentences")
print("="*70)
for sentence in test_sentences[:2]: # Analyze first 2
interactive_attention_explorer(sentence, model_type='bert')
print("\n")
def analyze_head_specialization(attentions, tokens):
"""
Discover what linguistic properties different heads capture.
Args:
attentions: Attention tensors
tokens: Token list
"""
num_layers = len(attentions)
num_heads = attentions[0].shape[1]
specializations = {
'positional': [], # Attends based on position
'syntactic': [], # Attends to grammatical relationships
'semantic': [], # Attends to meaning-related tokens
'delimiter': [], # Attends to punctuation/boundaries
}
print("\nš§ Head Specialization Analysis")
print("="*70)
for layer_idx in range(num_layers):
for head_idx in range(num_heads):
attention = attentions[layer_idx][0, head_idx].cpu().numpy()
# Positional pattern: attention depends mainly on distance
distance_correlation = 0
for i in range(len(tokens)):
for j in range(len(tokens)):
distance = abs(i - j)
distance_correlation += attention[i, j] * (1.0 / (distance + 1))
distance_correlation /= len(tokens) ** 2
# Punctuation attention
punct_attention = 0
punct_tokens = ['.', ',', '!', '?', ';', ':']
for i, token in enumerate(tokens):
if token in punct_tokens:
punct_attention += attention[:, i].mean()
punct_attention /= max(sum(1 for t in tokens if t in punct_tokens), 1)
# Store findings
if distance_correlation > 0.15:
specializations['positional'].append((layer_idx, head_idx, distance_correlation))
if punct_attention > 0.2:
specializations['delimiter'].append((layer_idx, head_idx, punct_attention))
# Early layers = syntactic, late layers = semantic (heuristic)
if layer_idx < num_layers // 3:
specializations['syntactic'].append((layer_idx, head_idx))
elif layer_idx > 2 * num_layers // 3:
specializations['semantic'].append((layer_idx, head_idx))
# Print findings
print("\nš POSITIONAL Heads (attend based on token distance):")
for layer, head, score in sorted(specializations['positional'],
key=lambda x: x[2], reverse=True)[:5]:
print(f" Layer {layer:2d}, Head {head}: score={score:.3f}")
print("\nš DELIMITER Heads (attend to punctuation):")
for layer, head, score in sorted(specializations['delimiter'],
key=lambda x: x[2], reverse=True)[:5]:
print(f" Layer {layer:2d}, Head {head}: score={score:.3f}")
print("\nš¤ SYNTACTIC Heads (early layers, grammatical structure):")
for layer, head in specializations['syntactic'][:5]:
print(f" Layer {layer:2d}, Head {head}")
print("\nš” SEMANTIC Heads (late layers, meaning relationships):")
for layer, head in specializations['semantic'][:5]:
print(f" Layer {layer:2d}, Head {head}")
print("\n" + "="*70)
return specializations
# Analyze BERT head specialization
bert_specializations = analyze_head_specialization(bert_attentions, bert_tokens)
print("ā
Specialization analysis complete")
def create_attention_summary(attentions, tokens, model_name='BERT'):
"""
Create comprehensive summary visualization.
Args:
attentions: Attention tensors
tokens: Token list
model_name: Model name
"""
num_layers = len(attentions)
num_heads = attentions[0].shape[1]
# Compute statistics
avg_attention_by_layer = []
entropy_by_layer = []
for layer_idx in range(num_layers):
layer_attn = attentions[layer_idx][0].cpu().numpy() # [num_heads, seq_len, seq_len]
# Average attention weight
avg_attn = layer_attn.mean()
avg_attention_by_layer.append(avg_attn)
# Average entropy (measure of attention dispersion)
entropy = -np.sum(layer_attn * np.log(layer_attn + 1e-10), axis=-1).mean()
entropy_by_layer.append(entropy)
# Create summary plot
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
# Plot 1: Average attention by layer
axes[0, 0].plot(range(num_layers), avg_attention_by_layer,
marker='o', linewidth=2, markersize=8)
axes[0, 0].set_xlabel('Layer', fontsize=12)
axes[0, 0].set_ylabel('Average Attention Weight', fontsize=12)
axes[0, 0].set_title('Average Attention Across Layers', fontsize=14, fontweight='bold')
axes[0, 0].grid(True, alpha=0.3)
# Plot 2: Entropy by layer
axes[0, 1].plot(range(num_layers), entropy_by_layer,
marker='s', linewidth=2, markersize=8, color='orange')
axes[0, 1].set_xlabel('Layer', fontsize=12)
axes[0, 1].set_ylabel('Attention Entropy', fontsize=12)
axes[0, 1].set_title('Attention Entropy Across Layers', fontsize=14, fontweight='bold')
axes[0, 1].grid(True, alpha=0.3)
# Plot 3: Attention heatmap (average across all heads and layers)
avg_attention = torch.stack([a[0] for a in attentions]).mean(dim=(0, 1)).cpu().numpy()
sns.heatmap(avg_attention, xticklabels=tokens, yticklabels=tokens,
cmap='YlOrRd', square=True, linewidths=0.5, ax=axes[1, 0])
axes[1, 0].set_title('Average Attention (All Layers & Heads)',
fontsize=14, fontweight='bold')
axes[1, 0].set_xlabel('Key Tokens')
axes[1, 0].set_ylabel('Query Tokens')
# Plot 4: Token-wise attention received
attention_received = avg_attention.sum(axis=0)
axes[1, 1].bar(range(len(tokens)), attention_received, color='steelblue', alpha=0.7)
axes[1, 1].set_xlabel('Token', fontsize=12)
axes[1, 1].set_ylabel('Total Attention Received', fontsize=12)
axes[1, 1].set_title('Attention Received by Each Token', fontsize=14, fontweight='bold')
axes[1, 1].set_xticks(range(len(tokens)))
axes[1, 1].set_xticklabels(tokens, rotation=45, ha='right')
axes[1, 1].grid(True, alpha=0.3, axis='y')
plt.suptitle(f'{model_name} Attention Analysis Summary',
fontsize=16, fontweight='bold', y=0.995)
plt.tight_layout()
plt.savefig(f'{model_name.lower()}_attention_summary.png', dpi=150, bbox_inches='tight')
plt.show()
# Print summary
print(f"\nš {model_name} Attention Summary")
print("="*70)
print(f" Total layers: {num_layers}")
print(f" Heads per layer: {num_heads}")
print(f" Sequence length: {len(tokens)}")
print(f" Avg attention weight: {np.mean(avg_attention_by_layer):.4f}")
print(f" Avg entropy: {np.mean(entropy_by_layer):.4f}")
print(f" Most attended token: '{tokens[np.argmax(attention_received)]}'")
print("="*70)
# Create summary for BERT
create_attention_summary(bert_attentions, bert_tokens, model_name='BERT')
print("ā
Summary visualization complete")
š You've unlocked the black box!
You can now visualize and understand what transformer models are actually learning. This is crucial for interpretability, debugging, and advancing AI research.
Q1: What do attention visualizations primarily show?
Q2: What library is commonly used for creating attention heatmaps?
Q3: In attention visualizations, what do darker colors typically represent?
Q4: What insight can you gain from visualizing attention patterns?
Q5: Why visualize attention patterns in multi-head attention?
You've completed the Transformers Architecture course! You now understand how these powerful models work from the ground up.
Your certificate will be generated instantly and sent to your email.
Explore these related courses: