Build a GPT-style decoder-only transformer and train it to generate text
In this project, you'll build a character-level language model using a decoder-only transformer architecture (like GPT). You'll train it on Shakespeare's text and watch it learn to generate similar writing!
pip install torch)import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import requests
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Download Shakespeare dataset
print("š„ Downloading Shakespeare dataset...")
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
text = requests.get(url).text
print(f"ā
Downloaded {len(text):,} characters")
# Preview the data
print("\nš First 500 characters:")
print(text[:500])
print("\n" + "="*70)
# Character-level tokenization
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(f"\nš Vocabulary: {vocab_size} unique characters")
print(f"Characters: {''.join(chars)}")
# Create mappings
char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for i, ch in enumerate(chars)}
# Encode/decode functions
def encode(text):
return [char_to_idx[ch] for ch in text]
def decode(indices):
return ''.join([idx_to_char[i] for i in indices])
# Test encoding/decoding
test_text = "Hello, World!"
encoded = encode(test_text)
decoded = decode(encoded)
print(f"\nš¤ Test encoding: '{test_text}' ā {encoded}")
print(f"š¤ Test decoding: {encoded} ā '{decoded}'")
from torch.utils.data import Dataset, DataLoader
class TextDataset(Dataset):
"""Character-level text dataset."""
def __init__(self, text, block_size):
self.data = torch.tensor(encode(text), dtype=torch.long)
self.block_size = block_size
def __len__(self):
return len(self.data) - self.block_size
def __getitem__(self, idx):
# Get chunk of text
chunk = self.data[idx:idx + self.block_size + 1]
x = chunk[:-1] # Input
y = chunk[1:] # Target (shifted by 1)
return x, y
# Hyperparameters
block_size = 128 # Context length
batch_size = 64 # Batch size
train_split = 0.9 # Train/val split
# Split data
n = int(train_split * len(text))
train_text = text[:n]
val_text = text[n:]
# Create datasets
train_dataset = TextDataset(train_text, block_size)
val_dataset = TextDataset(val_text, block_size)
# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
print(f"\nš Dataset Statistics:")
print(f" Training samples: {len(train_dataset):,}")
print(f" Validation samples: {len(val_dataset):,}")
print(f" Batches per epoch: {len(train_loader):,}")
# Test batch
x_batch, y_batch = next(iter(train_loader))
print(f"\nā
Batch shapes: x={x_batch.shape}, y={y_batch.shape}")
print(f" Example input: '{decode(x_batch[0].tolist())}'")
print(f" Example target: '{decode(y_batch[0].tolist())}')")
class CausalSelfAttention(nn.Module):
"""
Multi-head masked self-attention for decoder.
"""
def __init__(self, d_model, num_heads, dropout=0.1):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# Key, Query, Value projections for all heads (batched)
self.c_attn = nn.Linear(d_model, 3 * d_model)
# Output projection
self.c_proj = nn.Linear(d_model, d_model)
# Regularization
self.attn_dropout = nn.Dropout(dropout)
self.resid_dropout = nn.Dropout(dropout)
def forward(self, x):
"""
Args:
x: [batch, seq_len, d_model]
Returns:
output: [batch, seq_len, d_model]
"""
batch_size, seq_len, d_model = x.shape
# Calculate Q, K, V for all heads in batch
q, k, v = self.c_attn(x).split(self.d_model, dim=2)
# Reshape for multi-head attention
q = q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
k = k.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
v = v.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
# [batch, num_heads, seq_len, d_k]
# Compute attention scores
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
# [batch, num_heads, seq_len, seq_len]
# Apply causal mask (prevent attending to future tokens)
mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device)).view(1, 1, seq_len, seq_len)
scores = scores.masked_fill(mask == 0, float('-inf'))
# Softmax and dropout
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.attn_dropout(attn_weights)
# Apply attention to values
attn_output = torch.matmul(attn_weights, v)
# [batch, num_heads, seq_len, d_k]
# Concatenate heads
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
# Output projection
output = self.resid_dropout(self.c_proj(attn_output))
return output
# Test
attn = CausalSelfAttention(d_model=256, num_heads=8).to(device)
test_input = torch.randn(2, 10, 256).to(device)
output = attn(test_input)
print(f"ā
Causal Self-Attention: {test_input.shape} ā {output.shape}")
class FeedForward(nn.Module):
"""
Position-wise feed-forward network with GELU activation.
"""
def __init__(self, d_model, d_ff, dropout=0.1):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(), # GELU is used in GPT
nn.Linear(d_ff, d_model),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
# Test
ffn = FeedForward(d_model=256, d_ff=1024).to(device)
test_input = torch.randn(2, 10, 256).to(device)
output = ffn(test_input)
print(f"ā
Feed-Forward Network: {test_input.shape} ā {output.shape}")
class TransformerBlock(nn.Module):
"""
Single transformer decoder block: Attention + Feed-Forward
with layer normalization and residual connections.
"""
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.attn = CausalSelfAttention(d_model, num_heads, dropout)
self.ln2 = nn.LayerNorm(d_model)
self.ffn = FeedForward(d_model, d_ff, dropout)
def forward(self, x):
"""
Args:
x: [batch, seq_len, d_model]
Returns:
[batch, seq_len, d_model]
"""
# Pre-norm architecture (used in GPT-2 and later models)
x = x + self.attn(self.ln1(x))
x = x + self.ffn(self.ln2(x))
return x
# Test
block = TransformerBlock(d_model=256, num_heads=8, d_ff=1024).to(device)
test_input = torch.randn(2, 10, 256).to(device)
output = block(test_input)
print(f"ā
Transformer Block: {test_input.shape} ā {output.shape}")
class GPT(nn.Module):
"""
Simple GPT-style language model (decoder-only transformer).
"""
def __init__(self, vocab_size, block_size, d_model=256, num_layers=6,
num_heads=8, d_ff=1024, dropout=0.1):
super().__init__()
self.block_size = block_size
# Token + position embeddings
self.token_embedding = nn.Embedding(vocab_size, d_model)
self.position_embedding = nn.Embedding(block_size, d_model)
self.dropout = nn.Dropout(dropout)
# Transformer blocks
self.blocks = nn.ModuleList([
TransformerBlock(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
# Output layer
self.ln_f = nn.LayerNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
# Initialize weights
self.apply(self._init_weights)
print(f"ā
GPT Model created with {sum(p.numel() for p in self.parameters()):,} parameters")
def _init_weights(self, module):
"""Initialize weights."""
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, idx, targets=None):
"""
Args:
idx: [batch, seq_len] input token indices
targets: [batch, seq_len] target token indices (optional)
Returns:
logits: [batch, seq_len, vocab_size]
loss: scalar (if targets provided)
"""
batch_size, seq_len = idx.shape
# Token embeddings + positional embeddings
pos = torch.arange(0, seq_len, dtype=torch.long, device=idx.device).unsqueeze(0)
tok_emb = self.token_embedding(idx) # [batch, seq_len, d_model]
pos_emb = self.position_embedding(pos) # [1, seq_len, d_model]
x = self.dropout(tok_emb + pos_emb)
# Pass through transformer blocks
for block in self.blocks:
x = block(x)
# Final layer norm and projection to vocabulary
x = self.ln_f(x)
logits = self.lm_head(x) # [batch, seq_len, vocab_size]
# Compute loss if targets provided
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, top_p=None):
"""
Generate new tokens autoregressively.
Args:
idx: [batch, seq_len] starting token indices
max_new_tokens: number of tokens to generate
temperature: sampling temperature (higher = more random)
top_k: keep only top k tokens by probability (optional)
top_p: nucleus sampling - keep top tokens with cumulative prob >= p (optional)
Returns:
[batch, seq_len + max_new_tokens] generated token indices
"""
for _ in range(max_new_tokens):
# Crop to block_size
idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:]
# Forward pass
logits, _ = self(idx_cond)
# Get logits for last token
logits = logits[:, -1, :] / temperature # [batch, vocab_size]
# Top-k sampling
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = float('-inf')
# Top-p (nucleus) sampling
if top_p is not None:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative prob > top_p
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = float('-inf')
# Sample from distribution
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
# Append to sequence
idx = torch.cat((idx, idx_next), dim=1)
return idx
# Create model
model = GPT(
vocab_size=vocab_size,
block_size=block_size,
d_model=256,
num_layers=6,
num_heads=8,
d_ff=1024,
dropout=0.1
).to(device)
# Test forward pass
test_idx = torch.randint(0, vocab_size, (2, 10)).to(device)
test_targets = torch.randint(0, vocab_size, (2, 10)).to(device)
logits, loss = model(test_idx, test_targets)
print(f"\nā
Forward pass: idx{test_idx.shape} ā logits{logits.shape}, loss={loss.item():.4f}")
import torch.optim as optim
# Optimizer
optimizer = optim.AdamW(model.parameters(), lr=3e-4, betas=(0.9, 0.95), weight_decay=0.1)
# Learning rate scheduler
def get_lr(iteration, warmup_iters=100, lr_decay_iters=5000, min_lr=1e-5):
"""Cosine learning rate schedule with warmup."""
learning_rate = 3e-4
# Linear warmup
if iteration < warmup_iters:
return learning_rate * iteration / warmup_iters
# Cosine decay after warmup
if iteration > lr_decay_iters:
return min_lr
decay_ratio = (iteration - warmup_iters) / (lr_decay_iters - warmup_iters)
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
return min_lr + coeff * (learning_rate - min_lr)
def train_epoch(model, train_loader, optimizer, device, epoch):
"""Train for one epoch."""
model.train()
total_loss = 0
for batch_idx, (x, y) in enumerate(train_loader):
x, y = x.to(device), y.to(device)
# Update learning rate
lr = get_lr(epoch * len(train_loader) + batch_idx)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# Forward pass
logits, loss = model(x, y)
# Backward pass
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
total_loss += loss.item()
if batch_idx % 100 == 0:
print(f" Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}, LR: {lr:.6f}")
return total_loss / len(train_loader)
@torch.no_grad()
def evaluate(model, val_loader, device):
"""Evaluate model."""
model.eval()
total_loss = 0
for x, y in val_loader:
x, y = x.to(device), y.to(device)
logits, loss = model(x, y)
total_loss += loss.item()
return total_loss / len(val_loader)
# Training loop
num_epochs = 10
print("\nš Starting training...")
print("=" * 70)
for epoch in range(num_epochs):
print(f"\nEpoch {epoch+1}/{num_epochs}")
train_loss = train_epoch(model, train_loader, optimizer, device, epoch)
val_loss = evaluate(model, val_loader, device)
print(f" Train Loss: {train_loss:.4f}")
print(f" Val Loss: {val_loss:.4f}")
# Generate sample text
if (epoch + 1) % 2 == 0:
print("\n š Sample generation:")
context = torch.tensor([[char_to_idx['\n']]], dtype=torch.long, device=device)
generated = model.generate(context, max_new_tokens=200, temperature=0.8, top_k=50)
generated_text = decode(generated[0].tolist())
print(f" {generated_text}")
print()
# Save checkpoint
if (epoch + 1) % 5 == 0:
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'train_loss': train_loss,
'val_loss': val_loss,
}, f'gpt_checkpoint_epoch_{epoch+1}.pt')
print(f" ā
Checkpoint saved!")
print("\nš Training complete!")
@torch.no_grad()
def complete_text(model, prompt, max_tokens=200, temperature=0.8, top_k=50, top_p=0.9):
"""
Complete a text prompt.
Args:
model: Trained GPT model
prompt: Starting text
max_tokens: Maximum tokens to generate
temperature: Sampling temperature (0.1-2.0)
top_k: Top-k sampling
top_p: Nucleus sampling
Returns:
Completed text
"""
model.eval()
# Encode prompt
encoded = encode(prompt)
idx = torch.tensor([encoded], dtype=torch.long, device=device)
# Generate
generated = model.generate(idx, max_new_tokens=max_tokens,
temperature=temperature, top_k=top_k, top_p=top_p)
# Decode
output_text = decode(generated[0].tolist())
return output_text
# Interactive demo
print("\nš Shakespeare Text Generator")
print("=" * 70)
prompts = [
"ROMEO:",
"To be or not to be,",
"O Romeo, Romeo!",
"First Citizen:",
]
for prompt in prompts:
print(f"\nš Prompt: '{prompt}'")
print("-" * 70)
# Low temperature (more deterministic)
print("\nšµ Temperature 0.5 (Conservative):")
output = complete_text(model, prompt, max_tokens=150, temperature=0.5, top_k=50)
print(output)
# High temperature (more creative)
print("\nš“ Temperature 1.2 (Creative):")
output = complete_text(model, prompt, max_tokens=150, temperature=1.2, top_k=50)
print(output)
print("\n" + "=" * 70)
import matplotlib.pyplot as plt
# Compare sampling strategies
@torch.no_grad()
def compare_sampling_strategies(model, prompt, max_tokens=100):
"""Compare different sampling approaches."""
model.eval()
encoded = encode(prompt)
idx = torch.tensor([encoded], dtype=torch.long, device=device)
strategies = {
'Greedy (argmax)': {'temperature': 1.0, 'top_k': 1, 'top_p': None},
'Low temp (0.3)': {'temperature': 0.3, 'top_k': None, 'top_p': None},
'High temp (1.5)': {'temperature': 1.5, 'top_k': None, 'top_p': None},
'Top-k (k=10)': {'temperature': 1.0, 'top_k': 10, 'top_p': None},
'Top-p (p=0.9)': {'temperature': 1.0, 'top_k': None, 'top_p': 0.9},
}
print(f"\nš¬ Comparing Sampling Strategies")
print(f"Prompt: '{prompt}'")
print("=" * 70)
for name, params in strategies.items():
generated = model.generate(idx.clone(), max_new_tokens=max_tokens, **params)
output = decode(generated[0].tolist())
print(f"\n{name}:")
print(output)
print("-" * 70)
# Test
compare_sampling_strategies(model, "ROMEO:", max_tokens=100)
# Analyze model perplexity
@torch.no_grad()
def compute_perplexity(model, data_loader, device):
"""Compute perplexity on dataset."""
model.eval()
total_loss = 0
total_tokens = 0
for x, y in data_loader:
x, y = x.to(device), y.to(device)
logits, loss = model(x, y)
total_loss += loss.item() * y.numel()
total_tokens += y.numel()
avg_loss = total_loss / total_tokens
perplexity = math.exp(avg_loss)
return perplexity
train_ppl = compute_perplexity(model, train_loader, device)
val_ppl = compute_perplexity(model, val_loader, device)
print(f"\nš Model Performance:")
print(f" Training Perplexity: {train_ppl:.2f}")
print(f" Validation Perplexity: {val_ppl:.2f}")
print(f" Lower is better (random baseline: {vocab_size:.2f})")
š You've built a language model from scratch!
This is the same architecture powering GPT-3, ChatGPT, and other modern LLMs. You now understand the fundamentals of how these systems work!
Q1: What is the main task of a language model?
Q2: What loss function is typically used for language modeling?
Q3: What metric measures how well a language model predicts text?
Q4: During text generation, how is the next token typically selected?
Q5: What technique helps control randomness in text generation?