Home → MLOps Engineer → Model Retraining & Continuous Learning

Model Retraining & Continuous Learning

Automate model retraining pipelines, implement trigger conditions, online vs batch learning, shadow mode testing, and safe rollback strategies

šŸ“… Tutorial 12 šŸ“Š Intermediate

šŸŽ“ Complete all tutorials to earn your Free MLOps Engineer Certificate
Shareable on LinkedIn • Verified by AITutorials.site • No signup fee

šŸ”„ Why Retrain Models?

Your model was 95% accurate at launch. Six months later:

  • Accuracy dropped to 75% due to changing user behavior
  • New product categories appeared (not in training data)
  • Seasonal patterns shifted consumer preferences
  • Competitors changed market dynamics
  • Data drift detected across multiple features

Model retraining keeps models relevant by learning from new data. Unlike static models that degrade over time, continuously learning models adapt to changing patterns and maintain performance.

šŸ’” When to Retrain:

  • Performance Degradation: Accuracy drops below threshold
  • Data Drift Detected: Input distributions change significantly
  • New Data Available: Sufficient fresh labeled data accumulated
  • Business Events: Product launches, seasonal changes, market shifts
  • Scheduled Intervals: Weekly/monthly regular retraining
  • Concept Drift: Relationship between features and target changes

Retraining Strategies

Strategy When to Use Pros Cons
Scheduled Stable environments, predictable patterns Simple, predictable May retrain unnecessarily or miss urgent issues
Trigger-Based Dynamic environments, when monitoring is active Efficient, responds to real issues Requires robust monitoring
Continuous Fast-changing environments, real-time systems Always up-to-date Complex, resource intensive
Hybrid Most production systems Balanced approach More complex implementation

šŸŽÆ Retraining Trigger Conditions

Performance-Based Triggers

"""
Monitor performance and trigger retraining
"""
import pandas as pd
from datetime import datetime, timedelta

class RetrainingTrigger:
    def __init__(
        self,
        accuracy_threshold=0.85,
        drift_threshold=0.3,
        min_new_samples=1000,
        lookback_days=7
    ):
        self.accuracy_threshold = accuracy_threshold
        self.drift_threshold = drift_threshold
        self.min_new_samples = min_new_samples
        self.lookback_days = lookback_days
    
    def should_retrain(self):
        """Check if retraining should be triggered"""
        
        triggers = {}
        
        # Check 1: Performance degradation
        recent_accuracy = self.get_recent_accuracy()
        if recent_accuracy < self.accuracy_threshold:
            triggers['performance'] = f"Accuracy {recent_accuracy:.2f} < {self.accuracy_threshold}"
        
        # Check 2: Data drift
        drift_score = self.calculate_drift()
        if drift_score > self.drift_threshold:
            triggers['drift'] = f"Drift score {drift_score:.2f} > {self.drift_threshold}"
        
        # Check 3: Sufficient new data
        new_samples = self.count_new_samples()
        if new_samples >= self.min_new_samples:
            triggers['new_data'] = f"{new_samples} new samples available"
        
        # Check 4: Prediction distribution shift
        prediction_shift = self.check_prediction_shift()
        if prediction_shift > 0.2:
            triggers['prediction_shift'] = f"Prediction distribution shifted by {prediction_shift:.2f}"
        
        # Trigger if any condition met
        if triggers:
            print(f"šŸ”„ Retraining triggered: {triggers}")
            return True, triggers
        
        return False, {}
    
    def get_recent_accuracy(self):
        """Calculate accuracy over recent predictions with ground truth"""
        cutoff = datetime.now() - timedelta(days=self.lookback_days)
        
        predictions = pd.read_sql(
            f"""
            SELECT predicted, actual
            FROM predictions
            WHERE timestamp > '{cutoff}' AND actual IS NOT NULL
            """,
            con=db_connection
        )
        
        if len(predictions) < 100:
            return None  # Insufficient data
        
        accuracy = (predictions['predicted'] == predictions['actual']).mean()
        return accuracy
    
    def calculate_drift(self):
        """Calculate data drift score"""
        from evidently.report import Report
        from evidently.metrics import DatasetDriftMetric
        
        reference_data = load_reference_data()
        current_data = load_recent_data(days=self.lookback_days)
        
        report = Report(metrics=[DatasetDriftMetric()])
        report.run(reference_data=reference_data, current_data=current_data)
        
        results = report.as_dict()
        drift_score = results['metrics'][0]['result']['share_of_drifted_columns']
        
        return drift_score
    
    def count_new_samples(self):
        """Count new labeled samples since last training"""
        last_training = get_last_training_date()
        
        count = pd.read_sql(
            f"""
            SELECT COUNT(*) as cnt
            FROM training_data
            WHERE timestamp > '{last_training}' AND label IS NOT NULL
            """,
            con=db_connection
        )
        
        return count['cnt'].iloc[0]
    
    def check_prediction_shift(self):
        """Check if prediction distribution has shifted"""
        from scipy.stats import wasserstein_distance
        
        reference_predictions = load_reference_predictions()
        current_predictions = load_recent_predictions(days=self.lookback_days)
        
        distance = wasserstein_distance(reference_predictions, current_predictions)
        return distance

# Usage
trigger = RetrainingTrigger()
should_retrain, reasons = trigger.should_retrain()

if should_retrain:
    trigger_retraining_pipeline(reasons)

Scheduled Retraining with Airflow

"""
Airflow DAG for scheduled retraining
"""
from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.operators.bash import BashOperator
from datetime import datetime, timedelta

default_args = {
    'owner': 'mlops',
    'depends_on_past': False,
    'email': ['team@company.com'],
    'email_on_failure': True,
    'email_on_retry': False,
    'retries': 2,
    'retry_delay': timedelta(minutes=5),
}

dag = DAG(
    'model_retraining',
    default_args=default_args,
    description='Automated model retraining pipeline',
    schedule_interval='0 2 * * 0',  # Every Sunday at 2 AM
    start_date=datetime(2024, 1, 1),
    catchup=False,
)

def check_if_retraining_needed(**context):
    """Decide if retraining should proceed"""
    trigger = RetrainingTrigger()
    should_retrain, reasons = trigger.should_retrain()
    
    if not should_retrain:
        print("No retraining needed")
        # Skip downstream tasks
        context['task_instance'].xcom_push(key='skip_retrain', value=True)
        return False
    
    print(f"Retraining needed: {reasons}")
    context['task_instance'].xcom_push(key='retrain_reasons', value=reasons)
    return True

check_trigger = PythonOperator(
    task_id='check_trigger',
    python_callable=check_if_retraining_needed,
    dag=dag,
)

fetch_data = BashOperator(
    task_id='fetch_training_data',
    bash_command='python scripts/fetch_training_data.py --days 90',
    dag=dag,
)

preprocess = BashOperator(
    task_id='preprocess_data',
    bash_command='python scripts/preprocess.py',
    dag=dag,
)

train_model = BashOperator(
    task_id='train_model',
    bash_command='python scripts/train.py --experiment retraining-{{ ds }}',
    dag=dag,
)

evaluate = PythonOperator(
    task_id='evaluate_model',
    python_callable=evaluate_new_model,
    dag=dag,
)

deploy = PythonOperator(
    task_id='deploy_if_better',
    python_callable=deploy_if_improved,
    dag=dag,
)

check_trigger >> fetch_data >> preprocess >> train_model >> evaluate >> deploy

šŸ“š Online vs Batch Learning

Batch Learning (Retraining)

Train on entire dataset periodically. Most common approach for production ML.

"""
Batch retraining pipeline
"""
import mlflow
from sklearn.ensemble import RandomForestClassifier
import pandas as pd

def batch_retrain():
    """Retrain model on accumulated data"""
    
    # Fetch all training data (including new samples)
    training_data = pd.read_sql("""
        SELECT * FROM training_data
        WHERE timestamp > DATE_SUB(NOW(), INTERVAL 6 MONTH)
    """, con=db_connection)
    
    print(f"Training samples: {len(training_data)}")
    
    X = training_data.drop(['label', 'timestamp'], axis=1)
    y = training_data['label']
    
    # Split
    from sklearn.model_selection import train_test_split
    X_train, X_val, y_train, y_val = train_test_split(
        X, y, test_size=0.2, random_state=42
    )
    
    # Train
    with mlflow.start_run(run_name=f"retrain-{datetime.now().strftime('%Y%m%d')}"):
        model = RandomForestClassifier(n_estimators=200, max_depth=20)
        model.fit(X_train, y_train)
        
        # Evaluate
        train_acc = model.score(X_train, y_train)
        val_acc = model.score(X_val, y_val)
        
        print(f"Training accuracy: {train_acc:.4f}")
        print(f"Validation accuracy: {val_acc:.4f}")
        
        # Log metrics
        mlflow.log_metric("train_accuracy", train_acc)
        mlflow.log_metric("val_accuracy", val_acc)
        mlflow.log_param("n_samples", len(training_data))
        mlflow.log_param("retrain_date", datetime.now().isoformat())
        
        # Save model
        mlflow.sklearn.log_model(model, "model")
        
        return model, val_acc

Online Learning (Incremental)

Update model continuously with new samples. Suitable for streaming data and fast-changing environments.

"""
Online learning with incremental updates
"""
from sklearn.linear_model import SGDClassifier
import pickle

class OnlineLearner:
    def __init__(self, model_path='online_model.pkl'):
        self.model_path = model_path
        
        # Load or initialize model
        try:
            with open(model_path, 'rb') as f:
                self.model = pickle.load(f)
            print("Loaded existing model")
        except FileNotFoundError:
            self.model = SGDClassifier(
                loss='log',
                learning_rate='adaptive',
                eta0=0.01,
                warm_start=True  # Continue training
            )
            print("Initialized new model")
    
    def partial_fit(self, X, y, classes=None):
        """Update model with new samples"""
        if classes is None:
            classes = [0, 1]  # Binary classification
        
        self.model.partial_fit(X, y, classes=classes)
        
        # Save updated model
        with open(self.model_path, 'wb') as f:
            pickle.dump(self.model, f)
    
    def predict(self, X):
        """Make predictions"""
        return self.model.predict(X)

# Usage in production
learner = OnlineLearner()

# When new labeled data arrives
@app.post("/feedback")
async def receive_feedback(feedback: FeedbackData):
    """Receive feedback and update model"""
    
    # Store feedback
    store_feedback(feedback)
    
    # Accumulate batch
    feedback_batch = get_pending_feedback(min_batch_size=50)
    
    if len(feedback_batch) >= 50:
        # Update model
        X = feedback_batch[features]
        y = feedback_batch['label']
        
        learner.partial_fit(X, y)
        
        print(f"Model updated with {len(feedback_batch)} samples")
        mark_feedback_processed(feedback_batch)

Comparison: Batch vs Online

Aspect Batch Learning Online Learning
Update Frequency Periodic (daily/weekly/monthly) Continuous (per sample or mini-batch)
Data Requirements Full dataset each time New samples only
Algorithms Any (tree-based, neural networks, etc.) SGD-based (linear models, neural networks)
Computational Cost High per training, low frequency Low per update, high frequency
Stability More stable Can be unstable, needs careful tuning
Best For Most production systems Real-time systems, concept drift

šŸ•¶ļø Shadow Mode Testing

Shadow mode runs new model alongside production model without affecting users. Compare predictions to validate new model before full deployment.

Shadow Mode Implementation

"""
Shadow mode deployment
"""
import asyncio
from typing import Dict

class ShadowModePredictor:
    def __init__(self, production_model, shadow_model):
        self.production_model = production_model
        self.shadow_model = shadow_model
        self.comparison_log = []
    
    async def predict(self, features):
        """Make predictions with both models"""
        
        # Production prediction (synchronous)
        prod_prediction = self.production_model.predict(features)
        
        # Shadow prediction (async, non-blocking)
        shadow_task = asyncio.create_task(
            self.shadow_predict(features, prod_prediction)
        )
        
        # Return production result immediately
        return prod_prediction
    
    async def shadow_predict(self, features, prod_prediction):
        """Run shadow model and compare"""
        try:
            # Shadow prediction
            shadow_prediction = self.shadow_model.predict(features)
            
            # Log comparison
            comparison = {
                'timestamp': datetime.now(),
                'features': features,
                'production': prod_prediction,
                'shadow': shadow_prediction,
                'match': prod_prediction == shadow_prediction
            }
            
            # Store for analysis
            await self.log_comparison(comparison)
            
        except Exception as e:
            print(f"Shadow prediction failed: {e}")
            # Don't affect production
    
    async def log_comparison(self, comparison):
        """Store comparison for analysis"""
        # Write to database
        await db.execute("""
            INSERT INTO shadow_predictions
            (timestamp, production_pred, shadow_pred, match)
            VALUES (%s, %s, %s, %s)
        """, (
            comparison['timestamp'],
            comparison['production'],
            comparison['shadow'],
            comparison['match']
        ))

# FastAPI endpoint with shadow mode
@app.post("/predict")
async def predict(request: PredictRequest):
    """Prediction endpoint with shadow model"""
    
    # Load models
    production_model = load_model("production")
    shadow_model = load_model("shadow")  # New model candidate
    
    predictor = ShadowModePredictor(production_model, shadow_model)
    
    # Get prediction (shadow runs asynchronously)
    prediction = await predictor.predict(request.features)
    
    return {"prediction": prediction}

Shadow Mode Analysis

"""
Analyze shadow mode results
"""
import pandas as pd
import matplotlib.pyplot as plt

def analyze_shadow_mode(days=7):
    """Analyze shadow model performance"""
    
    # Fetch comparisons
    comparisons = pd.read_sql(f"""
        SELECT *
        FROM shadow_predictions
        WHERE timestamp > DATE_SUB(NOW(), INTERVAL {days} DAY)
    """, con=db_connection)
    
    print(f"\nšŸ“Š Shadow Mode Analysis ({len(comparisons)} predictions)")
    print("=" * 50)
    
    # Agreement rate
    agreement = comparisons['match'].mean()
    print(f"Agreement: {agreement:.2%}")
    
    # When ground truth is available
    labeled = comparisons[comparisons['actual'].notna()]
    
    if len(labeled) > 0:
        prod_acc = (labeled['production_pred'] == labeled['actual']).mean()
        shadow_acc = (labeled['shadow_pred'] == labeled['actual']).mean()
        
        print(f"\nAccuracy (on {len(labeled)} labeled samples):")
        print(f"  Production: {prod_acc:.2%}")
        print(f"  Shadow:     {shadow_acc:.2%}")
        print(f"  Improvement: {(shadow_acc - prod_acc):.2%}")
        
        # Decision
        if shadow_acc > prod_acc + 0.02:  # 2% improvement
            print("\nāœ… Shadow model performs better - Recommend promotion")
            return True
        elif shadow_acc < prod_acc - 0.02:
            print("\nāŒ Shadow model performs worse - Keep production model")
            return False
        else:
            print("\nāš ļø Performance similar - Need more data")
            return None
    
    return None

# Check shadow model weekly
if __name__ == '__main__':
    promote = analyze_shadow_mode(days=7)
    
    if promote:
        promote_shadow_to_production()

ā†©ļø Safe Rollback Strategies

Model Versioning for Rollback

"""
Model version management with rollback
"""
import mlflow
from mlflow.tracking import MlflowClient

class ModelVersionManager:
    def __init__(self, model_name):
        self.model_name = model_name
        self.client = MlflowClient()
    
    def promote_to_production(self, version):
        """Promote model version to production"""
        
        # Archive current production
        current = self.get_production_version()
        if current:
            self.client.transition_model_version_stage(
                name=self.model_name,
                version=current,
                stage="Archived"
            )
        
        # Promote new version
        self.client.transition_model_version_stage(
            name=self.model_name,
            version=version,
            stage="Production"
        )
        
        print(f"āœ… Promoted v{version} to production")
        
        # Log deployment
        self.log_deployment(version)
    
    def rollback_to_previous(self):
        """Rollback to previous production version"""
        
        # Get archived versions (sorted by creation date)
        archived = self.client.search_model_versions(
            f"name='{self.model_name}' AND run_id IS NOT NULL"
        )
        
        archived_sorted = sorted(
            [v for v in archived if v.current_stage == "Archived"],
            key=lambda x: x.creation_timestamp,
            reverse=True
        )
        
        if not archived_sorted:
            print("āŒ No previous version to rollback to")
            return False
        
        previous = archived_sorted[0]
        
        # Rollback
        self.promote_to_production(previous.version)
        
        print(f"ā†©ļø Rolled back to v{previous.version}")
        return True
    
    def get_production_version(self):
        """Get current production version"""
        versions = self.client.search_model_versions(
            f"name='{self.model_name}'"
        )
        
        for v in versions:
            if v.current_stage == "Production":
                return v.version
        
        return None
    
    def log_deployment(self, version):
        """Log deployment event"""
        import requests
        
        # Log to monitoring system
        requests.post('http://monitoring/api/deployments', json={
            'model': self.model_name,
            'version': version,
            'timestamp': datetime.now().isoformat(),
            'deployed_by': 'mlops-pipeline'
        })

# Usage
manager = ModelVersionManager("fraud_detection")

# Deploy new version
manager.promote_to_production(version=15)

# If issues detected, rollback
if detect_issues():
    manager.rollback_to_previous()

Automated Rollback on Failure

"""
Automatic rollback based on health checks
"""
import time
from datetime import datetime, timedelta

class AutoRollback:
    def __init__(
        self,
        error_rate_threshold=0.1,
        latency_threshold=1.0,
        check_duration_minutes=30
    ):
        self.error_rate_threshold = error_rate_threshold
        self.latency_threshold = latency_threshold
        self.check_duration = timedelta(minutes=check_duration_minutes)
        self.deployment_time = None
    
    def monitor_deployment(self, version):
        """Monitor new deployment and rollback if unhealthy"""
        
        self.deployment_time = datetime.now()
        print(f"šŸ” Monitoring deployment of v{version}...")
        
        while datetime.now() - self.deployment_time < self.check_duration:
            # Check health metrics
            metrics = self.get_current_metrics()
            
            # Check error rate
            if metrics['error_rate'] > self.error_rate_threshold:
                print(f"āŒ Error rate {metrics['error_rate']:.2%} exceeds threshold")
                self.trigger_rollback(version, "high_error_rate")
                return False
            
            # Check latency
            if metrics['p95_latency'] > self.latency_threshold:
                print(f"āŒ P95 latency {metrics['p95_latency']:.2f}s exceeds threshold")
                self.trigger_rollback(version, "high_latency")
                return False
            
            # Check prediction quality
            if metrics.get('accuracy') and metrics['accuracy'] < 0.8:
                print(f"āŒ Accuracy {metrics['accuracy']:.2%} too low")
                self.trigger_rollback(version, "low_accuracy")
                return False
            
            time.sleep(60)  # Check every minute
        
        print("āœ… Deployment healthy - monitoring complete")
        return True
    
    def get_current_metrics(self):
        """Fetch current production metrics"""
        # Query Prometheus
        import requests
        
        response = requests.get('http://prometheus:9090/api/v1/query', params={
            'query': 'rate(ml_prediction_requests_total{status="error"}[5m])'
        })
        error_rate = float(response.json()['data']['result'][0]['value'][1])
        
        response = requests.get('http://prometheus:9090/api/v1/query', params={
            'query': 'histogram_quantile(0.95, ml_prediction_latency_seconds)'
        })
        p95_latency = float(response.json()['data']['result'][0]['value'][1])
        
        return {
            'error_rate': error_rate,
            'p95_latency': p95_latency
        }
    
    def trigger_rollback(self, version, reason):
        """Trigger automatic rollback"""
        print(f"ā†©ļø Triggering rollback from v{version}: {reason}")
        
        manager = ModelVersionManager("fraud_detection")
        manager.rollback_to_previous()
        
        # Alert team
        send_alert(
            subject=f"🚨 Auto-rollback triggered: {reason}",
            message=f"Version {version} rolled back due to {reason}"
        )

# Use in deployment pipeline
def deploy_with_monitoring(new_version):
    """Deploy new version with automatic rollback"""
    
    manager = ModelVersionManager("fraud_detection")
    manager.promote_to_production(new_version)
    
    # Monitor for 30 minutes
    rollback = AutoRollback(check_duration_minutes=30)
    success = rollback.monitor_deployment(new_version)
    
    if not success:
        print("Deployment failed - rolled back")
    else:
        print("Deployment successful")
    
    return success

šŸŽÆ Summary

You've mastered model retraining and continuous learning:

šŸŽÆ

Trigger Conditions

Automate retraining based on performance, drift, and data availability

šŸ“š

Batch Learning

Periodic retraining on accumulated data

⚔

Online Learning

Continuous incremental updates with new samples

šŸ•¶ļø

Shadow Mode

Test new models alongside production safely

ā†©ļø

Safe Rollback

Version management and automatic rollback on failure

šŸ”„

Automation

End-to-end automated retraining pipelines

Key Takeaways

  1. Define clear trigger conditions for retraining (performance, drift, data)
  2. Choose batch vs online learning based on your use case
  3. Always test new models in shadow mode before promotion
  4. Implement automated rollback for failed deployments
  5. Monitor new deployments closely during initial period
  6. Maintain model version history for safe rollbacks
  7. Automate the entire retraining pipeline with Airflow

šŸš€ Next Steps:

Your models continuously improve! Next, you'll learn production best practices - feature stores, model governance, compliance, and cost optimization for enterprise ML systems.

Test Your Knowledge

Q1: When should you trigger model retraining?

Every hour automatically
When performance degrades, data drift detected, or sufficient new labeled data accumulated
Only when users complain
Never, models don't need retraining

Q2: What is online learning?

Training models on the internet
Training while users are online
Incremental model updates with new samples continuously, without retraining from scratch
Cloud-based training

Q3: What is shadow mode deployment?

Running new model alongside production model to compare predictions without affecting users
Deploying at night
Using dark theme
Hiding models from users

Q4: Why implement automatic rollback?

To save disk space
To train faster
To reduce costs
To quickly revert to previous version if new deployment shows high errors, latency, or accuracy issues

Q5: Batch learning vs online learning - which is more common?

Online learning for all production systems
Batch learning is more common, online learning for fast-changing environments
Both equally common
Neither, models never update