Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 8 additions & 38 deletions examples/iris_inference_server/__init__.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,27 @@
from examples.utils.decorators import mlflow_tracking_uri
from examples.iris_inference_server.routers import inference
from examples.iris_inference_server.load_model import ModelLoader
from fastapi import FastAPI
from fastapi import Request
from typing import List
from contextlib import asynccontextmanager
import pandas as pd
ml_models = {}

API_VERSION = "v1"


@mlflow_tracking_uri
@asynccontextmanager
async def load_ml_model(app: FastAPI):
"""
Context manager to load the ML model.
This is a placeholder for actual model loading logic.
"""
try:
import mlflow

# Load your ML model here
print("Loading ML model...")
registered_model_name = "Iris_Classifier_Model"
model_uri = f"models:/{registered_model_name}@Production"
model = mlflow.sklearn.load_model(model_uri=model_uri)
ml_models[registered_model_name] = model
ml_loader = ModelLoader()
ml_loader.load_model("Iris_Classifier_Model", "Production")
yield # This is where the model would be used
ml_models.clear() # Clear the model after use
ml_loader.clear_models() # Clear the model after use
finally:
print("Model loaded successfully.")


app = FastAPI(title="Inference Server", lifespan=load_ml_model)


@app.post("/predict")
async def root(request: Request):
"""
Root endpoint for the inference server.
This endpoint accepts a POST request with a JSON body containing
the features for prediction.
It returns the prediction made by the ML model.
"""

body = await request.json()
print("Body received:", body)
features = body.get("features", None)
columns = body.get("columns", None)
if not features or not isinstance(features, List):
return {"error": "Invalid input. 'features' must be a list."}
model = ml_models.get("Iris_Classifier_Model", None)
if model:
# Assuming the model has a predict method
features = pd.DataFrame([features], columns=columns)
prediction = model.predict(features)
return {"prediction": prediction.tolist()}
else:
return {"error": "Model not found."}
app.include_router(inference, prefix=f"/{API_VERSION}", tags=["inference"])
52 changes: 52 additions & 0 deletions examples/iris_inference_server/load_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from examples.utils.decorators import mlflow_tracking_uri
import mlflow


class ModelLoader:

# To validate mlflow aliases
__ALLOWED_ENVIRONMENTS = ["Production", "Staging", "Development"]
__ML_MODELS = {}

@mlflow_tracking_uri
def load_model(self, model_name: str, environment: str = "Production") -> None:
"""
Load a model from MLflow. Use this function to load a model
at the before starting to serve the endpoints.

:param model_name: Name of the model to load.
:param environment: Environment from which to load the model.
"""
if environment not in self.__ALLOWED_ENVIRONMENTS:
raise ValueError(
f"Invalid environment: {environment}. Allowed values are: {self.__ALLOWED_ENVIRONMENTS}"
)

registered_models = mlflow.MlflowClient().get_model_version_by_alias(
name=model_name, alias=environment
)

if not registered_models:
raise ValueError(
f"No registered model found for name: {model_name} and environment: {environment}"
)

model_uri = f"models:/{model_name}@{environment}"
model = mlflow.sklearn.load_model(model_uri=model_uri)
self.__ML_MODELS[model_name] = model

@classmethod
def get_model(cls, model_name: str):
"""
Get a loaded model by name.
:param model_name: Name of the model to retrieve.
:return: The loaded model if found, None otherwise.
"""
return cls.__ML_MODELS.get(model_name)

@classmethod
def clear_models(cls):
"""
Clear all loaded models.
"""
cls.__ML_MODELS.clear()
39 changes: 39 additions & 0 deletions examples/iris_inference_server/routers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from fastapi import APIRouter
from fastapi.responses import JSONResponse
import mlflow
from examples.iris_inference_server.schemas import IrisRequest
from examples.iris_inference_server.schemas import IrisResponse
from examples.iris_inference_server.load_model import ModelLoader

inference = APIRouter()


@inference.get("/health")
def health_check():
return {"status": "healthy"}


@inference.get("/ping")
def ping():
return {"status": "pong"}


@inference.get("/version")
def version():
return {"version": mlflow.__version__}


@inference.post("/invocations")
def invocations(iris_request: IrisRequest) -> IrisResponse:
features = iris_request.model_dump()
model = ModelLoader.get_model("Iris_Classifier_Model")
if model:
# Assuming the model has a predict method
features = iris_request.get_feature_values()
prediction = model.predict([features])
proba = model.predict_proba([features])
print("prediction:", prediction)
print("probability:", proba)
return IrisResponse(species=prediction[0], confidence=proba[0].max())
else:
return JSONResponse(content={"error": "Model not found."}, status_code=404)
49 changes: 49 additions & 0 deletions examples/iris_inference_server/schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from pydantic import BaseModel
from pydantic import Field
from pydantic import ConfigDict
from pydantic import field_validator
from typing import List


class IrisRequest(BaseModel):
model_config = ConfigDict(populate_by_name=True)
sepal_length: float = Field(alias="sepal length (cm)")
sepal_width: float = Field(alias="sepal width (cm)")
petal_length: float = Field(alias="petal length (cm)")
petal_width: float = Field(alias="petal width (cm)")

def get_feature_values(self) -> List[float]:
"""
Get the feature values as a list.
"""
return [
self.sepal_length,
self.sepal_width,
self.petal_length,
self.petal_width,
]

def get_feature_names(self) -> List[str]:
"""
Get the feature names as a list.
"""
return [
"sepal length (cm)",
"sepal width (cm)",
"petal length (cm)",
"petal width (cm)",
]


class IrisResponse(BaseModel):
species: int = Field(description="Predicted species of the iris flower")
confidence: float = Field(description="Confidence score of the prediction")

@field_validator("species")
@classmethod
def map_int_to_species(cls, species_id: int) -> str:
species_map = {0: "setosa", 1: "versicolor", 2: "virginica"}

if species_id not in species_map:
raise ValueError(f"Invalid species_id: {species_id}. Must be 0, 1, or 2.")
return species_map.get(species_id)
8 changes: 4 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "mlflow_for_ml_dev"
version = "1.7.2"
version = "1.8.0"
description = "Code examples for the youtube playlist 'MLflow for Machine Learning Development' by Manuel Gil"
authors = ["Manuel Gil <[email protected]>"]
readme = "README.md"
Expand Down Expand Up @@ -28,6 +28,7 @@ ucimlrepo = "^0.0.7"
keras = "^3.8.0"
torch = "^2.6.0"
fastapi = {extras = ["standard"], version = "^0.115.13"}
pydantic = "^2.11.7"

[tool.poetry.group.dev.dependencies]
ipython = "^8.32.0"
Expand Down
Loading