1
1
import time
2
+ from datetime import datetime , timedelta
2
3
from pathlib import Path
3
4
5
+ import aiohttp
4
6
import alembic .command
5
7
import alembic .config
6
8
import fastapi
7
9
import sqlmodel
8
- from fastapi import Depends
10
+ from cryptography .hazmat .primitives import hashes
11
+ from cryptography .hazmat .primitives .kdf .hkdf import HKDF
12
+ from fastapi import Depends , HTTPException , Security
9
13
from fastapi .middleware .cors import CORSMiddleware
14
+ from fastapi .security import APIKeyCookie
15
+ from jose import jwe , jwt
10
16
from loguru import logger
11
17
from oasst_inference_server import client_handler , deps , interface , models , worker_handler
12
18
from oasst_inference_server .chat_repository import ChatRepository
13
19
from oasst_inference_server .settings import settings
14
- from oasst_shared .schemas import inference
20
+ from oasst_shared .schemas import inference , protocol
15
21
from prometheus_fastapi_instrumentator import Instrumentator
16
22
17
23
app = fastapi .FastAPI ()
24
+ oauth2_scheme = APIKeyCookie (name = settings .auth_cookie_name )
18
25
19
26
20
27
# add prometheus metrics at /metrics
@@ -48,7 +55,7 @@ def get_root_token(token: str = Depends(get_bearer_token)) -> str:
48
55
root_token = settings .root_token
49
56
if token == root_token :
50
57
return token
51
- raise fastapi . HTTPException (
58
+ raise HTTPException (
52
59
status_code = fastapi .status .HTTP_401_UNAUTHORIZED ,
53
60
detail = "Invalid token" ,
54
61
)
@@ -106,6 +113,74 @@ def maybe_add_debug_api_keys():
106
113
raise
107
114
108
115
116
+ @app .get ("/auth/login/discord" )
117
+ async def login_discord ():
118
+ redirect_uri = f"{ settings .api_root } /auth/callback/discord"
119
+ auth_url = f"https://discord.com/api/oauth2/authorize?client_id={ settings .auth_discord_client_id } &redirect_uri={ redirect_uri } &response_type=code&scope=identify"
120
+ raise HTTPException (status_code = 302 , headers = {"location" : auth_url })
121
+
122
+
123
+ @app .get ("/auth/callback/discord" , response_model = protocol .Token )
124
+ async def callback_discord (
125
+ code : str ,
126
+ db : sqlmodel .Session = Depends (deps .create_session ),
127
+ ):
128
+ redirect_uri = f"{ settings .api_root } /auth/callback/discord"
129
+
130
+ async with aiohttp .ClientSession (raise_for_status = True ) as session :
131
+ # Exchange the auth code for a Discord access token
132
+ async with session .post (
133
+ "https://discord.com/api/oauth2/token" ,
134
+ data = {
135
+ "client_id" : settings .auth_discord_client_id ,
136
+ "client_secret" : settings .auth_discord_client_secret ,
137
+ "grant_type" : "authorization_code" ,
138
+ "code" : code ,
139
+ "redirect_uri" : redirect_uri ,
140
+ "scope" : "identify" ,
141
+ },
142
+ ) as token_response :
143
+ token_response_json = await token_response .json ()
144
+
145
+ try :
146
+ access_token = token_response_json ["access_token" ]
147
+ except KeyError :
148
+ raise HTTPException (status_code = 400 , detail = "Invalid access token response from Discord" )
149
+
150
+ # Retrieve user's Discord information using access token
151
+ async with session .get (
152
+ "https://discord.com/api/users/@me" , headers = {"Authorization" : f"Bearer { access_token } " }
153
+ ) as user_response :
154
+ user_response_json = await user_response .json ()
155
+
156
+ try :
157
+ discord_id = user_response_json ["id" ]
158
+ discord_username = user_response_json ["username" ]
159
+ except KeyError :
160
+ raise HTTPException (status_code = 400 , detail = "Invalid user info response from Discord" )
161
+
162
+ # Try to find a user in our DB linked to the Discord user
163
+ user : models .DbUser = query_user_by_provider_id (db , discord_id = discord_id )
164
+
165
+ # Create if no user exists
166
+ if not user :
167
+ user = models .DbUser (provider = "discord" , provider_account_id = discord_id , display_name = discord_username )
168
+
169
+ db .add (user )
170
+ db .commit ()
171
+ db .refresh (user )
172
+
173
+ # Discord account is authenticated and linked to a user; create JWT
174
+ access_token = create_access_token (
175
+ {"user_id" : user .id },
176
+ settings .auth_secret ,
177
+ settings .auth_algorithm ,
178
+ settings .auth_access_token_expire_minutes ,
179
+ )
180
+
181
+ return protocol .Token (access_token = access_token , token_type = "bearer" )
182
+
183
+
109
184
@app .get ("/chat" )
110
185
async def list_chats (cr : ChatRepository = Depends (deps .create_chat_repository )) -> interface .ListChatsResponse :
111
186
"""Lists all chats."""
@@ -142,13 +217,11 @@ async def get_chat(id: str, cr: ChatRepository = Depends(deps.create_chat_reposi
142
217
@app .put ("/worker" )
143
218
def create_worker (
144
219
request : interface .CreateWorkerRequest ,
145
- root_token : str = fastapi . Depends (get_root_token ),
146
- session : sqlmodel .Session = fastapi . Depends (deps .create_session ),
220
+ root_token : str = Depends (get_root_token ),
221
+ session : sqlmodel .Session = Depends (deps .create_session ),
147
222
):
148
223
"""Allows a client to register a worker."""
149
- worker = models .DbWorker (
150
- name = request .name ,
151
- )
224
+ worker = models .DbWorker (name = request .name )
152
225
session .add (worker )
153
226
session .commit ()
154
227
session .refresh (worker )
@@ -157,8 +230,8 @@ def create_worker(
157
230
158
231
@app .get ("/worker" )
159
232
def list_workers (
160
- root_token : str = fastapi . Depends (get_root_token ),
161
- session : sqlmodel .Session = fastapi . Depends (deps .create_session ),
233
+ root_token : str = Depends (get_root_token ),
234
+ session : sqlmodel .Session = Depends (deps .create_session ),
162
235
):
163
236
"""Lists all workers."""
164
237
workers = session .exec (sqlmodel .select (models .DbWorker )).all ()
@@ -168,11 +241,52 @@ def list_workers(
168
241
@app .delete ("/worker/{worker_id}" )
169
242
def delete_worker (
170
243
worker_id : str ,
171
- root_token : str = fastapi . Depends (get_root_token ),
172
- session : sqlmodel .Session = fastapi . Depends (deps .create_session ),
244
+ root_token : str = Depends (get_root_token ),
245
+ session : sqlmodel .Session = Depends (deps .create_session ),
173
246
):
174
247
"""Deletes a worker."""
175
248
worker = session .get (models .DbWorker , worker_id )
176
249
session .delete (worker )
177
250
session .commit ()
178
251
return fastapi .Response (status_code = 200 )
252
+
253
+
254
+ def query_user_by_provider_id (db : sqlmodel .Session , discord_id : str | None = None ) -> models .DbUser | None :
255
+ """Returns the user associated with a given provider ID if any."""
256
+ user_qry = db .query (models .DbUser )
257
+
258
+ if discord_id :
259
+ user_qry = user_qry .filter (models .DbUser .provider == "discord" ).filter (
260
+ models .DbUser .provider_account_id == discord_id
261
+ )
262
+ # elif other IDs...
263
+ else :
264
+ return None
265
+
266
+ user : models .DbUser = user_qry .first ()
267
+ return user
268
+
269
+
270
+ def create_access_token (data : dict , secret : str , algorithm : str , expire_minutes : int ) -> str :
271
+ """Create encoded JSON Web Token (JWT) using the given data."""
272
+ expires_delta = timedelta (minutes = expire_minutes )
273
+ to_encode = data .copy ()
274
+ expire = datetime .utcnow () + expires_delta
275
+ to_encode .update ({"exp" : expire })
276
+ encoded_jwt = jwt .encode (to_encode , secret , algorithm = algorithm )
277
+ return encoded_jwt
278
+
279
+
280
+ def decode_user_access_token (token : str = Security (oauth2_scheme )) -> dict :
281
+ """Decode the current user JWT token and return the payload."""
282
+ # We first generate a key from the auth secret
283
+ hkdf = HKDF (
284
+ algorithm = hashes .SHA256 (),
285
+ length = settings .auth_length ,
286
+ salt = settings .auth_salt ,
287
+ info = settings .auth_info ,
288
+ )
289
+ key = hkdf .derive (settings .auth_secret )
290
+ # Next we decrypt the JWE token
291
+ payload = jwe .decrypt (token , key )
292
+ return payload
0 commit comments