From d0242bd70733377552c16806663c4b3d96888f33 Mon Sep 17 00:00:00 2001 From: Manuel Gil Date: Sat, 9 Aug 2025 11:15:18 -0500 Subject: [PATCH 1/6] refactor init --- examples/iris_inference_server/__init__.py | 46 ++++------------------ 1 file changed, 8 insertions(+), 38 deletions(-) diff --git a/examples/iris_inference_server/__init__.py b/examples/iris_inference_server/__init__.py index c663160..af7d0cc 100644 --- a/examples/iris_inference_server/__init__.py +++ b/examples/iris_inference_server/__init__.py @@ -1,13 +1,11 @@ -from examples.utils.decorators import mlflow_tracking_uri +from examples.iris_inference_server.routers import inference +from examples.iris_inference_server.load_model import ModelLoader from fastapi import FastAPI -from fastapi import Request -from typing import List from contextlib import asynccontextmanager -import pandas as pd -ml_models = {} + +API_VERSION = "v1" -@mlflow_tracking_uri @asynccontextmanager async def load_ml_model(app: FastAPI): """ @@ -15,43 +13,15 @@ async def load_ml_model(app: FastAPI): This is a placeholder for actual model loading logic. """ try: - import mlflow - # Load your ML model here - print("Loading ML model...") - registered_model_name = "Iris_Classifier_Model" - model_uri = f"models:/{registered_model_name}@Production" - model = mlflow.sklearn.load_model(model_uri=model_uri) - ml_models[registered_model_name] = model + ml_loader = ModelLoader() + ml_loader.load_model("Iris_Classifier_Model", "Production") yield # This is where the model would be used - ml_models.clear() # Clear the model after use + ml_loader.clear_models() # Clear the model after use finally: print("Model loaded successfully.") app = FastAPI(title="Inference Server", lifespan=load_ml_model) - -@app.post("/predict") -async def root(request: Request): - """ - Root endpoint for the inference server. - This endpoint accepts a POST request with a JSON body containing - the features for prediction. - It returns the prediction made by the ML model. - """ - - body = await request.json() - print("Body received:", body) - features = body.get("features", None) - columns = body.get("columns", None) - if not features or not isinstance(features, List): - return {"error": "Invalid input. 'features' must be a list."} - model = ml_models.get("Iris_Classifier_Model", None) - if model: - # Assuming the model has a predict method - features = pd.DataFrame([features], columns=columns) - prediction = model.predict(features) - return {"prediction": prediction.tolist()} - else: - return {"error": "Model not found."} +app.include_router(inference, prefix=f"/{API_VERSION}", tags=["inference"]) From 8574145e8918f1d93477c4d537017cde3af0b791 Mon Sep 17 00:00:00 2001 From: Manuel Gil Date: Sat, 9 Aug 2025 11:15:53 -0500 Subject: [PATCH 2/6] adding pydantic to define schemas --- poetry.lock | 8 ++++---- pyproject.toml | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/poetry.lock b/poetry.lock index 1d7ae0e..3b2c43a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3417,13 +3417,13 @@ files = [ [[package]] name = "pydantic" -version = "2.11.4" +version = "2.11.7" description = "Data validation using Python type hints" optional = false python-versions = ">=3.9" files = [ - {file = "pydantic-2.11.4-py3-none-any.whl", hash = "sha256:d9615eaa9ac5a063471da949c8fc16376a84afb5024688b3ff885693506764eb"}, - {file = "pydantic-2.11.4.tar.gz", hash = "sha256:32738d19d63a226a52eed76645a98ee07c1f410ee41d93b4afbfa85ed8111c2d"}, + {file = "pydantic-2.11.7-py3-none-any.whl", hash = "sha256:dde5df002701f6de26248661f6835bbe296a47bf73990135c7d07ce741b9623b"}, + {file = "pydantic-2.11.7.tar.gz", hash = "sha256:d989c3c6cb79469287b1569f7447a17848c998458d49ebe294e975b9baf0f0db"}, ] [package.dependencies] @@ -5358,4 +5358,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "f6f1f88741fb1906871d9d8c28160305438e77c28289d64f99902536462fe7ba" +content-hash = "5d3851dadeff39fae429752d2dc5b5aee09e9f91eee2225f318595a9623be2f3" diff --git a/pyproject.toml b/pyproject.toml index 882fec6..bf5692b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ ucimlrepo = "^0.0.7" keras = "^3.8.0" torch = "^2.6.0" fastapi = {extras = ["standard"], version = "^0.115.13"} +pydantic = "^2.11.7" [tool.poetry.group.dev.dependencies] ipython = "^8.32.0" From 2d5d4cb54cabbd4930d85df55ec7741af4df7189 Mon Sep 17 00:00:00 2001 From: Manuel Gil Date: Sat, 9 Aug 2025 11:16:18 -0500 Subject: [PATCH 3/6] defining ModelLoader class --- examples/iris_inference_server/load_model.py | 52 ++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 examples/iris_inference_server/load_model.py diff --git a/examples/iris_inference_server/load_model.py b/examples/iris_inference_server/load_model.py new file mode 100644 index 0000000..16d4b45 --- /dev/null +++ b/examples/iris_inference_server/load_model.py @@ -0,0 +1,52 @@ +from examples.utils.decorators import mlflow_tracking_uri +import mlflow + + +class ModelLoader: + + # To validate mlflow aliases + __ALLOWED_ENVIRONMENTS = ["Production", "Staging", "Development"] + __ML_MODELS = {} + + @mlflow_tracking_uri + def load_model(self, model_name: str, environment: str = "Production") -> None: + """ + Load a model from MLflow. Use this function to load a model + at the before starting to serve the endpoints. + + :param model_name: Name of the model to load. + :param environment: Environment from which to load the model. + """ + if environment not in self.__ALLOWED_ENVIRONMENTS: + raise ValueError( + f"Invalid environment: {environment}. Allowed values are: {self.__ALLOWED_ENVIRONMENTS}" + ) + + registered_models = mlflow.MlflowClient().get_model_version_by_alias( + name=model_name, alias=environment + ) + + if not registered_models: + raise ValueError( + f"No registered model found for name: {model_name} and environment: {environment}" + ) + + model_uri = f"models:/{model_name}@{environment}" + model = mlflow.sklearn.load_model(model_uri=model_uri) + self.__ML_MODELS[model_name] = model + + @classmethod + def get_model(cls, model_name: str): + """ + Get a loaded model by name. + :param model_name: Name of the model to retrieve. + :return: The loaded model if found, None otherwise. + """ + return cls.__ML_MODELS.get(model_name) + + @classmethod + def clear_models(cls): + """ + Clear all loaded models. + """ + cls.__ML_MODELS.clear() From b600d4db1368642b8cb90950152682859867d117 Mon Sep 17 00:00:00 2001 From: Manuel Gil Date: Sat, 9 Aug 2025 11:16:37 -0500 Subject: [PATCH 4/6] defining inference router --- examples/iris_inference_server/routers.py | 39 +++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 examples/iris_inference_server/routers.py diff --git a/examples/iris_inference_server/routers.py b/examples/iris_inference_server/routers.py new file mode 100644 index 0000000..5ec69c3 --- /dev/null +++ b/examples/iris_inference_server/routers.py @@ -0,0 +1,39 @@ +from fastapi import APIRouter +from fastapi.responses import JSONResponse +import mlflow +from examples.iris_inference_server.schemas import IrisRequest +from examples.iris_inference_server.schemas import IrisResponse +from examples.iris_inference_server.load_model import ModelLoader + +inference = APIRouter() + + +@inference.get("/health") +def health_check(): + return {"status": "healthy"} + + +@inference.get("/ping") +def ping(): + return {"status": "pong"} + + +@inference.get("/version") +def version(): + return {"version": mlflow.__version__} + + +@inference.post("/invocations") +def invocations(iris_request: IrisRequest) -> IrisResponse: + features = iris_request.model_dump() + model = ModelLoader.get_model("Iris_Classifier_Model") + if model: + # Assuming the model has a predict method + features = iris_request.get_feature_values() + prediction = model.predict([features]) + proba = model.predict_proba([features]) + print("prediction:", prediction) + print("probability:", proba) + return IrisResponse(species=prediction[0], confidence=proba[0].max()) + else: + return JSONResponse(content={"error": "Model not found."}, status_code=404) From 73e1eba79cbd7e94b941e55a9be6639f655ebf03 Mon Sep 17 00:00:00 2001 From: Manuel Gil Date: Sat, 9 Aug 2025 11:16:54 -0500 Subject: [PATCH 5/6] defining input and output schemas --- examples/iris_inference_server/schemas.py | 49 +++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 examples/iris_inference_server/schemas.py diff --git a/examples/iris_inference_server/schemas.py b/examples/iris_inference_server/schemas.py new file mode 100644 index 0000000..d908093 --- /dev/null +++ b/examples/iris_inference_server/schemas.py @@ -0,0 +1,49 @@ +from pydantic import BaseModel +from pydantic import Field +from pydantic import ConfigDict +from pydantic import field_validator +from typing import List + + +class IrisRequest(BaseModel): + model_config = ConfigDict(populate_by_name=True) + sepal_length: float = Field(alias="sepal length (cm)") + sepal_width: float = Field(alias="sepal width (cm)") + petal_length: float = Field(alias="petal length (cm)") + petal_width: float = Field(alias="petal width (cm)") + + def get_feature_values(self) -> List[float]: + """ + Get the feature values as a list. + """ + return [ + self.sepal_length, + self.sepal_width, + self.petal_length, + self.petal_width, + ] + + def get_feature_names(self) -> List[str]: + """ + Get the feature names as a list. + """ + return [ + "sepal length (cm)", + "sepal width (cm)", + "petal length (cm)", + "petal width (cm)", + ] + + +class IrisResponse(BaseModel): + species: int = Field(description="Predicted species of the iris flower") + confidence: float = Field(description="Confidence score of the prediction") + + @field_validator("species") + @classmethod + def map_int_to_species(cls, species_id: int) -> str: + species_map = {0: "setosa", 1: "versicolor", 2: "virginica"} + + if species_id not in species_map: + raise ValueError(f"Invalid species_id: {species_id}. Must be 0, 1, or 2.") + return species_map.get(species_id) From 42bb9d51985969a6f5e378745049b3f5a29f37c0 Mon Sep 17 00:00:00 2001 From: Manuel Gil Date: Sat, 9 Aug 2025 14:25:21 -0500 Subject: [PATCH 6/6] Creating proper simple FastAPI API --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index bf5692b..98aa2a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "mlflow_for_ml_dev" -version = "1.7.2" +version = "1.8.0" description = "Code examples for the youtube playlist 'MLflow for Machine Learning Development' by Manuel Gil" authors = ["Manuel Gil "] readme = "README.md"