MLflow & Machine Learning on Databricks

Hard 35 min read

What is MLflow on Databricks?

Why MLflow Matters

The Problem: Data science teams struggle with experiment reproducibility, model versioning, and the gap between training a model in a notebook and deploying it to production.

The Solution: MLflow is an open-source platform for the complete ML lifecycle: experiment tracking, model packaging, registry, and serving. Databricks provides a fully managed, integrated MLflow experience.

Real Impact: Teams using MLflow on Databricks deploy models to production 3x faster and reduce the time to reproduce experiments from days to seconds.

Real-World Analogy

Think of MLflow like a professional kitchen management system:

  • Experiment Tracking = Recipe notebook where you log every ingredient, quantity, and outcome
  • Model Registry = The approved recipe book -- only tested recipes make it in
  • Model Serving = The actual restaurant kitchen serving dishes to customers
  • Feature Store = Pre-prepped ingredients ready to use in any recipe
MLflow integration with Databricks showing how experiment tracking, model registry, and serving connect
MLflow integrates natively with Databricks, providing a unified platform for experiment tracking, model registry, and deployment
MLflow Architecture on Databricks
Experiment Tracking Parameters & Hyperparams Metrics (accuracy, loss, F1) Artifacts (model files, plots) Register Model Registry Version Management Stage Transitions Approval Workflows Deploy Model Serving REST API Endpoints Auto-scaling A/B Testing AutoML Automated model selection Hyperparameter tuning Feature importance Feature Store Centralized features Point-in-time lookups Online & offline serving Registry Lifecycle None --> Staging Staging --> Production Production --> Archived
Key Takeaway: MLflow on Databricks provides a unified ML lifecycle: track experiments, register models, and deploy to serving endpoints -- all within the same platform, with automatic integration into Unity Catalog for governance.

Experiment Tracking

MLflow Tracking lets you log parameters, metrics, and artifacts for every training run. On Databricks, tracking is automatic -- every notebook run is logged as an experiment.

MLflow experiment tracking UI showing a table of runs with parameters, metrics, and model artifacts
The MLflow Tracking UI displays all experiment runs with their parameters, metrics, and logged artifacts for easy comparison
MLflow Traces tab showing detailed execution traces for model inference calls
The Traces tab in MLflow provides detailed execution traces for monitoring model inference calls, latency, and token usage
PySpark/Python - MLflow Experiment Tracking
import mlflow
import mlflow.sklearn
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, precision_score

# Set the experiment name
mlflow.set_experiment("/Users/[email protected]/churn-prediction")

# Load data from Silver layer
df = spark.table("silver.customer_features").toPandas()
X = df.drop("churned", axis=1)
y = df["churned"]
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42)

# Train with MLflow auto-logging
mlflow.sklearn.autolog()

with mlflow.start_run(run_name="rf-baseline") as run:
    # Log custom parameters
    mlflow.log_param("dataset_version", "2024-03-15")
    mlflow.log_param("feature_count", X.shape[1])

    # Train model
    model = RandomForestClassifier(
        n_estimators=100, max_depth=10, random_state=42)
    model.fit(X_train, y_train)

    # Evaluate and log metrics
    predictions = model.predict(X_test)
    mlflow.log_metric("accuracy", accuracy_score(y_test, predictions))
    mlflow.log_metric("f1_score", f1_score(y_test, predictions))
    mlflow.log_metric("precision", precision_score(y_test, predictions))

    # Log the model
    mlflow.sklearn.log_model(model, "random_forest_model")

    print(f"Run ID: {run.info.run_id}")
    print(f"Accuracy: {accuracy_score(y_test, predictions):.4f}")
Output
MLflow Run: a1b2c3d4e5f6
  Metrics:
    rmse:     0.0847
    mae:      0.0623
    r2:       0.9234
  Parameters:
    n_estimators: 200
    max_depth:    8
    learning_rate: 0.05
  Artifacts: model/, feature_importance.png

Model Registry

The Model Registry provides a centralized store for managing model versions, stage transitions, and approval workflows. In Unity Catalog, models are governed just like tables.

MLflow Model Registry model version page showing version details, stage, and source run
The Model Registry version page shows model details, the source experiment run, stage transitions, and deployment status
Python - Model Registry Operations
import mlflow
from mlflow import MlflowClient

client = MlflowClient()

# Register a model from an experiment run
model_uri = f"runs:/{run.info.run_id}/random_forest_model"
model_details = mlflow.register_model(
    model_uri,
    "prod_catalog.ml_models.churn_predictor"  # UC model path
)

# Add model description
client.update_registered_model(
    name="prod_catalog.ml_models.churn_predictor",
    description="Predicts customer churn based on usage features"
)

# Transition model to Staging
client.set_registered_model_alias(
    name="prod_catalog.ml_models.churn_predictor",
    alias="staging",
    version=model_details.version
)

# After validation, promote to Production
client.set_registered_model_alias(
    name="prod_catalog.ml_models.churn_predictor",
    alias="production",
    version=model_details.version
)

# Load a model by alias for inference
model = mlflow.pyfunc.load_model(
    "models:/prod_catalog.ml_models.churn_predictor@production"
)
predictions = model.predict(new_data)
Output
Model Registry: fraud_detection_model
  Version 3 -> Stage: Production
  Version 2 -> Stage: Archived
  Version 1 -> Stage: Archived
Transition: v3 approved by [email protected] on 2024-03-15

Common Mistake

Wrong: Promoting models directly from experiment to Production stage without validation

Why it fails: Untested models in production cause prediction errors, data drift goes undetected, and rollback becomes difficult.

Instead: Use a Staging stage for A/B testing and validation. Automate promotion with CI/CD checks on accuracy thresholds, data drift metrics, and latency benchmarks.

AutoML

Databricks AutoML automatically trains and tunes models, generating a leaderboard of results. It creates reproducible notebooks for each trial so you can inspect and customize the best approach.

Python - Databricks AutoML
from databricks import automl

# Run AutoML classification
summary = automl.classify(
    dataset=spark.table("silver.customer_features"),
    target_col="churned",
    primary_metric="f1",
    timeout_minutes=30,
    max_trials=20
)

# View results
print(f"Best trial: {summary.best_trial}")
print(f"Best F1: {summary.best_trial.metrics['test_f1_score']:.4f}")

# The best model is automatically logged to MLflow
best_model = mlflow.pyfunc.load_model(
    f"runs:/{summary.best_trial.mlflow_run_id}/model"
)

# AutoML also supports regression and forecasting
reg_summary = automl.regress(
    dataset=spark.table("silver.house_features"),
    target_col="price",
    primary_metric="rmse",
    timeout_minutes=20
)

Feature Store

The Feature Store provides a centralized repository for ML features, ensuring consistency between training and serving. Features are stored as Delta tables governed by Unity Catalog.

Python - Feature Store Operations
from databricks.feature_engineering import FeatureEngineeringClient, FeatureLookup

fe = FeatureEngineeringClient()

# Create a feature table
customer_features = spark.sql("""
    SELECT
        user_id,
        COUNT(*) as total_orders,
        AVG(amount) as avg_order_value,
        MAX(order_date) as last_order_date,
        DATEDIFF(current_date(), MAX(order_date)) as days_since_last_order
    FROM silver.orders
    GROUP BY user_id
""")

fe.create_table(
    name="prod_catalog.ml_features.customer_features",
    primary_keys=["user_id"],
    df=customer_features,
    description="Customer behavior features for churn prediction"
)

# Train a model with feature lookups
training_set = fe.create_training_set(
    df=spark.table("silver.churn_labels"),  # user_id + label
    feature_lookups=[
        FeatureLookup(
            table_name="prod_catalog.ml_features.customer_features",
            lookup_key="user_id"
        )
    ],
    label="churned"
)

# Convert to pandas for sklearn training
training_df = training_set.load_df().toPandas()
Deep Dive: Feature Store Online vs Offline

The Databricks Feature Store maintains two copies: an offline store (Delta tables for batch training) and an online store (low-latency key-value store for real-time serving). During training, features are read from Delta. During inference via Model Serving, features are automatically fetched from the online store using the primary key. This ensures training-serving consistency -- the same feature engineering code produces both stores, eliminating the common "training-serving skew" problem where features are computed differently in production.

Model Serving

Databricks Model Serving deploys MLflow models as auto-scaling REST API endpoints. It handles infrastructure, scaling, and monitoring automatically.

Python - Deploy and Query a Model Endpoint
import requests, json

# Model Serving is configured via UI or REST API
# After deployment, query the endpoint:

endpoint_url = (
    "https://my-workspace.databricks.com"
    "/serving-endpoints/churn-predictor/invocations"
)

headers = {
    "Authorization": f"Bearer {token}",
    "Content-Type": "application/json"
}

# Single prediction
payload = {
    "dataframe_records": [{
        "total_orders": 15,
        "avg_order_value": 82.50,
        "days_since_last_order": 45
    }]
}

response = requests.post(endpoint_url, headers=headers,
                         json=payload)
print(response.json())
# {"predictions": [0]}  -- 0 = not churning, 1 = churning

# Batch prediction using the model directly
import mlflow

model = mlflow.pyfunc.load_model(
    "models:/prod_catalog.ml_models.churn_predictor@production"
)
batch_predictions = model.predict(
    spark.table("silver.customer_features").toPandas()
)
Key Takeaway: Model Serving endpoints support traffic splitting for A/B tests, auto-scaling based on request volume, and GPU inference for large models. Use the /serving-endpoints/ REST API for programmatic deployment in CI/CD pipelines.
Output
$ curl -X POST https://workspace.cloud.databricks.com/serving-endpoints/fraud-model/invocations   -H "Authorization: Bearer $TOKEN"   -d '{"instances": [{"amount": 499.99, "merchant": "electronics", "hour": 2}]}'

{"predictions": [{"fraud_probability": 0.87, "label": "FRAUD"}]}

Practice Problems

Problem 1: Design an ML Pipeline

Medium

Design an end-to-end ML pipeline for predicting delivery delays. Data comes from silver.orders and silver.logistics tables. The model needs to be retrained weekly and served via REST API.

Problem 2: A/B Testing Models

Hard

You have two candidate models for fraud detection: a gradient boosted tree (fast, 92% recall) and a deep learning model (slower, 97% recall). Design an A/B testing strategy using Model Serving to determine which performs better in production.

Problem 3: Feature Store Design

Medium

Design a Feature Store schema for a recommendation engine. Users browse products, add to cart, and make purchases. The model needs real-time features (current session) and historical features (past 30 days of behavior).