Skip to content

Commit c46dc9d

Browse files
Added optimizers and losses endpoint.
1 parent 89bb5ea commit c46dc9d

File tree

5 files changed

+120
-109
lines changed

5 files changed

+120
-109
lines changed

.python-version

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
3.13.0
1+
3.12.0

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ dependencies = [
88
"boto3~=1.33.11",
99
"fastapi~=0.103.1",
1010
"mlflow>=2.20.2",
11-
"numpy>=2.2.3",
1211
"pandas>=2.2.3",
1312
"pydantic~=1.10.13",
1413
"python-dotenv~=1.0.0",

requirements.txt

Lines changed: 0 additions & 11 deletions
This file was deleted.

src/mlflow_api/main.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import pandas as pd
22
import uvicorn
33
from fastapi import FastAPI, Response, UploadFile
4+
from fastapi.exception_handlers import HTTPException
45
from starlette.responses import JSONResponse
56
from fastapi.middleware.cors import CORSMiddleware
67
from mlflow_api.mlflow_client import Client
78
from pydantic import BaseModel
89
from mlflow_api.schemas import Models, Parameters, Metrics, Dataset, Images, Versions
910
from dotenv import load_dotenv
11+
import torch.optim as optim
12+
import torch.nn as nn
1013

1114
load_dotenv()
1215

@@ -105,6 +108,53 @@ async def model_package(name: str):
105108
)
106109

107110

111+
@app.get("/optimizers/{framework}")
112+
async def optimizers(framework: str):
113+
if framework not in ["torch", "keras"]:
114+
raise HTTPException(400, "Allowed frameworks: ['torch', 'keras']")
115+
116+
if framework == "torch":
117+
opt = [op for op in dir(optim) if "_" not in op]
118+
return JSONResponse(opt)
119+
else:
120+
return JSONResponse([
121+
"SGD",
122+
"RMSprop",
123+
"Adagrad",
124+
"Adadelta",
125+
"Adam",
126+
"Adamax",
127+
"Nadam",
128+
"Ftrl"
129+
])
130+
131+
132+
@app.get("/losses/{framework}")
133+
async def losses(framework: str):
134+
if framework not in ["torch", "keras"]:
135+
raise HTTPException(400, "Allowed frameworks: ['torch', 'keras']")
136+
137+
if framework == "torch":
138+
return JSONResponse(["L1Loss", "MSELoss", "CrossEntropyLoss", "CTCLoss", "NLLLoss", "PoissonNLLLoss", "GaussianNLLLoss", "KLDivLoss", "BCELoss", "BCEWithLogitsLoss", "MarginRankingLoss", "HingeEmbeddingLoss", "MultiLabelMarginLoss", "HuberLoss", "SmoothL1Loss", "SoftMarginLoss", "MultiLabelSoftMarginLoss", "CosineEmbeddingLoss", "MultiMarginLoss", "TripletMarginLoss", "TripletMarginWithDistanceLoss"])
139+
else:
140+
return JSONResponse([
141+
"mean_squared_error",
142+
"mean_absolute_error",
143+
"mean_absolute_percentage_error",
144+
"mean_squared_logarithmic_error",
145+
"categorical_crossentropy",
146+
"sparse_categorical_crossentropy",
147+
"binary_crossentropy",
148+
"hinge",
149+
"squared_hinge",
150+
"categorical_hinge",
151+
"logcosh",
152+
"kullback_leibler_divergence",
153+
"poisson",
154+
"cosine_similarity"
155+
])
156+
157+
108158
@app.post("/model/predict")
109159
async def model_predict(name: str, file: UploadFile):
110160
df = pd.read_csv(file.file)

0 commit comments

Comments
 (0)