šŸ” Project 3: Visualize Attention Patterns

šŸ“š Hands-On Project šŸ”“ Advanced ā±ļø 1-2 hours

Analyze and visualize what transformer attention heads learn using pretrained models

šŸŽÆ Project Overview

In this project, you'll peek inside pretrained transformer models (BERT and GPT-2) to visualize and understand what different attention heads learn. You'll discover that different heads specialize in different linguistic patterns!

What You'll Build

  • āœ… Load pretrained BERT and GPT-2 models
  • āœ… Extract attention weights from all layers
  • āœ… Visualize attention patterns with heatmaps
  • āœ… Analyze what different heads attend to
  • āœ… Interactive attention explorer
  • āœ… Compare attention across layers
  • āœ… Discover syntactic vs semantic patterns

šŸ“‹ Prerequisites

  • Python 3.7+
  • Transformers library: pip install transformers
  • Visualization: pip install matplotlib seaborn
  • PyTorch or TensorFlow
  • 2-4 GB RAM for model loading

Step 1: Setup and Load Models

import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import BertTokenizer, BertModel, GPT2Tokenizer, GPT2LMHeadModel

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 8)

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load BERT
print("\nšŸ“„ Loading BERT-base-uncased...")
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased', output_attentions=True)
bert_model.to(device)
bert_model.eval()
print(f"āœ… BERT loaded: {bert_model.config.num_hidden_layers} layers, "
      f"{bert_model.config.num_attention_heads} heads per layer")

# Load GPT-2
print("\nšŸ“„ Loading GPT-2...")
gpt2_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
gpt2_model = GPT2LMHeadModel.from_pretrained('gpt2', output_attentions=True)
gpt2_model.to(device)
gpt2_model.eval()
print(f"āœ… GPT-2 loaded: {gpt2_model.config.n_layer} layers, "
      f"{gpt2_model.config.n_head} heads per layer")

print("\nšŸŽ‰ Models ready for analysis!")

Step 2: Extract Attention Weights

def get_bert_attention(text, tokenizer, model, device):
    """
    Extract attention weights from BERT.
    
    Args:
        text: Input text string
        tokenizer: BERT tokenizer
        model: BERT model
        device: torch device
    
    Returns:
        tokens: List of tokens
        attentions: Tuple of attention tensors (one per layer)
                   Each tensor: [batch=1, num_heads, seq_len, seq_len]
    """
    # Tokenize
    inputs = tokenizer(text, return_tensors='pt', add_special_tokens=True)
    input_ids = inputs['input_ids'].to(device)
    
    # Get tokens for visualization
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    
    # Forward pass
    with torch.no_grad():
        outputs = model(input_ids)
        attentions = outputs.attentions  # Tuple of [1, num_heads, seq_len, seq_len]
    
    return tokens, attentions


def get_gpt2_attention(text, tokenizer, model, device):
    """
    Extract attention weights from GPT-2.
    
    Args:
        text: Input text string
        tokenizer: GPT-2 tokenizer
        model: GPT-2 model
        device: torch device
    
    Returns:
        tokens: List of tokens
        attentions: Tuple of attention tensors (one per layer)
    """
    # Tokenize
    inputs = tokenizer(text, return_tensors='pt')
    input_ids = inputs['input_ids'].to(device)
    
    # Get tokens
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    
    # Forward pass
    with torch.no_grad():
        outputs = model(input_ids)
        attentions = outputs.attentions
    
    return tokens, attentions


# Test with example sentence
example_text = "The cat sat on the mat and looked at the bird."

print("\nšŸ” Extracting attention from BERT...")
bert_tokens, bert_attentions = get_bert_attention(example_text, bert_tokenizer, bert_model, device)
print(f"āœ… Tokens: {bert_tokens}")
print(f"āœ… Attention shape: {len(bert_attentions)} layers, "
      f"each {bert_attentions[0].shape}")

print("\nšŸ” Extracting attention from GPT-2...")
gpt2_tokens, gpt2_attentions = get_gpt2_attention(example_text, gpt2_tokenizer, gpt2_model, device)
print(f"āœ… Tokens: {gpt2_tokens}")
print(f"āœ… Attention shape: {len(gpt2_attentions)} layers, "
      f"each {gpt2_attentions[0].shape}")

Step 3: Visualize Single Head

def plot_attention_head(attention, tokens, layer, head, ax=None):
    """
    Plot attention weights for a single head.
    
    Args:
        attention: [seq_len, seq_len] attention matrix
        tokens: List of token strings
        layer: Layer number (for title)
        head: Head number (for title)
        ax: Matplotlib axis (optional)
    """
    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 8))
    
    # Plot heatmap
    sns.heatmap(attention, xticklabels=tokens, yticklabels=tokens, 
                cmap='YlOrRd', vmin=0, vmax=1, 
                square=True, linewidths=0.5, cbar_kws={'label': 'Attention Weight'},
                ax=ax)
    
    ax.set_title(f'Layer {layer}, Head {head}', fontsize=14, fontweight='bold')
    ax.set_xlabel('Key Tokens', fontsize=12)
    ax.set_ylabel('Query Tokens', fontsize=12)
    
    # Rotate labels
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
    
    plt.tight_layout()


# Visualize specific attention head
layer_idx = 5  # Middle layer
head_idx = 3   # Random head

# Get attention for this layer and head
attention = bert_attentions[layer_idx][0, head_idx].cpu().numpy()

# Plot
fig, ax = plt.subplots(figsize=(12, 10))
plot_attention_head(attention, bert_tokens, layer_idx, head_idx, ax)
plt.savefig(f'bert_attention_layer{layer_idx}_head{head_idx}.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"āœ… Visualized BERT Layer {layer_idx}, Head {head_idx}")

Step 4: Compare All Heads in a Layer

def plot_all_heads_in_layer(attentions, tokens, layer_idx, model_name='BERT'):
    """
    Plot all attention heads in a specific layer.
    
    Args:
        attentions: Tuple of attention tensors
        tokens: List of tokens
        layer_idx: Which layer to visualize
        model_name: Model name for title
    """
    # Get attention for this layer: [1, num_heads, seq_len, seq_len]
    layer_attention = attentions[layer_idx][0].cpu().numpy()
    num_heads = layer_attention.shape[0]
    
    # Create grid of subplots
    cols = 4
    rows = (num_heads + cols - 1) // cols
    
    fig, axes = plt.subplots(rows, cols, figsize=(20, rows * 4))
    axes = axes.flatten()
    
    for head_idx in range(num_heads):
        ax = axes[head_idx]
        attention = layer_attention[head_idx]
        
        sns.heatmap(attention, xticklabels=tokens, yticklabels=tokens,
                   cmap='YlOrRd', vmin=0, vmax=1, square=True,
                   linewidths=0.2, cbar=False, ax=ax)
        
        ax.set_title(f'Head {head_idx}', fontsize=10, fontweight='bold')
        ax.set_xlabel('')
        ax.set_ylabel('')
        ax.tick_params(labelsize=7)
        
        # Only show labels on outer edges
        if head_idx % cols != 0:
            ax.set_yticklabels([])
        if head_idx < num_heads - cols:
            ax.set_xticklabels([])
        else:
            ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
    
    # Hide extra subplots
    for idx in range(num_heads, len(axes)):
        axes[idx].axis('off')
    
    plt.suptitle(f'{model_name} - Layer {layer_idx} - All Attention Heads', 
                 fontsize=16, fontweight='bold', y=0.995)
    plt.tight_layout()
    plt.savefig(f'{model_name.lower()}_layer{layer_idx}_all_heads.png', dpi=150, bbox_inches='tight')
    plt.show()


# Visualize all heads in layer 5
plot_all_heads_in_layer(bert_attentions, bert_tokens, layer_idx=5, model_name='BERT')
print("āœ… Visualized all heads in BERT layer 5")

Step 5: Analyze Attention Patterns

def analyze_attention_patterns(attentions, tokens):
    """
    Analyze what patterns different heads learn.
    
    Args:
        attentions: Tuple of attention tensors
        tokens: List of tokens
    
    Returns:
        Dictionary with pattern analysis
    """
    patterns = {
        'local_attention': [],      # Attends to nearby tokens
        'global_attention': [],     # Attends to all tokens uniformly
        'positional_attention': [], # Attends to specific positions
        'special_token_attention': [] # Attends to special tokens ([CLS], [SEP])
    }
    
    num_layers = len(attentions)
    num_heads = attentions[0].shape[1]
    
    for layer_idx in range(num_layers):
        for head_idx in range(num_heads):
            attention = attentions[layer_idx][0, head_idx].cpu().numpy()
            seq_len = attention.shape[0]
            
            # Measure locality (how much attention goes to nearby tokens)
            locality_score = 0
            for i in range(seq_len):
                for j in range(max(0, i-2), min(seq_len, i+3)):
                    locality_score += attention[i, j]
            locality_score /= seq_len
            
            # Measure entropy (uniform distribution has high entropy)
            entropy = -np.sum(attention * np.log(attention + 1e-10), axis=-1).mean()
            
            # Measure attention to first token ([CLS])
            cls_attention = attention[:, 0].mean()
            
            # Categorize patterns
            if locality_score > 0.7:
                patterns['local_attention'].append((layer_idx, head_idx, locality_score))
            
            if entropy > 2.0:
                patterns['global_attention'].append((layer_idx, head_idx, entropy))
            
            if cls_attention > 0.3:
                patterns['special_token_attention'].append((layer_idx, head_idx, cls_attention))
    
    return patterns


# Analyze BERT patterns
print("\nšŸ”¬ Analyzing BERT attention patterns...")
bert_patterns = analyze_attention_patterns(bert_attentions, bert_tokens)

print(f"\nšŸ“Š Pattern Analysis Results:")
print(f"   Local attention heads: {len(bert_patterns['local_attention'])}")
for layer, head, score in bert_patterns['local_attention'][:5]:
    print(f"      Layer {layer}, Head {head}: {score:.3f}")

print(f"\n   Global attention heads: {len(bert_patterns['global_attention'])}")
for layer, head, entropy in bert_patterns['global_attention'][:5]:
    print(f"      Layer {layer}, Head {head}: entropy={entropy:.3f}")

print(f"\n   Special token attention heads: {len(bert_patterns['special_token_attention'])}")
for layer, head, score in bert_patterns['special_token_attention'][:5]:
    print(f"      Layer {layer}, Head {head}: [CLS] attention={score:.3f}")

Step 6: Layer-wise Attention Flow

def plot_attention_flow(attentions, tokens, token_idx):
    """
    Show how a specific token's attention changes across layers.
    
    Args:
        attentions: Tuple of attention tensors
        tokens: List of tokens
        token_idx: Index of token to track
    """
    num_layers = len(attentions)
    num_heads = attentions[0].shape[1]
    seq_len = len(tokens)
    
    # Extract attention from all layers for this token
    attention_by_layer = np.zeros((num_layers, seq_len))
    
    for layer_idx in range(num_layers):
        # Average across heads
        layer_attn = attentions[layer_idx][0, :, token_idx, :].cpu().numpy()
        attention_by_layer[layer_idx] = layer_attn.mean(axis=0)
    
    # Plot
    fig, ax = plt.subplots(figsize=(14, 8))
    
    im = ax.imshow(attention_by_layer.T, aspect='auto', cmap='YlOrRd', 
                   vmin=0, vmax=attention_by_layer.max())
    
    ax.set_xlabel('Layer', fontsize=14)
    ax.set_ylabel('Token', fontsize=14)
    ax.set_title(f'Attention Flow for Token: "{tokens[token_idx]}"', 
                fontsize=16, fontweight='bold', pad=20)
    
    ax.set_xticks(range(num_layers))
    ax.set_xticklabels(range(num_layers))
    ax.set_yticks(range(seq_len))
    ax.set_yticklabels(tokens)
    
    # Add colorbar
    cbar = plt.colorbar(im, ax=ax)
    cbar.set_label('Attention Weight', fontsize=12)
    
    plt.tight_layout()
    plt.savefig(f'attention_flow_token_{token_idx}.png', dpi=150, bbox_inches='tight')
    plt.show()


# Visualize attention flow for the word "cat"
cat_idx = bert_tokens.index('cat') if 'cat' in bert_tokens else 2
plot_attention_flow(bert_attentions, bert_tokens, cat_idx)
print(f"āœ… Visualized attention flow for token '{bert_tokens[cat_idx]}'")

Step 7: Compare BERT vs GPT-2

def compare_models_attention(bert_attentions, gpt2_attentions, 
                             bert_tokens, gpt2_tokens, layer_idx=5):
    """
    Compare attention patterns between BERT and GPT-2.
    
    Args:
        bert_attentions: BERT attention tensors
        gpt2_attentions: GPT-2 attention tensors
        bert_tokens: BERT tokens
        gpt2_tokens: GPT-2 tokens
        layer_idx: Which layer to compare
    """
    fig, axes = plt.subplots(1, 2, figsize=(20, 8))
    
    # BERT (encoder, bidirectional)
    bert_attn = bert_attentions[layer_idx][0, 0].cpu().numpy()  # Head 0
    sns.heatmap(bert_attn, xticklabels=bert_tokens, yticklabels=bert_tokens,
               cmap='Blues', square=True, linewidths=0.5, ax=axes[0])
    axes[0].set_title('BERT (Bidirectional Encoder)\nLayer 5, Head 0', 
                     fontsize=14, fontweight='bold')
    axes[0].set_xlabel('Key Tokens')
    axes[0].set_ylabel('Query Tokens')
    
    # GPT-2 (decoder, causal)
    gpt2_attn = gpt2_attentions[layer_idx][0, 0].cpu().numpy()  # Head 0
    sns.heatmap(gpt2_attn, xticklabels=gpt2_tokens, yticklabels=gpt2_tokens,
               cmap='Reds', square=True, linewidths=0.5, ax=axes[1])
    axes[1].set_title('GPT-2 (Causal Decoder)\nLayer 5, Head 0', 
                     fontsize=14, fontweight='bold')
    axes[1].set_xlabel('Key Tokens')
    axes[1].set_ylabel('Query Tokens')
    
    plt.suptitle('BERT vs GPT-2: Attention Patterns', fontsize=16, fontweight='bold', y=1.02)
    plt.tight_layout()
    plt.savefig('bert_vs_gpt2_attention.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print("\nšŸ” Key Differences:")
    print("   BERT (Blue):")
    print("      - Bidirectional: can attend to all tokens (past and future)")
    print("      - Useful for understanding context from both directions")
    print("      - Applications: classification, named entity recognition")
    print()
    print("   GPT-2 (Red):")
    print("      - Causal (triangular pattern): only attends to past tokens")
    print("      - Prevents information leakage during generation")
    print("      - Applications: text generation, autocompletion")


# Compare models
compare_models_attention(bert_attentions, gpt2_attentions, 
                        bert_tokens, gpt2_tokens, layer_idx=5)
print("āœ… Model comparison complete")

Step 8: Interactive Attention Explorer

def interactive_attention_explorer(text, model_type='bert'):
    """
    Interactive tool to explore attention for any text.
    
    Args:
        text: Input text
        model_type: 'bert' or 'gpt2'
    """
    # Get model and tokenizer
    if model_type == 'bert':
        tokenizer = bert_tokenizer
        model = bert_model
        tokens, attentions = get_bert_attention(text, tokenizer, model, device)
    else:
        tokenizer = gpt2_tokenizer
        model = gpt2_model
        tokens, attentions = get_gpt2_attention(text, tokenizer, model, device)
    
    num_layers = len(attentions)
    num_heads = attentions[0].shape[1]
    
    print(f"\n{'='*70}")
    print(f"šŸ” Attention Explorer - {model_type.upper()}")
    print(f"{'='*70}")
    print(f"Text: {text}")
    print(f"Tokens: {tokens}")
    print(f"Layers: {num_layers}, Heads per layer: {num_heads}")
    print(f"{'='*70}\n")
    
    # Find interesting heads
    print("šŸŽÆ Finding interesting attention patterns...\n")
    
    for layer_idx in range(min(3, num_layers)):  # Check first 3 layers
        for head_idx in range(min(4, num_heads)):  # Check first 4 heads
            attention = attentions[layer_idx][0, head_idx].cpu().numpy()
            
            # Check if this head has interesting patterns
            # Pattern 1: Diagonal (local attention)
            diagonal_sum = sum([attention[i, i] for i in range(len(tokens))])
            
            # Pattern 2: First column (attention to [CLS])
            first_col_sum = attention[:, 0].sum()
            
            # Pattern 3: Last row (last token attending to all)
            last_row_entropy = -np.sum(attention[-1] * np.log(attention[-1] + 1e-10))
            
            if diagonal_sum > 0.5:
                print(f"   Layer {layer_idx}, Head {head_idx}: LOCAL pattern "
                      f"(diagonal={diagonal_sum:.2f})")
            
            if first_col_sum > len(tokens) * 0.3:
                print(f"   Layer {layer_idx}, Head {head_idx}: CLS/START pattern "
                      f"(first_col={first_col_sum:.2f})")
            
            if last_row_entropy > 2.0:
                print(f"   Layer {layer_idx}, Head {head_idx}: GLOBAL pattern "
                      f"(entropy={last_row_entropy:.2f})")
    
    print(f"\n{'='*70}")
    
    # Visualize most interesting head
    layer_idx = 2
    head_idx = 0
    attention = attentions[layer_idx][0, head_idx].cpu().numpy()
    
    fig, ax = plt.subplots(figsize=(12, 10))
    plot_attention_head(attention, tokens, layer_idx, head_idx, ax)
    plt.savefig(f'{model_type}_attention_explorer.png', dpi=150, bbox_inches='tight')
    plt.show()


# Try different sentences
test_sentences = [
    "The quick brown fox jumps over the lazy dog.",
    "She sells seashells by the seashore.",
    "To be or not to be, that is the question.",
    "The cat and the dog are friends, but the bird is scared.",
]

print("\n🌟 Exploring Attention for Different Sentences")
print("="*70)

for sentence in test_sentences[:2]:  # Analyze first 2
    interactive_attention_explorer(sentence, model_type='bert')
    print("\n")

Step 9: Head Specialization Analysis

def analyze_head_specialization(attentions, tokens):
    """
    Discover what linguistic properties different heads capture.
    
    Args:
        attentions: Attention tensors
        tokens: Token list
    """
    num_layers = len(attentions)
    num_heads = attentions[0].shape[1]
    
    specializations = {
        'positional': [],    # Attends based on position
        'syntactic': [],     # Attends to grammatical relationships
        'semantic': [],      # Attends to meaning-related tokens
        'delimiter': [],     # Attends to punctuation/boundaries
    }
    
    print("\n🧠 Head Specialization Analysis")
    print("="*70)
    
    for layer_idx in range(num_layers):
        for head_idx in range(num_heads):
            attention = attentions[layer_idx][0, head_idx].cpu().numpy()
            
            # Positional pattern: attention depends mainly on distance
            distance_correlation = 0
            for i in range(len(tokens)):
                for j in range(len(tokens)):
                    distance = abs(i - j)
                    distance_correlation += attention[i, j] * (1.0 / (distance + 1))
            distance_correlation /= len(tokens) ** 2
            
            # Punctuation attention
            punct_attention = 0
            punct_tokens = ['.', ',', '!', '?', ';', ':']
            for i, token in enumerate(tokens):
                if token in punct_tokens:
                    punct_attention += attention[:, i].mean()
            punct_attention /= max(sum(1 for t in tokens if t in punct_tokens), 1)
            
            # Store findings
            if distance_correlation > 0.15:
                specializations['positional'].append((layer_idx, head_idx, distance_correlation))
            
            if punct_attention > 0.2:
                specializations['delimiter'].append((layer_idx, head_idx, punct_attention))
            
            # Early layers = syntactic, late layers = semantic (heuristic)
            if layer_idx < num_layers // 3:
                specializations['syntactic'].append((layer_idx, head_idx))
            elif layer_idx > 2 * num_layers // 3:
                specializations['semantic'].append((layer_idx, head_idx))
    
    # Print findings
    print("\nšŸ“ POSITIONAL Heads (attend based on token distance):")
    for layer, head, score in sorted(specializations['positional'], 
                                     key=lambda x: x[2], reverse=True)[:5]:
        print(f"   Layer {layer:2d}, Head {head}: score={score:.3f}")
    
    print("\nšŸ“ DELIMITER Heads (attend to punctuation):")
    for layer, head, score in sorted(specializations['delimiter'], 
                                     key=lambda x: x[2], reverse=True)[:5]:
        print(f"   Layer {layer:2d}, Head {head}: score={score:.3f}")
    
    print("\nšŸ”¤ SYNTACTIC Heads (early layers, grammatical structure):")
    for layer, head in specializations['syntactic'][:5]:
        print(f"   Layer {layer:2d}, Head {head}")
    
    print("\nšŸ’” SEMANTIC Heads (late layers, meaning relationships):")
    for layer, head in specializations['semantic'][:5]:
        print(f"   Layer {layer:2d}, Head {head}")
    
    print("\n" + "="*70)
    
    return specializations


# Analyze BERT head specialization
bert_specializations = analyze_head_specialization(bert_attentions, bert_tokens)
print("āœ… Specialization analysis complete")

Step 10: Attention Statistics Summary

def create_attention_summary(attentions, tokens, model_name='BERT'):
    """
    Create comprehensive summary visualization.
    
    Args:
        attentions: Attention tensors
        tokens: Token list
        model_name: Model name
    """
    num_layers = len(attentions)
    num_heads = attentions[0].shape[1]
    
    # Compute statistics
    avg_attention_by_layer = []
    entropy_by_layer = []
    
    for layer_idx in range(num_layers):
        layer_attn = attentions[layer_idx][0].cpu().numpy()  # [num_heads, seq_len, seq_len]
        
        # Average attention weight
        avg_attn = layer_attn.mean()
        avg_attention_by_layer.append(avg_attn)
        
        # Average entropy (measure of attention dispersion)
        entropy = -np.sum(layer_attn * np.log(layer_attn + 1e-10), axis=-1).mean()
        entropy_by_layer.append(entropy)
    
    # Create summary plot
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    
    # Plot 1: Average attention by layer
    axes[0, 0].plot(range(num_layers), avg_attention_by_layer, 
                    marker='o', linewidth=2, markersize=8)
    axes[0, 0].set_xlabel('Layer', fontsize=12)
    axes[0, 0].set_ylabel('Average Attention Weight', fontsize=12)
    axes[0, 0].set_title('Average Attention Across Layers', fontsize=14, fontweight='bold')
    axes[0, 0].grid(True, alpha=0.3)
    
    # Plot 2: Entropy by layer
    axes[0, 1].plot(range(num_layers), entropy_by_layer, 
                    marker='s', linewidth=2, markersize=8, color='orange')
    axes[0, 1].set_xlabel('Layer', fontsize=12)
    axes[0, 1].set_ylabel('Attention Entropy', fontsize=12)
    axes[0, 1].set_title('Attention Entropy Across Layers', fontsize=14, fontweight='bold')
    axes[0, 1].grid(True, alpha=0.3)
    
    # Plot 3: Attention heatmap (average across all heads and layers)
    avg_attention = torch.stack([a[0] for a in attentions]).mean(dim=(0, 1)).cpu().numpy()
    sns.heatmap(avg_attention, xticklabels=tokens, yticklabels=tokens,
               cmap='YlOrRd', square=True, linewidths=0.5, ax=axes[1, 0])
    axes[1, 0].set_title('Average Attention (All Layers & Heads)', 
                        fontsize=14, fontweight='bold')
    axes[1, 0].set_xlabel('Key Tokens')
    axes[1, 0].set_ylabel('Query Tokens')
    
    # Plot 4: Token-wise attention received
    attention_received = avg_attention.sum(axis=0)
    axes[1, 1].bar(range(len(tokens)), attention_received, color='steelblue', alpha=0.7)
    axes[1, 1].set_xlabel('Token', fontsize=12)
    axes[1, 1].set_ylabel('Total Attention Received', fontsize=12)
    axes[1, 1].set_title('Attention Received by Each Token', fontsize=14, fontweight='bold')
    axes[1, 1].set_xticks(range(len(tokens)))
    axes[1, 1].set_xticklabels(tokens, rotation=45, ha='right')
    axes[1, 1].grid(True, alpha=0.3, axis='y')
    
    plt.suptitle(f'{model_name} Attention Analysis Summary', 
                fontsize=16, fontweight='bold', y=0.995)
    plt.tight_layout()
    plt.savefig(f'{model_name.lower()}_attention_summary.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    # Print summary
    print(f"\nšŸ“Š {model_name} Attention Summary")
    print("="*70)
    print(f"   Total layers: {num_layers}")
    print(f"   Heads per layer: {num_heads}")
    print(f"   Sequence length: {len(tokens)}")
    print(f"   Avg attention weight: {np.mean(avg_attention_by_layer):.4f}")
    print(f"   Avg entropy: {np.mean(entropy_by_layer):.4f}")
    print(f"   Most attended token: '{tokens[np.argmax(attention_received)]}'")
    print("="*70)


# Create summary for BERT
create_attention_summary(bert_attentions, bert_tokens, model_name='BERT')
print("āœ… Summary visualization complete")

šŸŽ‰ Congratulations!

What You've Accomplished

  • āœ… Extracted and visualized attention from BERT and GPT-2
  • āœ… Analyzed attention patterns across layers and heads
  • āœ… Discovered head specialization (positional, syntactic, semantic)
  • āœ… Compared bidirectional vs causal attention
  • āœ… Created interactive attention exploration tools
  • āœ… Built comprehensive visualization dashboards

Key Discoveries

šŸ”¬ What We Learned About Attention

  • Layer Patterns: Early layers capture syntax, late layers capture semantics
  • Head Specialization: Different heads learn different linguistic properties
  • Positional Heads: Some heads mainly attend based on token distance
  • Delimiter Heads: Some heads focus on punctuation and boundaries
  • BERT vs GPT-2: Bidirectional vs causal masking creates different patterns
  • Attention Flow: How attention changes across layers reveals learning hierarchy

Research Insights

šŸ“š Papers on Attention Analysis

  • "What Does BERT Look At?" (Clark et al., 2019)
    Found specific heads attend to syntactic relations like dependency arcs
  • "Are Sixteen Heads Really Better than One?" (Michel et al., 2019)
    Showed many heads can be pruned without hurting performance
  • "Analyzing Multi-Head Self-Attention" (Voita et al., 2019)
    Identified positional, syntactic, and rare word heads

Next Steps

šŸ”¬ Experiment

  • Try different text types (code, math, poetry)
  • Analyze other models (RoBERTa, GPT-3)
  • Compare multilingual attention patterns

šŸš€ Advanced

  • Implement attention rollout/flow
  • Head pruning experiments
  • Attention-based interpretability

šŸ“Š Research

  • Test linguistic hypotheses
  • Correlate attention with task performance
  • Study emergent linguistic properties

šŸ† You've unlocked the black box!

You can now visualize and understand what transformer models are actually learning. This is crucial for interpretability, debugging, and advancing AI research.

Test Your Knowledge

Q1: What do attention visualizations primarily show?

Training loss
Model size
Which tokens the model focuses on
Execution speed

Q2: What library is commonly used for creating attention heatmaps?

NumPy
Matplotlib or Seaborn
Pandas
Flask

Q3: In attention visualizations, what do darker colors typically represent?

Errors
Slower processing
Lower attention
Higher attention weights

Q4: What insight can you gain from visualizing attention patterns?

Which input tokens influence predictions
Exact training time
Memory usage
Disk space requirements

Q5: Why visualize attention patterns in multi-head attention?

To reduce model size
To speed up training
To understand what different heads learn
To eliminate bugs
šŸŽ“

Congratulations!

You've completed the Transformers Architecture course! You now understand how these powerful models work from the ground up.

Get Your Certificate

Your certificate will be generated instantly and sent to your email.

Continue Your Learning Journey

Explore these related courses: