Skip to content

Commit

Permalink
added 2 funcs for predict
Browse files Browse the repository at this point in the history
  • Loading branch information
PHILIPP111007 committed Apr 12, 2024
1 parent 13e7956 commit aea4b89
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 30 deletions.
107 changes: 78 additions & 29 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@
from model import Model
from predict import predict
from config import CONFIG


from schema import (
InferenceResponse,
InferenceInput,
InferenceOutput,
ErrorResponse,
ModelDayInput,
)


Expand Down Expand Up @@ -55,41 +58,41 @@ async def lifespan(app: FastAPI):
)


@app.post(
"/api/v1/predict",
response_model=InferenceResponse,
responses={422: {"model": ErrorResponse}, 500: {"model": ErrorResponse}},
)
def do_predict(request: Request, body: InferenceInput):
"""
Perform prediction on input data
"""
# @app.post(
# "/api/v1/predict_day",
# response_model=InferenceResponse,
# responses={422: {"model": ErrorResponse}, 500: {"model": ErrorResponse}},
# )
# def do_predict_day(request: Request, body: InferenceInput):
# """
# Perform prediction on input data
# """

logger.info("API predict called")
logger.info(f"input: {body}")
# logger.info("API predict called")
# logger.info(f"input: {body}")

# prepare input data
X = body
# # prepare input data
# X = body

# run model inference
y = predict(app.package, [X])
# generate prediction based on probablity
pred = ["setosa", "versicolor", "virginica"][y.index(max(y))]
# # run model inference
# y = predict(app.package, [X])
# # generate prediction based on probablity
# pred = ["setosa", "versicolor", "virginica"][y.index(max(y))]

# round probablities for json
y = list(map(lambda v: round(v, CONFIG["ROUND_DIGIT"]), y))
# # round probablities for json
# y = list(map(lambda v: round(v, CONFIG["ROUND_DIGIT"]), y))

# prepare json for returning
logger.info(f"results: {y}")
# # prepare json for returning
# logger.info(f"results: {y}")

return InferenceResponse(
data=InferenceOutput(
predicted_value=y[0],
predicted_confidence_interval_lower_bound=y[1],
predicted_confidence_interval_upper_bound=y[2],
text=pred,
)
)
# return InferenceResponse(
# data=InferenceOutput(
# predicted_value=y[0],
# predicted_confidence_interval_lower_bound=y[1],
# predicted_confidence_interval_upper_bound=y[2],
# text=pred,
# )
# )


@app.get("/api/v1/about")
Expand All @@ -112,6 +115,52 @@ def bash(command):
app.include_router(api_test, prefix="/api/v1", tags=["tests"])
#############################################


@app.post("/api/v1/predict_interval")
def do_predict_interval():
"""
TODO: Perform prediction on input data
"""
from random import uniform

return {
"data": {
"dates": [
"2020-01-01",
"2020-01-02",
"2020-01-03",
"2020-01-04",
"2020-01-05",
"2020-01-06",
],
"x": [1, 2, 3, 4, 5, 6],
"y_pred": [
1.8,
2.6,
uniform(1.3, 2.2),
uniform(2.2, 3.8),
uniform(3.3, 3.5),
uniform(1.0, 1.3),
],
"y_true": [2, 3],
}
}


@app.post(
"/api/v1/predict_day",
responses={422: {"model": ErrorResponse}, 500: {"model": ErrorResponse}},
)
def do_predict_day(data: ModelDayInput):
"""
TODO: Perform prediction on input data
"""
from random import uniform

logger.info("API predict called")
return {"data": uniform(1.0, 100.0)}


if __name__ == "__main__":
uvicorn.run(
"main:app", host="0.0.0.0", port=8080, reload=True, log_config="log.ini"
Expand Down
46 changes: 45 additions & 1 deletion app/schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pydantic import BaseModel, Field
from datetime import datetime
from datetime import datetime, date


class NewsData(BaseModel):
Expand Down Expand Up @@ -50,5 +50,49 @@ class InferenceResponse(BaseModel):
data: InferenceOutput


########
########
########
################
########
########
################
########
########
################
########
########
########


class TestInput(BaseModel):
"""Test schema"""

date: str


class ModelDayInput(BaseModel):
# Necessary
CAPITALIZATION: float
CLOSE: float
DIVISOR: float
HIGH: float
LOW: float
OPEN: float
TRADEDATE: date
finance: str
economic: str
politic: str

# Optional
NAME: str
SHORTNAME: str
SECID: str
BOARDID: str
DURATION: str
YIELD: str
DECIMALS: str
CURRENCYID: str
VOLUME: str
TRADINGSESSION: str
VALUE: str

0 comments on commit aea4b89

Please sign in to comment.