diff --git a/birdfsd_yolov5/api/README.md b/birdfsd_yolov5/api/README.md index bd01e1c..ea6c445 100644 --- a/birdfsd_yolov5/api/README.md +++ b/birdfsd_yolov5/api/README.md @@ -3,8 +3,7 @@ - Run locally ```sh -cd api -uvicorn api:app --reload +python api.py ``` ## Example diff --git a/birdfsd_yolov5/api/api.py b/birdfsd_yolov5/api/api.py index a74722d..fc5d8d4 100644 --- a/birdfsd_yolov5/api/api.py +++ b/birdfsd_yolov5/api/api.py @@ -11,6 +11,7 @@ from typing import Any, Union import pandas as pd +import uvicorn from PIL import Image from dotenv import load_dotenv from fastapi import FastAPI, Response, UploadFile @@ -158,10 +159,11 @@ def predict_endpoint(file: UploadFile, api_s3 = api_utils.create_s3_client(api_s3=True) db = mongodb_db() - model_version, model_name, model_weights, model = model_utils.init_model( - s3) + model_version, model_name, model_weights, model = model_utils.init_model(s3) if os.getenv('MODEL_REPO'): page = f'{os.getenv("MODEL_REPO")}/releases/tag/{model_version}' else: page = None + + uvicorn.run(app, host='127.0.0.1', port=8000) diff --git a/birdfsd_yolov5/api/model_utils.py b/birdfsd_yolov5/api/model_utils.py index a9726fd..200f6bb 100644 --- a/birdfsd_yolov5/api/model_utils.py +++ b/birdfsd_yolov5/api/model_utils.py @@ -2,7 +2,7 @@ """This module is used to get the latest model weights and information.""" from pathlib import Path -from typing import Tuple +from typing import Optional, Tuple import torch from minio import Minio @@ -42,11 +42,15 @@ def get_latest_model_weights(s3_client: Minio, return model_version, model_name, model_object_name -def init_model(s3: Minio) -> Tuple[str, str, str, torch.nn.Module]: +def init_model( + s3: Minio, + use_weights: Optional[str] = None +) -> Tuple[str, str, str, torch.nn.Module]: """This function initializes the model. Args: s3 (Minio): Minio S3 client object. + use_weights (str): Use this weights file instead of the latest model. Returns: model_version: The model version. @@ -55,8 +59,13 @@ def init_model(s3: Minio) -> Tuple[str, str, str, torch.nn.Module]: model: The model object. """ - model_version, model_name, model_weights = get_latest_model_weights( - s3, skip_download=True) + if not use_weights: + model_version, model_name, model_weights = get_latest_model_weights( + s3, skip_download=True) + else: + model_version = Path(use_weights).stem + model_name = Path(use_weights).stem + model_weights = use_weights if not Path(model_weights).exists(): model_version, model_name, model_weights = get_latest_model_weights(s3) diff --git a/notebooks/BirdFSD_YOLOv5_train.ipynb b/notebooks/BirdFSD_YOLOv5_train.ipynb index a80b941..144a426 100644 --- a/notebooks/BirdFSD_YOLOv5_train.ipynb +++ b/notebooks/BirdFSD_YOLOv5_train.ipynb @@ -273,7 +273,10 @@ "source": [ "EPOCHS = 100 #@param {type:\"integer\"}\n", "BATCH_SIZE = 16 #@param {type:\"integer\"}\n", - "PRETRAINED_WEIGHTS = 'yolov5s' #@param {type:\"string\"}" + "if Path('best.pt').exists():\n", + " PRETRAINED_WEIGHTS = 'best.pt' #@param {type:\"string\"}\n", + "else:\n", + " PRETRAINED_WEIGHTS = 'yolov5s' #@param {type:\"string\"}" ] }, {