Model Registry และ Feature Flags คืออะไร
Model Registry เป็นระบบจัดเก็บและจัดการ machine learning models แบบ centralized ทำหน้าที่เก็บ model artifacts (weights, configs), track model versions และ metadata, จัดการ model lifecycle (staging, production, archived) และเป็น single source of truth สำหรับทุก model ในองค์กร
Feature Flags (Feature Toggles) เป็นเทคนิคที่ช่วยเปิดปิด features ได้โดยไม่ต้อง deploy code ใหม่ ใช้ control ว่า user กลุ่มไหนจะเห็น feature ไหน สามารถ rollback ได้ทันทีโดยปิด flag และทำ gradual rollout ได้ (เช่น เปิดให้ 10% ของ users ก่อน)
การรวม Model Registry กับ Feature Flags ทำให้สามารถ deploy ML models แบบ progressive rollout ได้ เช่น ปล่อย model ใหม่ให้ 5% ของ traffic ก่อน ถ้า metrics ดีค่อยเพิ่มเป็น 25%, 50%, 100% ถ้ามีปัญหาสามารถ rollback ทันทีโดยสลับ flag กลับไป model เดิม
เครื่องมือที่นิยมใช้ได้แก่ MLflow สำหรับ Model Registry, LaunchDarkly หรือ Unleash สำหรับ Feature Flags, Weights and Biases สำหรับ experiment tracking และ Seldon Core หรือ BentoML สำหรับ model serving
ออกแบบ Model Registry ที่รองรับ Feature Flags
สถาปัตยกรรมของระบบ
# === Architecture ===
#
# ┌──────────────┐ ┌──────────────┐
# │ ML Training │ │ Feature Flag │
# │ Pipeline │ │ Service │
# │ (Airflow) │ │ (Unleash) │
# └──────┬───────┘ └──────┬───────┘
# │ │
# ┌────▼─────────────────────▼────┐
# │ Model Registry │
# │ ┌─────────┐ ┌────────────┐ │
# │ │ MLflow │ │ Metadata │ │
# │ │ Models │ │ Store (DB) │ │
# │ └─────────┘ └────────────┘ │
# └──────────────┬────────────────┘
# │
# ┌──────────────▼────────────────┐
# │ Model Serving Layer │
# │ ┌────────┐ ┌────────────┐ │
# │ │Model A │ │ Model B │ │
# │ │(prod) │ │ (canary) │ │
# │ └────────┘ └────────────┘ │
# └──────────────┬────────────────┘
# │
# ┌──────────────▼────────────────┐
# │ Traffic Router │
# │ (Based on Feature Flags) │
# │ 95% -> Model A (prod) │
# │ 5% -> Model B (canary) │
# └───────────────────────────────┘
# === MLflow Model Registry Setup ===
# ติดตั้ง MLflow
# pip install mlflow[extras] boto3 psycopg2-binary
# รัน MLflow Server
# mlflow server \
# --backend-store-uri postgresql://user:pass@localhost:5432/mlflow \
# --default-artifact-root s3://mlflow-artifacts/ \
# --host 0.0.0.0 --port 5000
# Docker Compose
# services:
# mlflow:
# image: ghcr.io/mlflow/mlflow:latest
# ports: ["5000:5000"]
# environment:
# MLFLOW_BACKEND_STORE_URI: postgresql://user:pass@postgres:5432/mlflow
# MLFLOW_DEFAULT_ARTIFACT_ROOT: s3://mlflow-artifacts/
# AWS_ACCESS_KEY_ID:
# AWS_SECRET_ACCESS_KEY:
# command: >
# mlflow server --host 0.0.0.0 --port 5000
#
# postgres:
# image: postgres:16
# environment:
# POSTGRES_DB: mlflow
# POSTGRES_USER: user
# POSTGRES_PASSWORD: pass
# volumes: ["pgdata:/var/lib/postgresql/data"]
#
# unleash:
# image: unleashorg/unleash-server:latest
# ports: ["4242:4242"]
# environment:
# DATABASE_URL: postgresql://user:pass@postgres:5432/unleash
# INIT_ADMIN_API_TOKENS: "*:*.unleash-api-token"
#
# volumes:
# pgdata:
สร้าง Model Registry Service ด้วย Python
Python service สำหรับจัดการ models
#!/usr/bin/env python3
# model_registry.py — Model Registry with Feature Flag Integration
import mlflow
from mlflow.tracking import MlflowClient
import json
import logging
from datetime import datetime
from typing import Optional, Dict, List
from dataclasses import dataclass, asdict
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("model_registry")
@dataclass
class ModelVersion:
name: str
version: int
stage: str
run_id: str
metrics: Dict[str, float]
feature_flag: str
rollout_percentage: int = 0
created_at: str = ""
class ModelRegistryService:
def __init__(self, mlflow_uri="http://localhost:5000"):
mlflow.set_tracking_uri(mlflow_uri)
self.client = MlflowClient(mlflow_uri)
def register_model(self, model_name, run_id, metrics, tags=None):
result = mlflow.register_model(
f"runs:/{run_id}/model",
model_name,
)
version = result.version
# Set tags
if tags:
for key, value in tags.items():
self.client.set_model_version_tag(model_name, version, key, value)
# Store metrics as tags
for metric_name, metric_value in metrics.items():
self.client.set_model_version_tag(
model_name, version, f"metric_{metric_name}", str(metric_value)
)
logger.info(f"Registered {model_name} v{version} (run: {run_id})")
return ModelVersion(
name=model_name, version=int(version),
stage="None", run_id=run_id, metrics=metrics,
feature_flag=f"model_{model_name}_v{version}",
created_at=datetime.utcnow().isoformat(),
)
def promote_model(self, model_name, version, stage):
valid_stages = ["Staging", "Production", "Archived"]
if stage not in valid_stages:
raise ValueError(f"Invalid stage: {stage}. Must be one of {valid_stages}")
self.client.transition_model_version_stage(
name=model_name, version=str(version), stage=stage,
)
logger.info(f"Promoted {model_name} v{version} to {stage}")
def get_production_model(self, model_name):
versions = self.client.get_latest_versions(model_name, stages=["Production"])
if not versions:
return None
return versions[0]
def get_model_history(self, model_name) -> List[Dict]:
versions = self.client.search_model_versions(f"name='{model_name}'")
history = []
for v in versions:
tags = {t.key: t.value for t in (v.tags or {}).items()} if hasattr(v, 'tags') else {}
history.append({
"version": v.version,
"stage": v.current_stage,
"run_id": v.run_id,
"created": v.creation_timestamp,
"tags": tags,
})
return sorted(history, key=lambda x: x["version"], reverse=True)
def compare_models(self, model_name, version_a, version_b):
va = self.client.get_model_version(model_name, str(version_a))
vb = self.client.get_model_version(model_name, str(version_b))
run_a = self.client.get_run(va.run_id)
run_b = self.client.get_run(vb.run_id)
metrics_a = run_a.data.metrics
metrics_b = run_b.data.metrics
comparison = {}
all_metrics = set(list(metrics_a.keys()) + list(metrics_b.keys()))
for metric in all_metrics:
val_a = metrics_a.get(metric, None)
val_b = metrics_b.get(metric, None)
if val_a is not None and val_b is not None:
diff = val_b - val_a
pct = (diff / val_a * 100) if val_a != 0 else 0
comparison[metric] = {
"version_a": val_a,
"version_b": val_b,
"diff": round(diff, 4),
"pct_change": round(pct, 2),
}
return comparison
# ใช้งาน
registry = ModelRegistryService()
# Register a new model
# model_info = registry.register_model(
# model_name="recommendation_model",
# run_id="abc123",
# metrics={"accuracy": 0.92, "f1_score": 0.89, "latency_ms": 45},
# tags={"team": "ml-platform", "framework": "pytorch"},
# )
Feature Flag Integration สำหรับ ML Models
รวม Feature Flags เข้ากับ Model Serving
#!/usr/bin/env python3
# feature_flag_models.py — Feature Flag Integration for ML Models
import requests
import random
import hashlib
import json
import logging
from typing import Optional, Dict, Any
from functools import lru_cache
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("ff_models")
class FeatureFlagClient:
def __init__(self, unleash_url="http://localhost:4242/api", api_token=""):
self.base_url = unleash_url
self.headers = {
"Authorization": api_token,
"Content-Type": "application/json",
}
self._cache = {}
def is_enabled(self, flag_name, user_id=None, default=False):
try:
context = {}
if user_id:
context["userId"] = str(user_id)
resp = requests.post(
f"{self.base_url}/client/features/{flag_name}",
headers=self.headers,
json={"context": context},
timeout=2,
)
if resp.status_code == 200:
return resp.json().get("enabled", default)
return default
except Exception as e:
logger.warning(f"Feature flag check failed: {e}")
return default
def get_variant(self, flag_name, user_id=None):
try:
resp = requests.get(
f"{self.base_url}/client/features/{flag_name}",
headers=self.headers,
timeout=2,
)
if resp.status_code == 200:
data = resp.json()
variants = data.get("variants", [])
if not variants:
return None
if user_id:
hash_val = int(hashlib.md5(str(user_id).encode()).hexdigest(), 16)
total_weight = sum(v.get("weight", 0) for v in variants)
selected = hash_val % total_weight
cumulative = 0
for v in variants:
cumulative += v.get("weight", 0)
if selected < cumulative:
return v.get("name")
return variants[0].get("name")
return None
except Exception:
return None
class ModelRouter:
def __init__(self, registry_service, ff_client):
self.registry = registry_service
self.ff = ff_client
self.models = {}
def load_model(self, model_name, version):
key = f"{model_name}_v{version}"
if key not in self.models:
import mlflow
model_uri = f"models:/{model_name}/{version}"
self.models[key] = mlflow.pyfunc.load_model(model_uri)
logger.info(f"Loaded model: {key}")
return self.models[key]
def predict(self, model_name, input_data, user_id=None):
# Check feature flag for model version
flag_name = f"model_{model_name}"
variant = self.ff.get_variant(flag_name, user_id)
if variant:
version = variant.replace("v", "")
else:
prod_model = self.registry.get_production_model(model_name)
version = prod_model.version if prod_model else "1"
model = self.load_model(model_name, version)
prediction = model.predict(input_data)
logger.info(f"Predicted with {model_name} v{version} for user {user_id}")
return {
"prediction": prediction,
"model_name": model_name,
"model_version": version,
"user_id": user_id,
}
# === Feature Flag Configuration (Unleash API) ===
def create_model_feature_flag(unleash_url, api_token, model_name, variants):
"""
สร้าง feature flag สำหรับ model routing
variants = [
{"name": "v3", "weight": 950}, # 95% traffic
{"name": "v4", "weight": 50}, # 5% traffic (canary)
]
"""
resp = requests.post(
f"{unleash_url}/api/admin/projects/default/features",
headers={"Authorization": api_token, "Content-Type": "application/json"},
json={
"name": f"model_{model_name}",
"type": "experiment",
"description": f"Model routing for {model_name}",
"enabled": True,
"strategies": [{
"name": "flexibleRollout",
"parameters": {
"rollout": "100",
"stickiness": "userId",
},
}],
"variants": variants,
},
)
return resp.json()
# ตัวอย่างการใช้
# create_model_feature_flag(
# "http://localhost:4242",
# "api-token",
# "recommendation",
# [{"name": "v3", "weight": 950}, {"name": "v4", "weight": 50}]
# )
Canary Deployment และ A/B Testing
ระบบ canary deployment สำหรับ ML models
#!/usr/bin/env python3
# canary_deployment.py — Canary Deployment for ML Models
import time
import json
import logging
from datetime import datetime, timedelta
from typing import Dict, List
from dataclasses import dataclass
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("canary")
@dataclass
class CanaryConfig:
model_name: str
current_version: str
canary_version: str
stages: List[Dict] # [{"percentage": 5, "duration_min": 30}, ...]
success_criteria: Dict[str, float] # {"accuracy": 0.9, "latency_p95_ms": 100}
auto_rollback: bool = True
class CanaryDeployment:
def __init__(self, ff_client, metrics_client=None):
self.ff = ff_client
self.metrics = metrics_client
self.deployments = {}
def start_canary(self, config: CanaryConfig):
deployment_id = f"{config.model_name}_{datetime.utcnow().strftime('%Y%m%d%H%M')}"
self.deployments[deployment_id] = {
"config": config,
"current_stage": 0,
"status": "in_progress",
"started_at": datetime.utcnow().isoformat(),
"metrics_history": [],
}
logger.info(f"Started canary: {deployment_id}")
logger.info(f" Current: {config.current_version}, Canary: {config.canary_version}")
logger.info(f" Stages: {config.stages}")
self._execute_canary(deployment_id)
return deployment_id
def _execute_canary(self, deployment_id):
deployment = self.deployments[deployment_id]
config = deployment["config"]
for i, stage in enumerate(config.stages):
percentage = stage["percentage"]
duration = stage["duration_min"]
logger.info(f"Stage {i+1}: {percentage}% canary for {duration} minutes")
# Update feature flag
self._update_traffic_split(config.model_name, config.current_version,
config.canary_version, percentage)
deployment["current_stage"] = i
# Monitor for duration
end_time = datetime.utcnow() + timedelta(minutes=duration)
while datetime.utcnow() < end_time:
metrics = self._collect_metrics(config.model_name, config.canary_version)
deployment["metrics_history"].append({
"stage": i, "timestamp": datetime.utcnow().isoformat(),
"metrics": metrics,
})
# Check success criteria
if config.auto_rollback and not self._check_criteria(metrics, config.success_criteria):
logger.warning(f"Canary failed criteria at stage {i+1}!")
self._rollback(deployment_id)
return
time.sleep(60) # Check every minute
# All stages passed — promote canary to production
logger.info("All canary stages passed! Promoting to 100%")
self._promote(deployment_id)
def _update_traffic_split(self, model_name, current_ver, canary_ver, canary_pct):
current_weight = 1000 - (canary_pct * 10)
canary_weight = canary_pct * 10
logger.info(f"Traffic split: {current_ver}={100-canary_pct}%, {canary_ver}={canary_pct}%")
# Update Unleash feature flag variants
def _collect_metrics(self, model_name, version):
# Collect from Prometheus/Grafana
return {
"accuracy": 0.93,
"latency_p95_ms": 85,
"error_rate": 0.01,
"throughput_rps": 150,
}
def _check_criteria(self, metrics, criteria):
for metric, threshold in criteria.items():
actual = metrics.get(metric)
if actual is None:
continue
if metric in ["latency_p95_ms", "error_rate"]:
if actual > threshold:
logger.warning(f"Criteria failed: {metric}={actual} > {threshold}")
return False
else:
if actual < threshold:
logger.warning(f"Criteria failed: {metric}={actual} < {threshold}")
return False
return True
def _rollback(self, deployment_id):
deployment = self.deployments[deployment_id]
config = deployment["config"]
logger.info(f"Rolling back: {config.canary_version} -> {config.current_version}")
self._update_traffic_split(config.model_name, config.current_version,
config.canary_version, 0)
deployment["status"] = "rolled_back"
deployment["ended_at"] = datetime.utcnow().isoformat()
def _promote(self, deployment_id):
deployment = self.deployments[deployment_id]
config = deployment["config"]
self._update_traffic_split(config.model_name, config.canary_version,
config.current_version, 100)
deployment["status"] = "promoted"
deployment["ended_at"] = datetime.utcnow().isoformat()
logger.info(f"Canary promoted: {config.canary_version} is now production")
# ใช้งาน
# canary = CanaryDeployment(ff_client)
# config = CanaryConfig(
# model_name="recommendation",
# current_version="v3",
# canary_version="v4",
# stages=[
# {"percentage": 5, "duration_min": 30},
# {"percentage": 25, "duration_min": 60},
# {"percentage": 50, "duration_min": 120},
# ],
# success_criteria={"accuracy": 0.9, "latency_p95_ms": 100, "error_rate": 0.05},
# )
# canary.start_canary(config)
Monitoring และ Automated Rollback
ระบบ monitoring สำหรับ model performance
#!/usr/bin/env python3
# model_monitor.py — ML Model Performance Monitor
from prometheus_client import Counter, Histogram, Gauge, start_http_server
import time
import logging
import json
from datetime import datetime
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("model_monitor")
# Prometheus Metrics
prediction_count = Counter("model_predictions_total", "Total predictions", ["model", "version"])
prediction_latency = Histogram("model_prediction_latency_seconds", "Prediction latency",
["model", "version"],
buckets=[0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0])
prediction_errors = Counter("model_prediction_errors_total", "Prediction errors", ["model", "version"])
model_accuracy = Gauge("model_accuracy", "Model accuracy (rolling)", ["model", "version"])
data_drift_score = Gauge("model_data_drift", "Data drift score", ["model", "version"])
class ModelMonitor:
def __init__(self, alert_callback=None):
self.alert_callback = alert_callback
self.rolling_window = {}
def record_prediction(self, model_name, version, latency_s, correct=None, error=None):
prediction_count.labels(model=model_name, version=version).inc()
prediction_latency.labels(model=model_name, version=version).observe(latency_s)
if error:
prediction_errors.labels(model=model_name, version=version).inc()
key = f"{model_name}_{version}"
if key not in self.rolling_window:
self.rolling_window[key] = {"correct": 0, "total": 0}
if correct is not None:
self.rolling_window[key]["total"] += 1
if correct:
self.rolling_window[key]["correct"] += 1
window = self.rolling_window[key]
if window["total"] > 0:
accuracy = window["correct"] / window["total"]
model_accuracy.labels(model=model_name, version=version).set(accuracy)
def check_alerts(self, model_name, version, thresholds):
key = f"{model_name}_{version}"
window = self.rolling_window.get(key, {})
alerts = []
# Check accuracy
if window.get("total", 0) > 100:
accuracy = window["correct"] / window["total"]
if accuracy < thresholds.get("min_accuracy", 0.8):
alerts.append({
"type": "accuracy_drop",
"model": model_name,
"version": version,
"value": round(accuracy, 4),
"threshold": thresholds["min_accuracy"],
})
if alerts and self.alert_callback:
for alert in alerts:
self.alert_callback(alert)
return alerts
def reset_window(self, model_name, version):
key = f"{model_name}_{version}"
self.rolling_window[key] = {"correct": 0, "total": 0}
class AutoRollbackController:
def __init__(self, monitor, ff_client, registry):
self.monitor = monitor
self.ff = ff_client
self.registry = registry
self.rollback_history = []
def evaluate_and_rollback(self, model_name, canary_version, stable_version, thresholds):
alerts = self.monitor.check_alerts(model_name, canary_version, thresholds)
if alerts:
logger.warning(f"Auto-rollback triggered for {model_name} v{canary_version}")
# Switch all traffic to stable version
# self.ff.update_variant(f"model_{model_name}", stable_version, 100)
# Archive canary in registry
# self.registry.promote_model(model_name, canary_version, "Archived")
self.rollback_history.append({
"model": model_name,
"rolled_back_version": canary_version,
"stable_version": stable_version,
"reason": alerts,
"timestamp": datetime.utcnow().isoformat(),
})
logger.info(f"Rollback complete: {model_name} v{canary_version} -> v{stable_version}")
return True
return False
# Start Prometheus metrics server
# start_http_server(9090)
# Grafana Dashboard Queries:
# - Prediction rate: rate(model_predictions_total[5m])
# - Error rate: rate(model_prediction_errors_total[5m]) / rate(model_predictions_total[5m])
# - P95 latency: histogram_quantile(0.95, rate(model_prediction_latency_seconds_bucket[5m]))
# - Accuracy: model_accuracy
# - Drift: model_data_drift
FAQ คำถามที่พบบ่อย
Q: Model Registry กับ Model Store ต่างกันอย่างไร?
A: Model Store เก็บ model artifacts (files, weights) อย่างเดียว เหมือน file storage Model Registry เก็บทั้ง artifacts และ metadata (version, stage, metrics, tags, lineage) จัดการ lifecycle (staging, production, archived) มี API สำหรับ promote/demote models และ integrate กับ CI/CD MLflow Model Registry เป็นตัวอย่างที่รวมทั้งสองอย่าง
Q: Feature Flags เหมาะกับ ML models ทุกแบบไหม?
A: เหมาะมากสำหรับ online serving models ที่ต้อง A/B test หรือ canary deploy เช่น recommendation, ranking, fraud detection สำหรับ batch prediction models ที่รัน offline อาจใช้ feature flags น้อยกว่า แต่ยังมีประโยชน์สำหรับเลือก model version ที่จะใช้ใน batch job ไม่เหมาะสำหรับ models ที่ต้อง consistency สูง (เช่น financial models ที่ต้องได้ผลเดียวกันทุกครั้ง)
Q: Canary deployment ใช้เวลานานแค่ไหน?
A: ขึ้นอยู่กับ risk tolerance และ traffic volume ทั่วไปใช้ 3-5 stages เช่น 5% (30 นาที) -> 25% (1 ชั่วโมง) -> 50% (2 ชั่วโมง) -> 100% รวมประมาณ 4-6 ชั่วโมง สำหรับ high-risk models อาจใช้เวลาหลายวัน สำหรับ low-risk updates อาจเร็วกว่า สิ่งสำคัญคือมี sufficient traffic ในแต่ละ stage เพื่อให้ metrics มีนัยสำคัญทางสถิติ
Q: Rollback อัตโนมัติทำงานอย่างไร?
A: ระบบ monitor เก็บ metrics ของ canary model แบบ real-time (accuracy, latency, error rate) เปรียบเทียบกับ thresholds ที่กำหนดไว้ ถ้า metrics ต่ำกว่า threshold ระบบจะ switch traffic กลับไป stable model ทันทีผ่าน feature flag update ไม่ต้อง redeploy archive canary model ใน registry และ alert ทีมให้ตรวจสอบ ทั้งหมดเกิดขึ้นภายในวินาทีโดยไม่ต้องมี human intervention
