Implement a complete encoder-decoder transformer in PyTorch and train it for machine translation
In this hands-on project, you'll build a complete transformer architecture from scratch using PyTorch. By the end, you'll have a working machine translation model that can translate English to French.
pip install torch)import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from torch.utils.data import Dataset, DataLoader
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Hyperparameters
d_model = 512 # Model dimension
num_heads = 8 # Number of attention heads
num_layers = 6 # Number of encoder/decoder layers
d_ff = 2048 # Feed-forward dimension
dropout = 0.1 # Dropout rate
max_seq_len = 100 # Maximum sequence length
vocab_size_src = 10000 # Source vocabulary size
vocab_size_tgt = 10000 # Target vocabulary size
print("β
Setup complete!")
First, implement positional encoding to give the model information about token positions.
class PositionalEncoding(nn.Module):
"""
Sinusoidal positional encoding from 'Attention Is All You Need'.
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
"""
def __init__(self, d_model, max_seq_len=5000, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(dropout)
# Create positional encoding matrix
pe = torch.zeros(max_seq_len, d_model)
position = torch.arange(0, max_seq_len).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
-(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term) # Even dimensions
pe[:, 1::2] = torch.cos(position * div_term) # Odd dimensions
pe = pe.unsqueeze(0) # [1, max_seq_len, d_model]
self.register_buffer('pe', pe)
def forward(self, x):
"""
Args:
x: [batch, seq_len, d_model]
Returns:
[batch, seq_len, d_model] with positional encoding added
"""
x = x + self.pe[:, :x.size(1), :]
return self.dropout(x)
# Test
pe = PositionalEncoding(d_model=512)
test_input = torch.randn(2, 10, 512)
output = pe(test_input)
print(f"β
Positional Encoding: {test_input.shape} β {output.shape}")
Implement the core attention mechanism with multiple heads.
class MultiHeadAttention(nn.Module):
"""
Multi-head scaled dot-product attention.
"""
def __init__(self, d_model, num_heads, dropout=0.1):
super().__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# Linear projections for Q, K, V
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
# Output projection
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def scaled_dot_product_attention(self, Q, K, V, mask=None):
"""
Compute scaled dot-product attention.
Args:
Q, K, V: [batch, num_heads, seq_len, d_k]
mask: [batch, 1, seq_len, seq_len] or [batch, 1, 1, seq_len]
Returns:
output: [batch, num_heads, seq_len, d_k]
attention_weights: [batch, num_heads, seq_len, seq_len]
"""
# Compute attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
# [batch, num_heads, seq_len_q, seq_len_k]
# Apply mask (if provided)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Compute attention weights
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
# Apply attention to values
output = torch.matmul(attention_weights, V)
# [batch, num_heads, seq_len_q, d_k]
return output, attention_weights
def forward(self, query, key, value, mask=None):
"""
Args:
query: [batch, seq_len_q, d_model]
key: [batch, seq_len_k, d_model]
value: [batch, seq_len_v, d_model]
mask: [batch, seq_len_q, seq_len_k] or [batch, 1, seq_len_k]
Returns:
output: [batch, seq_len_q, d_model]
attention_weights: [batch, num_heads, seq_len_q, seq_len_k]
"""
batch_size = query.size(0)
# Linear projections and split into heads
Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# [batch, num_heads, seq_len, d_k]
# Expand mask for multi-head
if mask is not None:
mask = mask.unsqueeze(1) # [batch, 1, seq_len_q, seq_len_k]
# Apply attention
attn_output, attention_weights = self.scaled_dot_product_attention(Q, K, V, mask)
# Concatenate heads
attn_output = attn_output.transpose(1, 2).contiguous().view(
batch_size, -1, self.d_model
) # [batch, seq_len_q, d_model]
# Final linear projection
output = self.W_o(attn_output)
return output, attention_weights
# Test
mha = MultiHeadAttention(d_model=512, num_heads=8)
test_q = torch.randn(2, 10, 512)
test_k = torch.randn(2, 10, 512)
test_v = torch.randn(2, 10, 512)
output, attn_weights = mha(test_q, test_k, test_v)
print(f"β
Multi-Head Attention: Q{test_q.shape} β Output{output.shape}")
print(f" Attention weights: {attn_weights.shape}")
class FeedForward(nn.Module):
"""
Position-wise feed-forward network.
FFN(x) = ReLU(xW1 + b1)W2 + b2
"""
def __init__(self, d_model, d_ff, dropout=0.1):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(d_ff, d_model)
def forward(self, x):
"""
Args:
x: [batch, seq_len, d_model]
Returns:
[batch, seq_len, d_model]
"""
return self.linear2(self.dropout(F.relu(self.linear1(x))))
# Test
ffn = FeedForward(d_model=512, d_ff=2048)
test_input = torch.randn(2, 10, 512)
output = ffn(test_input)
print(f"β
Feed-Forward Network: {test_input.shape} β {output.shape}")
class EncoderLayer(nn.Module):
"""
Single encoder layer: Self-Attention + Feed-Forward
with residual connections and layer normalization.
"""
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.self_attention = MultiHeadAttention(d_model, num_heads, dropout)
self.feed_forward = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x, mask=None):
"""
Args:
x: [batch, seq_len, d_model]
mask: [batch, seq_len, seq_len] padding mask
Returns:
[batch, seq_len, d_model]
"""
# Self-attention + residual + norm
attn_output, _ = self.self_attention(x, x, x, mask)
x = self.norm1(x + self.dropout1(attn_output))
# Feed-forward + residual + norm
ff_output = self.feed_forward(x)
x = self.norm2(x + self.dropout2(ff_output))
return x
# Test
encoder_layer = EncoderLayer(d_model=512, num_heads=8, d_ff=2048)
test_input = torch.randn(2, 10, 512)
output = encoder_layer(test_input)
print(f"β
Encoder Layer: {test_input.shape} β {output.shape}")
class Encoder(nn.Module):
"""
Complete transformer encoder: embedding + N encoder layers.
"""
def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff,
max_seq_len, dropout=0.1):
super().__init__()
self.d_model = d_model
# Token + positional embedding
self.embedding = nn.Embedding(vocab_size, d_model)
self.positional_encoding = PositionalEncoding(d_model, max_seq_len, dropout)
# Stack of encoder layers
self.layers = nn.ModuleList([
EncoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
self.dropout = nn.Dropout(dropout)
def forward(self, src, src_mask=None):
"""
Args:
src: [batch, src_len] token indices
src_mask: [batch, 1, src_len] padding mask
Returns:
[batch, src_len, d_model] encoded representations
"""
# Embed and scale
x = self.embedding(src) * math.sqrt(self.d_model)
x = self.positional_encoding(x)
# Pass through encoder layers
for layer in self.layers:
x = layer(x, src_mask)
return x
# Test
encoder = Encoder(vocab_size=10000, d_model=512, num_layers=6,
num_heads=8, d_ff=2048, max_seq_len=100)
test_src = torch.randint(0, 10000, (2, 10))
output = encoder(test_src)
print(f"β
Complete Encoder: {test_src.shape} β {output.shape}")
class DecoderLayer(nn.Module):
"""
Single decoder layer: Masked Self-Attention + Cross-Attention + Feed-Forward
"""
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.self_attention = MultiHeadAttention(d_model, num_heads, dropout)
self.cross_attention = MultiHeadAttention(d_model, num_heads, dropout)
self.feed_forward = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
"""
Args:
x: [batch, tgt_len, d_model] decoder input
encoder_output: [batch, src_len, d_model]
src_mask: [batch, 1, src_len] encoder padding mask
tgt_mask: [batch, tgt_len, tgt_len] causal mask
Returns:
[batch, tgt_len, d_model]
"""
# Masked self-attention
attn_output, _ = self.self_attention(x, x, x, tgt_mask)
x = self.norm1(x + self.dropout1(attn_output))
# Cross-attention to encoder
attn_output, _ = self.cross_attention(x, encoder_output, encoder_output, src_mask)
x = self.norm2(x + self.dropout2(attn_output))
# Feed-forward
ff_output = self.feed_forward(x)
x = self.norm3(x + self.dropout3(ff_output))
return x
# Test
decoder_layer = DecoderLayer(d_model=512, num_heads=8, d_ff=2048)
test_x = torch.randn(2, 8, 512)
test_enc = torch.randn(2, 10, 512)
output = decoder_layer(test_x, test_enc)
print(f"β
Decoder Layer: {test_x.shape} β {output.shape}")
class Decoder(nn.Module):
"""
Complete transformer decoder: embedding + N decoder layers.
"""
def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff,
max_seq_len, dropout=0.1):
super().__init__()
self.d_model = d_model
# Token + positional embedding
self.embedding = nn.Embedding(vocab_size, d_model)
self.positional_encoding = PositionalEncoding(d_model, max_seq_len, dropout)
# Stack of decoder layers
self.layers = nn.ModuleList([
DecoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
# Output projection to vocabulary
self.fc_out = nn.Linear(d_model, vocab_size)
self.dropout = nn.Dropout(dropout)
def forward(self, tgt, encoder_output, src_mask=None, tgt_mask=None):
"""
Args:
tgt: [batch, tgt_len] target token indices
encoder_output: [batch, src_len, d_model]
src_mask: [batch, 1, src_len]
tgt_mask: [batch, tgt_len, tgt_len] causal mask
Returns:
[batch, tgt_len, vocab_size] logits
"""
# Embed and scale
x = self.embedding(tgt) * math.sqrt(self.d_model)
x = self.positional_encoding(x)
# Pass through decoder layers
for layer in self.layers:
x = layer(x, encoder_output, src_mask, tgt_mask)
# Project to vocabulary
logits = self.fc_out(x)
return logits
# Test
decoder = Decoder(vocab_size=10000, d_model=512, num_layers=6,
num_heads=8, d_ff=2048, max_seq_len=100)
test_tgt = torch.randint(0, 10000, (2, 8))
test_enc = torch.randn(2, 10, 512)
output = decoder(test_tgt, test_enc)
print(f"β
Complete Decoder: {test_tgt.shape} β {output.shape}")
class Transformer(nn.Module):
"""
Complete transformer model: Encoder + Decoder
"""
def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512,
num_layers=6, num_heads=8, d_ff=2048, max_seq_len=100, dropout=0.1):
super().__init__()
self.encoder = Encoder(src_vocab_size, d_model, num_layers, num_heads,
d_ff, max_seq_len, dropout)
self.decoder = Decoder(tgt_vocab_size, d_model, num_layers, num_heads,
d_ff, max_seq_len, dropout)
def make_src_mask(self, src):
"""Create padding mask for source."""
# src: [batch, src_len]
src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
# [batch, 1, 1, src_len]
return src_mask
def make_tgt_mask(self, tgt):
"""Create causal mask for target."""
# tgt: [batch, tgt_len]
batch_size, tgt_len = tgt.shape
# Padding mask
tgt_pad_mask = (tgt != 0).unsqueeze(1).unsqueeze(2)
# [batch, 1, 1, tgt_len]
# Causal mask (lower triangular)
tgt_sub_mask = torch.tril(torch.ones((tgt_len, tgt_len), device=tgt.device)).bool()
# [tgt_len, tgt_len]
# Combine masks
tgt_mask = tgt_pad_mask & tgt_sub_mask
# [batch, 1, tgt_len, tgt_len]
return tgt_mask
def forward(self, src, tgt):
"""
Args:
src: [batch, src_len] source token indices
tgt: [batch, tgt_len] target token indices
Returns:
[batch, tgt_len, tgt_vocab_size] output logits
"""
src_mask = self.make_src_mask(src)
tgt_mask = self.make_tgt_mask(tgt)
encoder_output = self.encoder(src, src_mask)
decoder_output = self.decoder(tgt, encoder_output, src_mask, tgt_mask)
return decoder_output
# Create model
model = Transformer(
src_vocab_size=10000,
tgt_vocab_size=10000,
d_model=512,
num_layers=6,
num_heads=8,
d_ff=2048,
max_seq_len=100,
dropout=0.1
).to(device)
# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"β
Complete Transformer Model Created!")
print(f" Total parameters: {total_params:,}")
# Test forward pass
test_src = torch.randint(1, 10000, (2, 10)).to(device)
test_tgt = torch.randint(1, 10000, (2, 8)).to(device)
output = model(test_src, test_tgt)
print(f" Forward pass: src{test_src.shape} + tgt{test_tgt.shape} β output{output.shape}")
import torch.optim as optim
# Optimizer
optimizer = optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
# Loss function (ignore padding tokens)
criterion = nn.CrossEntropyLoss(ignore_index=0)
# Learning rate scheduler (Noam scheduler from paper)
class NoamScheduler:
def __init__(self, optimizer, d_model, warmup_steps=4000):
self.optimizer = optimizer
self.d_model = d_model
self.warmup_steps = warmup_steps
self.step_num = 0
def step(self):
self.step_num += 1
lr = self.d_model ** (-0.5) * min(
self.step_num ** (-0.5),
self.step_num * self.warmup_steps ** (-1.5)
)
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
scheduler = NoamScheduler(optimizer, d_model=512, warmup_steps=4000)
print("β
Training setup complete!")
print(f" Optimizer: Adam")
print(f" Loss: CrossEntropyLoss (ignore padding)")
print(f" Scheduler: Noam (warmup_steps=4000)")
def train_epoch(model, dataloader, optimizer, criterion, scheduler, device):
"""Train for one epoch."""
model.train()
total_loss = 0
for batch_idx, (src, tgt) in enumerate(dataloader):
src, tgt = src.to(device), tgt.to(device)
# Prepare decoder input (shift right)
tgt_input = tgt[:, :-1]
tgt_output = tgt[:, 1:]
# Forward pass
output = model(src, tgt_input)
# Compute loss
output = output.reshape(-1, output.size(-1))
tgt_output = tgt_output.reshape(-1)
loss = criterion(output, tgt_output)
# Backward pass
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
scheduler.step()
total_loss += loss.item()
if batch_idx % 100 == 0:
print(f" Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}")
return total_loss / len(dataloader)
def evaluate(model, dataloader, criterion, device):
"""Evaluate model."""
model.eval()
total_loss = 0
with torch.no_grad():
for src, tgt in dataloader:
src, tgt = src.to(device), tgt.to(device)
tgt_input = tgt[:, :-1]
tgt_output = tgt[:, 1:]
output = model(src, tgt_input)
output = output.reshape(-1, output.size(-1))
tgt_output = tgt_output.reshape(-1)
loss = criterion(output, tgt_output)
total_loss += loss.item()
return total_loss / len(dataloader)
# Training loop
num_epochs = 10
print("π Starting training...")
for epoch in range(num_epochs):
print(f"\nEpoch {epoch+1}/{num_epochs}")
train_loss = train_epoch(model, train_loader, optimizer, criterion, scheduler, device)
val_loss = evaluate(model, val_loader, criterion, device)
print(f" Train Loss: {train_loss:.4f}")
print(f" Val Loss: {val_loss:.4f}")
# Save checkpoint
if (epoch + 1) % 5 == 0:
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'train_loss': train_loss,
'val_loss': val_loss,
}, f'transformer_checkpoint_epoch_{epoch+1}.pt')
print(f" β
Checkpoint saved!")
print("\nπ Training complete!")
def translate(model, src_sentence, src_vocab, tgt_vocab, device, max_len=50):
"""
Translate a source sentence to target language.
Args:
model: Trained transformer
src_sentence: Source text string
src_vocab: Source vocabulary
tgt_vocab: Target vocabulary
device: torch device
max_len: Maximum output length
Returns:
Translated text string
"""
model.eval()
# Tokenize and add special tokens
src_tokens = [''] + src_sentence.split() + ['']
src_indices = [src_vocab.get(token, src_vocab['']) for token in src_tokens]
src_tensor = torch.LongTensor(src_indices).unsqueeze(0).to(device)
# Encode
with torch.no_grad():
src_mask = model.make_src_mask(src_tensor)
encoder_output = model.encoder(src_tensor, src_mask)
# Decode (greedy)
tgt_indices = [tgt_vocab['']]
for _ in range(max_len):
tgt_tensor = torch.LongTensor(tgt_indices).unsqueeze(0).to(device)
with torch.no_grad():
tgt_mask = model.make_tgt_mask(tgt_tensor)
output = model.decoder(tgt_tensor, encoder_output, src_mask, tgt_mask)
# Get next token
next_token_logits = output[0, -1, :]
next_token_id = torch.argmax(next_token_logits).item()
tgt_indices.append(next_token_id)
# Stop if EOS token
if next_token_id == tgt_vocab['']:
break
# Convert indices to tokens
tgt_tokens = [list(tgt_vocab.keys())[list(tgt_vocab.values()).index(idx)]
for idx in tgt_indices[1:-1]] # Skip and
return ' '.join(tgt_tokens)
# Example translations
test_sentences = [
"Hello, how are you?",
"I love machine learning.",
"Transformers are amazing!",
]
print("\nπ Testing Translations:")
print("=" * 70)
for sentence in test_sentences:
translation = translate(model, sentence, src_vocab, tgt_vocab, device)
print(f"EN: {sentence}")
print(f"FR: {translation}")
print("-" * 70)
π You've built a transformer from scratch!
This is the same architecture powering GPT, BERT, and modern AI systems. You now understand how these models work at a fundamental level.
Q1: What is the first component you implement in a Transformer?
Q2: What PyTorch module is commonly used for multi-head attention?
Q3: What must be added to token embeddings for position information?
Q4: What is the purpose of the feedforward network in each layer?
Q5: Why implement a Transformer from scratch?