Skip to content

Commit

Permalink
Invoke the API by running the script
Browse files Browse the repository at this point in the history
  • Loading branch information
Alyetama committed Jun 15, 2022
1 parent 6fa20e1 commit 71d1c30
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 9 deletions.
3 changes: 1 addition & 2 deletions birdfsd_yolov5/api/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
- Run locally

```sh
cd api
uvicorn api:app --reload
python api.py
```

## Example
Expand Down
6 changes: 4 additions & 2 deletions birdfsd_yolov5/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Any, Union

import pandas as pd
import uvicorn
from PIL import Image
from dotenv import load_dotenv
from fastapi import FastAPI, Response, UploadFile
Expand Down Expand Up @@ -158,10 +159,11 @@ def predict_endpoint(file: UploadFile,
api_s3 = api_utils.create_s3_client(api_s3=True)
db = mongodb_db()

model_version, model_name, model_weights, model = model_utils.init_model(
s3)
model_version, model_name, model_weights, model = model_utils.init_model(s3)

if os.getenv('MODEL_REPO'):
page = f'{os.getenv("MODEL_REPO")}/releases/tag/{model_version}'
else:
page = None

uvicorn.run(app, host='127.0.0.1', port=8000)
17 changes: 13 additions & 4 deletions birdfsd_yolov5/api/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""This module is used to get the latest model weights and information."""

from pathlib import Path
from typing import Tuple
from typing import Optional, Tuple

import torch
from minio import Minio
Expand Down Expand Up @@ -42,11 +42,15 @@ def get_latest_model_weights(s3_client: Minio,
return model_version, model_name, model_object_name


def init_model(s3: Minio) -> Tuple[str, str, str, torch.nn.Module]:
def init_model(
s3: Minio,
use_weights: Optional[str] = None
) -> Tuple[str, str, str, torch.nn.Module]:
"""This function initializes the model.
Args:
s3 (Minio): Minio S3 client object.
use_weights (str): Use this weights file instead of the latest model.
Returns:
model_version: The model version.
Expand All @@ -55,8 +59,13 @@ def init_model(s3: Minio) -> Tuple[str, str, str, torch.nn.Module]:
model: The model object.
"""
model_version, model_name, model_weights = get_latest_model_weights(
s3, skip_download=True)
if not use_weights:
model_version, model_name, model_weights = get_latest_model_weights(
s3, skip_download=True)
else:
model_version = Path(use_weights).stem
model_name = Path(use_weights).stem
model_weights = use_weights

if not Path(model_weights).exists():
model_version, model_name, model_weights = get_latest_model_weights(s3)
Expand Down
5 changes: 4 additions & 1 deletion notebooks/BirdFSD_YOLOv5_train.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,10 @@
"source": [
"EPOCHS = 100 #@param {type:\"integer\"}\n",
"BATCH_SIZE = 16 #@param {type:\"integer\"}\n",
"PRETRAINED_WEIGHTS = 'yolov5s' #@param {type:\"string\"}"
"if Path('best.pt').exists():\n",
" PRETRAINED_WEIGHTS = 'best.pt' #@param {type:\"string\"}\n",
"else:\n",
" PRETRAINED_WEIGHTS = 'yolov5s' #@param {type:\"string\"}"
]
},
{
Expand Down

0 comments on commit 71d1c30

Please sign in to comment.