Skip to content

Commit

Permalink
Working API example
Browse files Browse the repository at this point in the history
  • Loading branch information
mihow committed Apr 11, 2023
1 parent fe9d6e2 commit 2b7a372
Show file tree
Hide file tree
Showing 19 changed files with 250 additions and 48 deletions.
35 changes: 33 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ pytest-cov = "^4.0.0"
pytest-asyncio = "^0.21.0"
pytest = "*"
fastapi = "^0.95.0"
uvicorn = "^0.21.1"


[tool.pytest.ini_options]
Expand Down
11 changes: 0 additions & 11 deletions trapdata/api/__init__.py

This file was deleted.

39 changes: 39 additions & 0 deletions trapdata/api/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pathlib
from typing import Any, Dict, List, Optional

from pydantic import BaseSettings, HttpUrl, PostgresDsn, validator
from pydantic.networks import AnyHttpUrl

from trapdata.cli import read_settings
from trapdata.settings import Settings as BaseSettings


class Settings(BaseSettings):
PROJECT_NAME: str = "AMI Data Manager"

SENTRY_DSN: Optional[HttpUrl] = None

API_PATH: str = "/api/v1"

ACCESS_TOKEN_EXPIRE_MINUTES: int = 7 * 24 * 60 # 7 days

BACKEND_CORS_ORIGINS: List[AnyHttpUrl] = []

# The following variables need to be defined in environment

TEST_DATABASE_URL: Optional[PostgresDsn]

SECRET_KEY: str
# END: required environment variables

# STATIC_ROOT: str = "static"

# @validator("STATIC_ROOT")
# def validate_static_root(cls, v):
# path = cls.user_data_path / v
# path.mkdir(parents=True, exist_ok=True)
# return path


# settings = read_settings(SettingsClass=Settings, SECRET_KEY="secret")
settings = Settings(SECRET_KEY="secret")
27 changes: 0 additions & 27 deletions trapdata/api/deployments.py

This file was deleted.

4 changes: 2 additions & 2 deletions trapdata/api/deps/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
settings = read_settings()


def get_session() -> Generator[orm.Session, None]:
Session = get_session_class(settings.database_url)
def get_session() -> Generator[orm.Session, None, None]:
Session = get_session_class(db_path=settings.database_url)
with Session() as session:
yield session
session.close()
82 changes: 82 additions & 0 deletions trapdata/api/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from fastapi import FastAPI
from fastapi.routing import APIRoute
from fastapi.staticfiles import StaticFiles
from starlette.middleware.cors import CORSMiddleware
from starlette.requests import Request
from starlette.responses import FileResponse, RedirectResponse

from trapdata.api.config import settings
from trapdata.api.views import api_router


def create_app():
description = f"{settings.PROJECT_NAME} API"
app = FastAPI(
title=settings.PROJECT_NAME,
openapi_url=f"{settings.API_PATH}/openapi.json",
docs_url="/docs/",
description=description,
redoc_url="/redoc/",
)
setup_routers(app)
setup_cors_middleware(app)
serve_static_app(app)
return app


def setup_routers(app: FastAPI) -> None:
app.include_router(api_router, prefix=settings.API_PATH)
# The following operation needs to be at the end of this function
use_route_names_as_operation_ids(app)


def serve_static_app(app):
app.mount(
"/static/crops",
StaticFiles(directory=settings.user_data_path / "crops"),
name="crops",
)
app.mount(
"/",
StaticFiles(directory="trapdata/webui/public"),
name="static",
)

@app.middleware("http")
async def _add_404_middleware(request: Request, call_next):
"""Serves static assets on 404"""
response = await call_next(request)
path = request["path"]
if path.startswith(settings.API_PATH) or path.startswith("/docs"):
return response
if response.status_code == 404:
return FileResponse("trapdata/webui/public/index.html")
return response


def setup_cors_middleware(app):
if settings.BACKEND_CORS_ORIGINS:
app.add_middleware(
CORSMiddleware,
allow_origins=[str(origin) for origin in settings.BACKEND_CORS_ORIGINS],
allow_credentials=True,
allow_methods=["*"],
expose_headers=["Content-Range", "Range"],
allow_headers=["Authorization", "Range", "Content-Range"],
)


def use_route_names_as_operation_ids(app: FastAPI) -> None:
"""
Simplify operation IDs so that generated API clients have simpler function
names.
Should be called only after all routes have been added.
"""
route_names = set()
for route in app.routes:
if isinstance(route, APIRoute):
if route.name in route_names:
raise Exception("Route function names should be unique")
route.operation_id = route.name
route_names.add(route.name)
20 changes: 20 additions & 0 deletions trapdata/api/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from trapdata import logger
from trapdata.api.factory import create_app

app = create_app()


def run():
import uvicorn

logger.info("Starting uvicorn in reload mode")
uvicorn.run(
"main:app",
host="0.0.0.0",
reload=True,
port=int("8000"),
)


if __name__ == "__main__":
run()
8 changes: 8 additions & 0 deletions trapdata/api/views/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from fastapi import APIRouter

from trapdata.api.views import deployments, stats

api_router = APIRouter()

api_router.include_router(stats.router, tags=["stats"])
api_router.include_router(deployments.router, tags=["deployments"])
39 changes: 39 additions & 0 deletions trapdata/api/views/deployments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import Any, List, Optional

from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import func, orm, select
from starlette.responses import Response

from trapdata.api.config import settings
from trapdata.api.deps.db import get_session
from trapdata.api.deps.request_params import parse_react_admin_params
from trapdata.api.request_params import RequestParams
from trapdata.db import Base
from trapdata.db.models.deployments import DeploymentListItem, list_deployments

router = APIRouter(prefix="/deployments")


@router.get("", response_model=List[DeploymentListItem])
async def get_deployments(
response: Response,
session: orm.Session = Depends(get_session),
# request_params: RequestParams = Depends(parse_react_admin_params(Base)),
) -> Any:
deployments = list_deployments(session)
return deployments


@router.post("/process", response_model=List[DeploymentListItem])
async def process_deployment(
response: Response,
session: orm.Session = Depends(get_session),
# request_params: RequestParams = Depends(parse_react_admin_params(Base)),
) -> Any:
from trapdata.ml.pipeline import start_pipeline

start_pipeline(
session=session, image_base_path=settings.image_base_path, settings=settings
)
deployments = list_deployments(session)
return deployments
File renamed without changes.
File renamed without changes.
File renamed without changes.
11 changes: 8 additions & 3 deletions trapdata/api/utils.py → trapdata/api/views/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@

from fastapi import APIRouter

from app.schemas.msg import Msg
router = APIRouter(prefix="/stats")

router = APIRouter()

from pydantic import BaseModel


class Msg(BaseModel):
msg: str


@router.get(
"/hello-world",
"/",
response_model=Msg,
status_code=200,
include_in_schema=False,
Expand Down
File renamed without changes.
10 changes: 10 additions & 0 deletions trapdata/cli/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@ def gui():
run()


@cli.command()
def api():
"""
Launch API server
"""
from trapdata.api.main import run as start_api

start_api()


@cli.command("import")
def import_data(image_base_path: Optional[pathlib.Path] = None, queue: bool = True):
"""
Expand Down
3 changes: 2 additions & 1 deletion trapdata/db/models/deployments.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
is used as the deployment name.
"""
import pathlib
from typing import Optional

import sqlalchemy as sa
from pydantic import BaseModel
Expand All @@ -16,7 +17,7 @@


class DeploymentListItem(BaseModel):
# id: int
id: Optional[int] = None
name: str
num_events: int
num_source_images: int
Expand Down
7 changes: 5 additions & 2 deletions trapdata/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import sqlalchemy
from pydantic import BaseSettings, Field, ValidationError, validator
from pydantic.main import ModelMetaclass
from rich import print as rprint

from trapdata import ml
Expand Down Expand Up @@ -199,9 +200,11 @@ def kivy_settings_source(settings: BaseSettings) -> dict[str, str]:


@lru_cache
def read_settings(*args, **kwargs):
def read_settings(
settings_class: ModelMetaclass = Settings, *args, **kwargs
) -> ModelMetaclass:
try:
return Settings(*args, **kwargs)
return settings_class(*args, **kwargs)
except ValidationError as e:
# @TODO the validation errors could be printed in a more helpful way:
rprint(cli_help_message)
Expand Down
1 change: 1 addition & 0 deletions trapdata/webui/public/index.html
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
:)

0 comments on commit 2b7a372

Please sign in to comment.