MLflow
MLflow is an open-source platform for managing the end-to-end machine learning lifecycle. It tackles four key problems: experiment tracking, reproducible runs, model packaging, and model versioning. If you've ever trained a model, gotten great results, and then couldn't remember which hyperparameters you used — MLflow solves that.
Why Experiment Tracking Matters
In a typical ML project, you run dozens or hundreds of experiments. Each experiment has:
- Parameters: Learning rate, batch size, number of layers, etc.
- Metrics: Accuracy, F1 score, loss, etc.
- Artifacts: Trained model files, plots, data snapshots
Without a tracking system, you end up with folders like model_v3_final_really_final/ and no idea which run produced the best results. MLflow brings discipline to this chaos.
MLflow Tracking
MLflow Tracking lets you log parameters, metrics, and artifacts from each training run. Runs are organized into experiments, and you can compare them side-by-side in the MLflow UI.
Basic Experiment Logging
import mlflow
import mlflow.sklearn
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
from sklearn.datasets import load_iris
# Set the experiment name (creates it if it doesn't exist)
mlflow.set_experiment("iris-classification")
# Load data
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(
iris.data, iris.target, test_size=0.2, random_state=42
)
# Start a run — everything logged here is grouped together
with mlflow.start_run(run_name="rf-default"):
# Log parameters
mlflow.log_param("n_estimators", 100)
mlflow.log_param("max_depth", None)
mlflow.log_param("random_state", 42)
# Train the model
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)
# Evaluate and log metrics
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred, average="weighted")
mlflow.log_metric("accuracy", accuracy)
mlflow.log_metric("f1_score", f1)
# Log the model as an artifact
mlflow.sklearn.log_model(model, "model")
# Log a custom artifact (e.g., a plot)
import matplotlib.pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplay
fig, ax = plt.subplots(figsize=(6, 6))
ConfusionMatrixDisplay.from_estimator(model, X_test, y_test, ax=ax)
plt.title("Confusion Matrix")
plt.savefig("confusion_matrix.png")
mlflow.log_artifact("confusion_matrix.png")
print(f"Run ID: {mlflow.active_run().info.run_id}")
print(f"Accuracy: {accuracy:.4f}, F1: {f1:.4f}")
Hyperparameter Sweep with Tracking
import mlflow
import numpy as np
mlflow.set_experiment("iris-hp-sweep")
for n_estimators in [50, 100, 200]:
for max_depth in [3, 5, None]:
with mlflow.start_run(run_name=f"rf-ne{n_estimators}-md{max_depth}"):
model = RandomForestClassifier(
n_estimators=n_estimators,
max_depth=max_depth,
random_state=42,
)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred, average="weighted")
mlflow.log_params({
"n_estimators": n_estimators,
"max_depth": max_depth,
})
mlflow.log_metrics({"accuracy": accuracy, "f1_score": f1})
mlflow.sklearn.log_model(model, "model")
Running the MLflow UI
# Start the tracking server locally
mlflow ui --port 5000
# Or with a backend store for persistence
mlflow server \
--backend-store-uri sqlite:///mlflow.db \
--default-artifact-root ./artifacts \
--host 0.0.0.0 \
--port 5000
Then open http://localhost:5000 to compare runs visually.
MLflow Projects
MLflow Projects let you package your training code so that anyone can reproduce your run with a single command. A project is defined by an MLproject file:
# MLproject
name: iris-classifier
conda_env: conda.yaml
entry_points:
main:
parameters:
n_estimators: {type: int, default: 100}
max_depth: {type: int, default: 5}
command: "python train.py --n_estimators {n_estimators} --max_depth {max_depth}"
# conda.yaml
name: iris-env
channels:
- defaults
dependencies:
- python=3.11
- scikit-learn=1.3
- mlflow=2.9
- pip:
- matplotlib
Run the project:
mlflow run . -P n_estimators=200 -P max_depth=5
MLflow Models
MLflow Models provide a standard format for packaging ML models so they can be used with diverse downstream tools — from batch scoring to real-time serving.
Model Flavors
# Log with multiple flavors for flexibility
import mlflow
with mlflow.start_run():
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)
# Log as sklearn flavor (default)
mlflow.sklearn.log_model(
model,
"model",
# Add input example for schema inference
input_example=X_train[:3],
# Add signature for validation
signature=mlflow.models.infer_signature(
X_train, model.predict(X_train)
),
)
Serving a Model
# Serve a logged model locally
mlflow models serve -m "runs:/<run_id>/model" -p 5001
# Query the serving endpoint
curl -X POST http://localhost:5001/invocations \
-H "Content-Type: application/json" \
-d '{"dataframe_split": {"columns": ["sepal_length", "sepal_width", "petal_length", "petal_width"], "data": [[5.1, 3.5, 1.4, 0.2]]}}'
Model Registry
The Model Registry provides a centralized store for managing the full lifecycle of a model — from development to staging to production.
Registering and Transitioning Models
import mlflow
from mlflow.tracking import MlflowClient
client = MlflowClient()
# Register a model from a run
result = mlflow.register_model(
"runs:/<run_id>/model",
"iris-classifier"
)
# Transition to staging
client.transition_model_version_stage(
name="iris-classifier",
version=1,
stage="Staging"
)
# After validation, promote to production
client.transition_model_version_stage(
name="iris-classifier",
version=1,
stage="Production"
)
# List all versions of a model
versions = client.search_model_versions("name='iris-classifier'")
for v in versions:
print(f"Version {v.version}: {v.current_stage} (run_id={v.run_id})")
Loading a Production Model
import mlflow.sklearn
# Load the current production model
model = mlflow.sklearn.load_model("models:/iris-classifier/Production")
prediction = model.predict([[5.1, 3.5, 1.4, 0.2]])
print(prediction)
Querying Runs Programmatically
from mlflow.tracking import MlflowClient
client = MlflowClient()
# Search for the best run by accuracy
runs = client.search_runs(
experiment_ids=["1"],
filter_string="metrics.accuracy > 0.95",
order_by=["metrics.accuracy DESC"],
max_results=5,
)
for run in runs:
print(f"Run: {run.info.run_id}")
print(f" Accuracy: {run.data.metrics['accuracy']:.4f}")
print(f" Params: {run.data.params}")
You can use GCS as the artifact store and Cloud SQL as the backend store for a production MLflow setup:
mlflow server \
--backend-store-uri postgresql://user:pass@<cloud-sql-host>:5432/mlflow \
--default-artifact-root gs://your-bucket/mlflow/artifacts \
--host 0.0.0.0