-
Notifications
You must be signed in to change notification settings - Fork 30
/
main.py
147 lines (119 loc) · 4.46 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# import sys, os
# sys.path.insert(
# 0, os.path.abspath("../")
# ) # Adds the parent directory to the system path
import os
import secrets
import traceback
import llm as llm
from utils import getenv, set_env_variables
import json, time
import litellm
from litellm import BudgetManager
litellm.max_budget = 1000
budget_manager = BudgetManager(project_name=os.getenv("PROJECT_NAME"), client_type="hosted")
from fastapi import FastAPI, Request, status, HTTPException, Depends
from fastapi.responses import StreamingResponse
from fastapi.security import OAuth2PasswordBearer
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
master_key = os.getenv("LITELLM_PROXY_MASTER_KEY", "sk-litellm-master-key")
user_api_keys = set(budget_manager.get_users())
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
######## AUTH UTILITIES ################
def user_api_key_auth(api_key: str = Depends(oauth2_scheme)):
if api_key == master_key:
return
if api_key not in user_api_keys:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail={"error": "invalid user key"},
# TODO: this will be {'detail': {'error': 'something'}}
)
def key_auth(api_key: str = Depends(oauth2_scheme)):
if api_key != master_key:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail={"error": "invalid admin key"},
# TODO: this will be {'detail': {'error': 'something'}}
)
######## CHAT COMPLETIONS ################
# for streaming
def data_generator(response):
for chunk in response:
# print(f"chunk: {chunk}")
yield f"data: {json.dumps(chunk)}\n\n"
# for completion
@app.post("/chat/completions", dependencies=[Depends(user_api_key_auth)])
async def completion(request: Request):
key = request.headers.get("Authorization").replace("Bearer ", "") # type: ignore
data = await request.json()
print(f"received request data: {data}")
data["user_key"] = key
data["budget_manager"] = budget_manager
data["master_key"] = master_key
set_env_variables(data)
# handle how users send streaming
if 'stream' in data:
if type(data['stream']) == str: # if users send stream as str convert to bool
# convert to bool
if data['stream'].lower() == "true":
data['stream'] = True # convert to boolean
response = llm.completion(**data)
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
return StreamingResponse(data_generator(response), media_type='text/event-stream')
return response
@app.get("/models/available")
def get_available_models():
return {"models": litellm.utils.get_valid_models()}
@app.get("/models") # if project requires model list
def model_list():
available_models = litellm.utils.get_valid_models()
data = []
for model in available_models:
{
"id": model,
"object": model,
"created": int(time.time()),
"owned_by": "openai"
}
return dict(
data=data,
object="list",
)
@app.get("/health")
async def health():
return {"status": "ok"}
######## KEY MANAGEMENT ################
@app.get("/key/cost", dependencies=[Depends(user_api_key_auth)])
async def report_current(request: Request):
key = request.headers.get("Authorization").replace("Bearer ", "") # type: ignore
return budget_manager.get_model_cost(key)
@app.post("/key/new", dependencies=[Depends(key_auth)])
async def generate_key(request: Request):
try:
data = await request.json()
data.get("total_budget")
except:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
total_budget = data["total_budget"]
api_key = f"sk-litellm-{secrets.token_urlsafe(16)}"
try:
budget_manager.create_budget(
total_budget=total_budget, user=api_key, duration="monthly"
)
user_api_keys.add(api_key)
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
return {"api_key": api_key, "total_budget": total_budget, "duration": "monthly"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=getenv("PORT", 8080))