-
Notifications
You must be signed in to change notification settings - Fork 2
/
main.py
62 lines (49 loc) · 1.85 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from fastapi import FastAPI, Request, Form
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from pydantic import BaseModel
from typing import Optional, Annotated
from classification_model import predict
import uvicorn
app = FastAPI()
app.mount("/static/", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
return templates.TemplateResponse(
"index.html",
{
"request": request,
"params": ["petal_length", "petal_width", "sepal_length", "sepal_width"],
},
)
@app.get("/valid-float", response_class=HTMLResponse)
async def validate_floats(request: Request):
param, val = list(request.query_params.items())[0]
if not val.isalpha():
return templates.TemplateResponse(
"param-field.html", {"request": request,
"param": param, "error": None, "value": val}
)
else:
return templates.TemplateResponse(
"param-field.html",
{"request": request, "param": param,
"error": "This param must be a float64 value"},
)
@app.post("/classify", response_class=HTMLResponse)
async def classify_endpoint(
request: Request,
petal_length: Annotated[str, Form()],
petal_width: Annotated[str, Form()],
sepal_length: Annotated[str, Form()],
sepal_width: Annotated[str, Form()],
):
predicted_class = predict(
[petal_length, petal_width, sepal_length, sepal_width])
return templates.TemplateResponse(
"classification.html", {"request": request, "class": predicted_class}
)
if __name__ == "__main__":
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)