Skip to main content

MLflow

MLflow is an open-source platform for managing the end-to-end machine learning lifecycle. It tackles four key problems: experiment tracking, reproducible runs, model packaging, and model versioning. If you've ever trained a model, gotten great results, and then couldn't remember which hyperparameters you used — MLflow solves that.

Why Experiment Tracking Matters

In a typical ML project, you run dozens or hundreds of experiments. Each experiment has:

  • Parameters: Learning rate, batch size, number of layers, etc.
  • Metrics: Accuracy, F1 score, loss, etc.
  • Artifacts: Trained model files, plots, data snapshots

Without a tracking system, you end up with folders like model_v3_final_really_final/ and no idea which run produced the best results. MLflow brings discipline to this chaos.

MLflow Tracking

MLflow Tracking lets you log parameters, metrics, and artifacts from each training run. Runs are organized into experiments, and you can compare them side-by-side in the MLflow UI.

Basic Experiment Logging

python
import mlflow
import mlflow.sklearn
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
from sklearn.datasets import load_iris

# Set the experiment name (creates it if it doesn't exist)
mlflow.set_experiment("iris-classification")

# Load data
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(
iris.data, iris.target, test_size=0.2, random_state=42
)

# Start a run — everything logged here is grouped together
with mlflow.start_run(run_name="rf-default"):
# Log parameters
mlflow.log_param("n_estimators", 100)
mlflow.log_param("max_depth", None)
mlflow.log_param("random_state", 42)

# Train the model
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)

# Evaluate and log metrics
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred, average="weighted")

mlflow.log_metric("accuracy", accuracy)
mlflow.log_metric("f1_score", f1)

# Log the model as an artifact
mlflow.sklearn.log_model(model, "model")

# Log a custom artifact (e.g., a plot)
import matplotlib.pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplay

fig, ax = plt.subplots(figsize=(6, 6))
ConfusionMatrixDisplay.from_estimator(model, X_test, y_test, ax=ax)
plt.title("Confusion Matrix")
plt.savefig("confusion_matrix.png")
mlflow.log_artifact("confusion_matrix.png")

print(f"Run ID: {mlflow.active_run().info.run_id}")
print(f"Accuracy: {accuracy:.4f}, F1: {f1:.4f}")

Hyperparameter Sweep with Tracking

python
import mlflow
import numpy as np

mlflow.set_experiment("iris-hp-sweep")

for n_estimators in [50, 100, 200]:
for max_depth in [3, 5, None]:
with mlflow.start_run(run_name=f"rf-ne{n_estimators}-md{max_depth}"):
model = RandomForestClassifier(
n_estimators=n_estimators,
max_depth=max_depth,
random_state=42,
)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)

accuracy = accuracy_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred, average="weighted")

mlflow.log_params({
"n_estimators": n_estimators,
"max_depth": max_depth,
})
mlflow.log_metrics({"accuracy": accuracy, "f1_score": f1})
mlflow.sklearn.log_model(model, "model")

Running the MLflow UI

bash
# Start the tracking server locally
mlflow ui --port 5000

# Or with a backend store for persistence
mlflow server \
--backend-store-uri sqlite:///mlflow.db \
--default-artifact-root ./artifacts \
--host 0.0.0.0 \
--port 5000

Then open http://localhost:5000 to compare runs visually.

MLflow Projects

MLflow Projects let you package your training code so that anyone can reproduce your run with a single command. A project is defined by an MLproject file:

yaml
# MLproject
name: iris-classifier

conda_env: conda.yaml

entry_points:
main:
parameters:
n_estimators: {type: int, default: 100}
max_depth: {type: int, default: 5}
command: "python train.py --n_estimators {n_estimators} --max_depth {max_depth}"
yaml
# conda.yaml
name: iris-env
channels:
- defaults
dependencies:
- python=3.11
- scikit-learn=1.3
- mlflow=2.9
- pip:
- matplotlib

Run the project:

bash
mlflow run . -P n_estimators=200 -P max_depth=5

MLflow Models

MLflow Models provide a standard format for packaging ML models so they can be used with diverse downstream tools — from batch scoring to real-time serving.

Model Flavors

python
# Log with multiple flavors for flexibility
import mlflow

with mlflow.start_run():
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)

# Log as sklearn flavor (default)
mlflow.sklearn.log_model(
model,
"model",
# Add input example for schema inference
input_example=X_train[:3],
# Add signature for validation
signature=mlflow.models.infer_signature(
X_train, model.predict(X_train)
),
)

Serving a Model

bash
# Serve a logged model locally
mlflow models serve -m "runs:/<run_id>/model" -p 5001

# Query the serving endpoint
curl -X POST http://localhost:5001/invocations \
-H "Content-Type: application/json" \
-d '{"dataframe_split": {"columns": ["sepal_length", "sepal_width", "petal_length", "petal_width"], "data": [[5.1, 3.5, 1.4, 0.2]]}}'

Model Registry

The Model Registry provides a centralized store for managing the full lifecycle of a model — from development to staging to production.

Registering and Transitioning Models

python
import mlflow
from mlflow.tracking import MlflowClient

client = MlflowClient()

# Register a model from a run
result = mlflow.register_model(
"runs:/<run_id>/model",
"iris-classifier"
)

# Transition to staging
client.transition_model_version_stage(
name="iris-classifier",
version=1,
stage="Staging"
)

# After validation, promote to production
client.transition_model_version_stage(
name="iris-classifier",
version=1,
stage="Production"
)

# List all versions of a model
versions = client.search_model_versions("name='iris-classifier'")
for v in versions:
print(f"Version {v.version}: {v.current_stage} (run_id={v.run_id})")

Loading a Production Model

python
import mlflow.sklearn

# Load the current production model
model = mlflow.sklearn.load_model("models:/iris-classifier/Production")
prediction = model.predict([[5.1, 3.5, 1.4, 0.2]])
print(prediction)

Querying Runs Programmatically

python
from mlflow.tracking import MlflowClient

client = MlflowClient()

# Search for the best run by accuracy
runs = client.search_runs(
experiment_ids=["1"],
filter_string="metrics.accuracy > 0.95",
order_by=["metrics.accuracy DESC"],
max_results=5,
)

for run in runs:
print(f"Run: {run.info.run_id}")
print(f" Accuracy: {run.data.metrics['accuracy']:.4f}")
print(f" Params: {run.data.params}")
MLflow on GCP

You can use GCS as the artifact store and Cloud SQL as the backend store for a production MLflow setup:

bash
mlflow server \
--backend-store-uri postgresql://user:pass@<cloud-sql-host>:5432/mlflow \
--default-artifact-root gs://your-bucket/mlflow/artifacts \
--host 0.0.0.0