🎓 Complete all tutorials to earn your Free MLOps Engineer Certificate
Shareable on LinkedIn • Verified by AITutorials.site • No signup fee
🎯 From Model to API
You've trained a model, packaged it properly, and now you need to deploy it. But how do users actually access your model's predictions? How do you handle thousands of concurrent requests? How do you prevent abuse? How do you monitor what's happening in production?
The answer is building a production-ready API - a web service that wraps your ML model and makes it accessible to applications, users, and other services.
⚠️ Common Problems with Poor API Design:
- API crashes when users send malformed data
- Can't handle more than 10 requests per second
- No way to know if the service is healthy
- Users abuse the API with spam requests
- Can't debug production issues (no logging)
- Security vulnerabilities allow unauthorized access
In this tutorial, you'll learn how to build a robust ML API using FastAPI - the modern, high-performance Python framework designed for production APIs.
⚡ Why FastAPI for ML Models?
FastAPI has become the de facto standard for ML model serving in Python. Here's why:
Performance
One of the fastest Python frameworks (comparable to Node.js and Go). Built on Starlette and Pydantic.
Type Safety
Automatic request validation using Python type hints. Catches errors before they reach your model.
Auto Documentation
Generates interactive API docs (Swagger UI) automatically from your code.
Async Support
Handle thousands of concurrent requests efficiently with async/await.
FastAPI vs Alternatives
| Feature | FastAPI | Flask | Django REST |
|---|---|---|---|
| Performance | ✅ Very Fast | Slow | Moderate |
| Async Support | ✅ Native | Limited | Added in 3.1 |
| Request Validation | ✅ Automatic | Manual | Manual (serializers) |
| Auto Documentation | ✅ Built-in | Requires extensions | Requires setup |
| Type Hints | ✅ Core feature | Optional | Optional |
| Learning Curve | Easy | ✅ Easiest | Steep |
| Best For | ✅ ML APIs, Microservices | Quick prototypes | Full web apps |
🏗️ Building Your First ML API
Installation
# Install FastAPI and ASGI server
pip install fastapi uvicorn[standard]
# Install ML dependencies
pip install scikit-learn joblib pydantic numpy
Minimal ML API
"""
Minimal FastAPI ML Prediction Service
File: app.py
"""
from fastapi import FastAPI
from pydantic import BaseModel
import joblib
import numpy as np
# Initialize FastAPI app
app = FastAPI(
title="Iris Classifier API",
description="ML model for predicting iris species",
version="1.0.0"
)
# Load model at startup
model = joblib.load('iris_model.joblib')
# Define request schema
class PredictionRequest(BaseModel):
sepal_length: float
sepal_width: float
petal_length: float
petal_width: float
class Config:
json_schema_extra = {
"example": {
"sepal_length": 5.1,
"sepal_width": 3.5,
"petal_length": 1.4,
"petal_width": 0.2
}
}
# Define response schema
class PredictionResponse(BaseModel):
prediction: int
species: str
confidence: float
# Prediction endpoint
@app.post("/predict", response_model=PredictionResponse)
def predict(request: PredictionRequest):
"""Make a prediction on iris features"""
# Convert to numpy array
features = np.array([[
request.sepal_length,
request.sepal_width,
request.petal_length,
request.petal_width
]])
# Get prediction and probability
prediction = model.predict(features)[0]
probabilities = model.predict_proba(features)[0]
confidence = float(max(probabilities))
# Map to species name
species_map = {0: 'setosa', 1: 'versicolor', 2: 'virginica'}
species = species_map[prediction]
return {
"prediction": int(prediction),
"species": species,
"confidence": confidence
}
# Health check endpoint
@app.get("/health")
def health():
"""Check if service is healthy"""
return {"status": "healthy"}
# Root endpoint
@app.get("/")
def root():
"""API information"""
return {
"name": "Iris Classifier API",
"version": "1.0.0",
"endpoints": ["/predict", "/health", "/docs"]
}
Running the API
# Start the server
uvicorn app:app --reload --host 0.0.0.0 --port 8000
# Server starts at http://localhost:8000
# Auto docs at http://localhost:8000/docs
# Alternative docs at http://localhost:8000/redoc
Testing Your API
# Test with curl
curl -X POST "http://localhost:8000/predict" \
-H "Content-Type: application/json" \
-d '{
"sepal_length": 5.1,
"sepal_width": 3.5,
"petal_length": 1.4,
"petal_width": 0.2
}'
# Response:
# {
# "prediction": 0,
# "species": "setosa",
# "confidence": 0.99
# }
# Test with Python requests
import requests
response = requests.post(
"http://localhost:8000/predict",
json={
"sepal_length": 5.1,
"sepal_width": 3.5,
"petal_length": 1.4,
"petal_width": 0.2
}
)
print(response.json())
# {'prediction': 0, 'species': 'setosa', 'confidence': 0.99}
✅ What You Get For Free:
- Interactive API documentation at /docs
- Request validation (rejects invalid data automatically)
- JSON schema for requests/responses
- Type safety with Python type hints
- Automatic error messages
✅ Advanced Request Validation
FastAPI's validation powered by Pydantic ensures data quality before it reaches your model.
Adding Constraints
from pydantic import BaseModel, Field, validator
from typing import List
class PredictionRequest(BaseModel):
sepal_length: float = Field(
...,
gt=0,
lt=10,
description="Sepal length in cm"
)
sepal_width: float = Field(
...,
gt=0,
lt=10,
description="Sepal width in cm"
)
petal_length: float = Field(
...,
gt=0,
lt=10,
description="Petal length in cm"
)
petal_width: float = Field(
...,
gt=0,
lt=10,
description="Petal width in cm"
)
# Custom validator
@validator('sepal_length', 'sepal_width', 'petal_length', 'petal_width')
def check_reasonable_values(cls, v):
if v > 15:
raise ValueError('Value seems unrealistically high')
return v
class Config:
json_schema_extra = {
"example": {
"sepal_length": 5.1,
"sepal_width": 3.5,
"petal_length": 1.4,
"petal_width": 0.2
}
}
# Batch prediction request
class BatchPredictionRequest(BaseModel):
instances: List[PredictionRequest] = Field(
...,
max_items=100,
description="Maximum 100 predictions per request"
)
@app.post("/predict/batch")
def predict_batch(request: BatchPredictionRequest):
"""Handle multiple predictions in one request"""
predictions = []
for instance in request.instances:
features = np.array([[
instance.sepal_length,
instance.sepal_width,
instance.petal_length,
instance.petal_width
]])
prediction = model.predict(features)[0]
probabilities = model.predict_proba(features)[0]
predictions.append({
"prediction": int(prediction),
"confidence": float(max(probabilities))
})
return {"predictions": predictions, "count": len(predictions)}
Error Handling
from fastapi import HTTPException, status
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
# Custom exception handler
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc):
"""Custom validation error messages"""
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content={
"error": "Invalid input",
"details": exc.errors(),
"message": "Please check your input values"
}
)
# Custom prediction error
class PredictionError(Exception):
pass
@app.exception_handler(PredictionError)
async def prediction_exception_handler(request, exc):
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={
"error": "Prediction failed",
"message": str(exc)
}
)
# Updated prediction endpoint with error handling
@app.post("/predict", response_model=PredictionResponse)
def predict(request: PredictionRequest):
try:
features = np.array([[
request.sepal_length,
request.sepal_width,
request.petal_length,
request.petal_width
]])
# Check for NaN or Inf
if not np.isfinite(features).all():
raise PredictionError("Input contains invalid values (NaN or Inf)")
prediction = model.predict(features)[0]
probabilities = model.predict_proba(features)[0]
species_map = {0: 'setosa', 1: 'versicolor', 2: 'virginica'}
return {
"prediction": int(prediction),
"species": species_map[prediction],
"confidence": float(max(probabilities))
}
except Exception as e:
raise PredictionError(f"Model prediction failed: {str(e)}")
⚡ Async Predictions for Scale
Async endpoints allow your API to handle thousands of concurrent requests efficiently.
When to Use Async
- I/O-bound operations: Database queries, file operations, external API calls
- Multiple models: Loading different models or running ensemble predictions
- Pre/post-processing: Async image downloads, data fetching
- High concurrency: Serving many users simultaneously
💡 Note: Pure NumPy/scikit-learn computations don't benefit from async (they're CPU-bound). Use async for I/O operations around predictions.
Async Endpoint Example
import asyncio
from fastapi import BackgroundTasks
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Async prediction with preprocessing
@app.post("/predict/async")
async def predict_async(request: PredictionRequest, background_tasks: BackgroundTasks):
"""Async prediction with background logging"""
# Simulate async preprocessing (e.g., fetching data from DB)
await asyncio.sleep(0.01) # Replace with actual async I/O
# Synchronous model prediction (CPU-bound)
features = np.array([[
request.sepal_length,
request.sepal_width,
request.petal_length,
request.petal_width
]])
prediction = model.predict(features)[0]
probabilities = model.predict_proba(features)[0]
species_map = {0: 'setosa', 1: 'versicolor', 2: 'virginica'}
result = {
"prediction": int(prediction),
"species": species_map[prediction],
"confidence": float(max(probabilities))
}
# Log in background (doesn't block response)
background_tasks.add_task(
log_prediction,
request=request,
result=result
)
return result
def log_prediction(request, result):
"""Background task for logging"""
logger.info(f"Prediction: {result['species']} (confidence: {result['confidence']:.2f})")
# Async batch processing
@app.post("/predict/batch/async")
async def predict_batch_async(request: BatchPredictionRequest):
"""Process multiple predictions concurrently"""
async def predict_single(instance):
# Simulate async preprocessing
await asyncio.sleep(0.01)
features = np.array([[
instance.sepal_length,
instance.sepal_width,
instance.petal_length,
instance.petal_width
]])
prediction = model.predict(features)[0]
probabilities = model.predict_proba(features)[0]
return {
"prediction": int(prediction),
"confidence": float(max(probabilities))
}
# Process all predictions concurrently
tasks = [predict_single(instance) for instance in request.instances]
predictions = await asyncio.gather(*tasks)
return {"predictions": predictions, "count": len(predictions)}
Load Testing Async Performance
# Install load testing tool
pip install locust
# Run load test
locust -f load_test.py --host http://localhost:8000
# load_test.py
from locust import HttpUser, task, between
class MLAPIUser(HttpUser):
wait_time = between(1, 3)
@task
def predict(self):
self.client.post("/predict", json={
"sepal_length": 5.1,
"sepal_width": 3.5,
"petal_length": 1.4,
"petal_width": 0.2
})
🏥 Health Checks & Monitoring
Production APIs need comprehensive health checks and monitoring to ensure reliability.
Comprehensive Health Checks
from datetime import datetime
from typing import Dict, Any
import psutil
import os
# Startup time
startup_time = datetime.now()
@app.get("/health")
def health_check():
"""Basic health check"""
return {"status": "healthy"}
@app.get("/health/detailed")
def detailed_health_check() -> Dict[str, Any]:
"""Detailed health check with system info"""
# Check model is loaded
model_loaded = model is not None
# System metrics
memory = psutil.virtual_memory()
cpu_percent = psutil.cpu_percent(interval=1)
# Uptime
uptime = (datetime.now() - startup_time).total_seconds()
return {
"status": "healthy" if model_loaded else "unhealthy",
"timestamp": datetime.now().isoformat(),
"uptime_seconds": uptime,
"model": {
"loaded": model_loaded,
"type": type(model).__name__ if model_loaded else None
},
"system": {
"cpu_percent": cpu_percent,
"memory_percent": memory.percent,
"memory_available_mb": memory.available / (1024 * 1024)
},
"process": {
"pid": os.getpid(),
"threads": psutil.Process().num_threads()
}
}
@app.get("/health/ready")
def readiness_check():
"""Kubernetes readiness probe"""
# Check if service can handle requests
if model is None:
return JSONResponse(
status_code=503,
content={"status": "not ready", "reason": "model not loaded"}
)
return {"status": "ready"}
@app.get("/health/live")
def liveness_check():
"""Kubernetes liveness probe"""
# Check if service is alive (simple check)
return {"status": "alive"}
Request Metrics
from collections import defaultdict
from time import time
from functools import wraps
# Simple in-memory metrics (use Prometheus in production)
metrics = {
"total_requests": 0,
"total_predictions": 0,
"total_errors": 0,
"prediction_times": [],
"endpoint_counts": defaultdict(int)
}
# Middleware for tracking requests
@app.middleware("http")
async def track_requests(request, call_next):
start_time = time()
metrics["total_requests"] += 1
response = await call_next(request)
process_time = time() - start_time
metrics["endpoint_counts"][request.url.path] += 1
# Add custom header
response.headers["X-Process-Time"] = str(process_time)
return response
# Metrics endpoint
@app.get("/metrics")
def get_metrics():
"""Get API metrics"""
avg_prediction_time = (
sum(metrics["prediction_times"]) / len(metrics["prediction_times"])
if metrics["prediction_times"] else 0
)
return {
"total_requests": metrics["total_requests"],
"total_predictions": metrics["total_predictions"],
"total_errors": metrics["total_errors"],
"average_prediction_time_ms": avg_prediction_time * 1000,
"endpoints": dict(metrics["endpoint_counts"])
}
# Updated predict endpoint with metrics
@app.post("/predict", response_model=PredictionResponse)
def predict_with_metrics(request: PredictionRequest):
start_time = time()
try:
features = np.array([[
request.sepal_length,
request.sepal_width,
request.petal_length,
request.petal_width
]])
prediction = model.predict(features)[0]
probabilities = model.predict_proba(features)[0]
species_map = {0: 'setosa', 1: 'versicolor', 2: 'virginica'}
# Track metrics
metrics["total_predictions"] += 1
metrics["prediction_times"].append(time() - start_time)
# Keep only last 1000 times
if len(metrics["prediction_times"]) > 1000:
metrics["prediction_times"] = metrics["prediction_times"][-1000:]
return {
"prediction": int(prediction),
"species": species_map[prediction],
"confidence": float(max(probabilities))
}
except Exception as e:
metrics["total_errors"] += 1
raise HTTPException(status_code=500, detail=str(e))
🔒 Authentication & Rate Limiting
API Key Authentication
from fastapi import Security, HTTPException, status
from fastapi.security import APIKeyHeader
from typing import Optional
# API key configuration
API_KEY_NAME = "X-API-Key"
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
# In production, store in database or environment
VALID_API_KEYS = {
"dev_key_12345": {"name": "Development", "tier": "free"},
"prod_key_67890": {"name": "Production", "tier": "premium"}
}
async def get_api_key(api_key: str = Security(api_key_header)) -> dict:
"""Validate API key"""
if api_key is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing API Key"
)
if api_key not in VALID_API_KEYS:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API Key"
)
return VALID_API_KEYS[api_key]
# Protected endpoint
@app.post("/predict/protected")
def predict_protected(
request: PredictionRequest,
api_key_info: dict = Security(get_api_key)
):
"""Prediction endpoint with authentication"""
# Different behavior based on tier
if api_key_info["tier"] == "free":
# Could add limits for free tier
pass
features = np.array([[
request.sepal_length,
request.sepal_width,
request.petal_length,
request.petal_width
]])
prediction = model.predict(features)[0]
probabilities = model.predict_proba(features)[0]
species_map = {0: 'setosa', 1: 'versicolor', 2: 'virginica'}
return {
"prediction": int(prediction),
"species": species_map[prediction],
"confidence": float(max(probabilities)),
"tier": api_key_info["tier"]
}
Rate Limiting
from fastapi import Request
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
# Initialize rate limiter
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# Rate limited endpoint
@app.post("/predict/limited")
@limiter.limit("10/minute") # 10 requests per minute
async def predict_rate_limited(
request: Request,
prediction_request: PredictionRequest
):
"""Prediction with rate limiting"""
features = np.array([[
prediction_request.sepal_length,
prediction_request.sepal_width,
prediction_request.petal_length,
prediction_request.petal_width
]])
prediction = model.predict(features)[0]
probabilities = model.predict_proba(features)[0]
species_map = {0: 'setosa', 1: 'versicolor', 2: 'virginica'}
return {
"prediction": int(prediction),
"species": species_map[prediction],
"confidence": float(max(probabilities))
}
# Tiered rate limiting
@app.post("/predict/tiered")
@limiter.limit("100/minute") # Global limit
async def predict_tiered(
request: Request,
prediction_request: PredictionRequest,
api_key_info: dict = Security(get_api_key)
):
"""Rate limiting based on API tier"""
# Premium users get higher limits (implemented in limiter config)
if api_key_info["tier"] == "premium":
# Premium logic
pass
# Prediction logic...
return {"status": "success"}
⚠️ Production Rate Limiting: For production, use Redis-backed rate limiting (slowapi with Redis) instead of in-memory limits, so rate limits persist across server restarts and work with multiple instances.
✅ Production Best Practices
1. Proper Project Structure
ml-api/
├── app/
│ ├── __init__.py
│ ├── main.py # FastAPI app
│ ├── models.py # Pydantic models
│ ├── api/
│ │ ├── __init__.py
│ │ ├── predictions.py # Prediction endpoints
│ │ ├── health.py # Health endpoints
│ │ └── auth.py # Authentication
│ ├── core/
│ │ ├── __init__.py
│ │ ├── config.py # Configuration
│ │ └── security.py # Security utilities
│ └── ml/
│ ├── __init__.py
│ ├── model.py # Model loading/inference
│ └── preprocessing.py # Data preprocessing
├── models/ # Serialized models
│ └── iris_model.joblib
├── tests/
│ ├── __init__.py
│ ├── test_api.py
│ └── test_predictions.py
├── requirements.txt
├── Dockerfile
└── README.md
2. Configuration Management
# app/core/config.py
from pydantic_settings import BaseSettings
from functools import lru_cache
class Settings(BaseSettings):
app_name: str = "ML API"
app_version: str = "1.0.0"
# Model settings
model_path: str = "models/iris_model.joblib"
max_batch_size: int = 100
# API settings
api_prefix: str = "/api/v1"
enable_auth: bool = True
# Rate limiting
rate_limit_per_minute: int = 60
# Monitoring
enable_metrics: bool = True
log_level: str = "INFO"
class Config:
env_file = ".env"
@lru_cache()
def get_settings():
return Settings()
# Usage in main.py
from app.core.config import get_settings
settings = get_settings()
app = FastAPI(title=settings.app_name, version=settings.app_version)
3. Logging
import logging
from logging.handlers import RotatingFileHandler
import sys
def setup_logging():
"""Configure application logging"""
# Create logger
logger = logging.getLogger("ml_api")
logger.setLevel(logging.INFO)
# Console handler
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)
console_format = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
console_handler.setFormatter(console_format)
# File handler with rotation
file_handler = RotatingFileHandler(
'logs/api.log',
maxBytes=10485760, # 10MB
backupCount=5
)
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(console_format)
logger.addHandler(console_handler)
logger.addHandler(file_handler)
return logger
logger = setup_logging()
# Usage in endpoints
@app.post("/predict")
def predict(request: PredictionRequest):
logger.info(f"Received prediction request: {request.dict()}")
# ... prediction logic
logger.info(f"Prediction completed: {result}")
return result
4. Testing
# tests/test_api.py
from fastapi.testclient import TestClient
from app.main import app
client = TestClient(app)
def test_health_check():
"""Test health endpoint"""
response = client.get("/health")
assert response.status_code == 200
assert response.json()["status"] == "healthy"
def test_prediction():
"""Test prediction endpoint"""
response = client.post(
"/predict",
json={
"sepal_length": 5.1,
"sepal_width": 3.5,
"petal_length": 1.4,
"petal_width": 0.2
}
)
assert response.status_code == 200
data = response.json()
assert "prediction" in data
assert "species" in data
assert "confidence" in data
assert 0 <= data["confidence"] <= 1
def test_invalid_input():
"""Test validation"""
response = client.post(
"/predict",
json={
"sepal_length": -1, # Invalid: negative
"sepal_width": 3.5,
"petal_length": 1.4,
"petal_width": 0.2
}
)
assert response.status_code == 422 # Validation error
5. CORS Configuration
from fastapi.middleware.cors import CORSMiddleware
app.add_middleware(
CORSMiddleware,
allow_origins=["https://yourfrontend.com"], # Specific origins in production
allow_credentials=True,
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
🎯 Summary
You've learned how to build production-ready ML APIs with FastAPI:
FastAPI Basics
High-performance framework with automatic validation and documentation
Request Validation
Type-safe requests with Pydantic models and custom validators
Async Support
Handle thousands of concurrent requests efficiently
Health Checks
Comprehensive monitoring and metrics for production
Security
API key authentication and rate limiting
Best Practices
Proper structure, logging, testing, and configuration
Key Takeaways
- FastAPI is the best choice for ML model serving in Python
- Use Pydantic models for automatic request validation
- Implement async endpoints for I/O-bound operations
- Add comprehensive health checks and monitoring
- Protect your API with authentication and rate limiting
- Follow production best practices (logging, testing, CORS)
- Always test your API before deploying to production
🚀 Next Steps:
Your API is ready, but how do you package it for deployment? The next tutorial covers containerization with Docker - creating consistent, portable environments for your ML services.
Test Your Knowledge
Q1: What is the main advantage of FastAPI over Flask for ML model serving?
Q2: What does Pydantic provide in FastAPI?
Q3: When should you use async endpoints in FastAPI?
Q4: What's the purpose of the /health/ready endpoint in Kubernetes?
Q5: Why is rate limiting important for ML APIs?