Skip to content

Commit 0b9f4ac

Browse files
olliestanleyyk
andauthored
Introduce users and Discord auth to inference server (#1772)
Resolves #1471. The idea here is to introduce a flexible system which allows arbitrarily many authentication methods to be added in future, starting with Discord. We use a Discord OAuth2 flow to login a user (or register a new user), then issue a JSON Web Token (JWT) which can be decoded to obtain a user ID directly. Therefore any new auth method can follow the same pattern and auth can be handled entirely on the backend. This additionally includes a separate Redis instance for the inference stack, resolves the pre-commit failure in `main`, and adds missing dependencies for the inference server. - Ideally we would use `fastapi-discord` which would handle the Discord side for us, however there is no version of it which is compatible with FastAPI versions above `0.84.0`, and we depend on FastAPI `0.88.0` Some code relating to creating/decoding JWTs is duplicated from existing backend code and may be possible to move to `oasst_shared`. We also really need to clean up `main.py` in inference. After this initial inclusion of users to the server is merged, we will require two follow-on issues: - Modify DB schemas and code to associate client actions with DB users - Lock relevant API calls behind authentication by validating provided JWTs The new settings will also need configuring in the environments for the server to operate properly. --------- Co-authored-by: Yannic Kilcher <[email protected]>
1 parent e4b0c84 commit 0b9f4ac

File tree

7 files changed

+219
-25
lines changed

7 files changed

+219
-25
lines changed

docker-compose.yaml

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,21 @@ services:
153153
retries: 10
154154
profiles: ["inference"]
155155

156+
inference-redis:
157+
image: redis
158+
restart: always
159+
profiles: ["inference"]
160+
ports:
161+
- 6389:6379
162+
healthcheck:
163+
test: ["CMD-SHELL", "redis-cli ping | grep PONG"]
164+
interval: 2s
165+
timeout: 2s
166+
retries: 10
167+
command: redis-server /usr/local/etc/redis/redis.conf
168+
volumes:
169+
- ./redis.conf:/usr/local/etc/redis/redis.conf
170+
156171
inference-server:
157172
build:
158173
dockerfile: docker/inference/Dockerfile.server
@@ -161,7 +176,7 @@ services:
161176
image: oasst-inference-server:dev
162177
environment:
163178
PORT: 8000
164-
REDIS_HOST: redis
179+
REDIS_HOST: inference-redis
165180
POSTGRES_HOST: inference-db
166181
POSTGRES_DB: oasst_inference
167182
DEBUG_API_KEYS: '["0000"]'
@@ -172,7 +187,7 @@ services:
172187
ports:
173188
- "8000:8000"
174189
depends_on:
175-
redis:
190+
inference-redis:
176191
condition: service_healthy
177192
inference-db:
178193
condition: service_healthy
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
"""added users table
2+
3+
Revision ID: b365a18db6fd
4+
Revises: 4bead1c4cf52
5+
Create Date: 2023-02-21 22:05:52.620014
6+
7+
"""
8+
import sqlalchemy as sa
9+
import sqlmodel
10+
from alembic import op
11+
12+
# revision identifiers, used by Alembic.
13+
revision = "b365a18db6fd"
14+
down_revision = "4bead1c4cf52"
15+
branch_labels = None
16+
depends_on = None
17+
18+
19+
def upgrade() -> None:
20+
# ### commands auto generated by Alembic - please adjust! ###
21+
op.create_table(
22+
"user",
23+
sa.Column("id", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
24+
sa.Column("provider", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
25+
sa.Column("provider_account_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
26+
sa.Column("display_name", sqlmodel.sql.sqltypes.AutoString(length=256), nullable=False),
27+
sa.PrimaryKeyConstraint("id"),
28+
)
29+
op.create_index(op.f("ix_user_provider"), "user", ["provider"], unique=False)
30+
op.create_index(op.f("ix_user_provider_account_id"), "user", ["provider_account_id"], unique=False)
31+
op.create_index("provider", "user", ["provider_account_id"], unique=True)
32+
# ### end Alembic commands ###
33+
34+
35+
def downgrade() -> None:
36+
# ### commands auto generated by Alembic - please adjust! ###
37+
op.drop_index("provider", table_name="user")
38+
op.drop_index(op.f("ix_user_provider_account_id"), table_name="user")
39+
op.drop_index(op.f("ix_user_provider"), table_name="user")
40+
op.drop_table("user")
41+
# ### end Alembic commands ###

inference/server/main.py

Lines changed: 126 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,27 @@
11
import time
2+
from datetime import datetime, timedelta
23
from pathlib import Path
34

5+
import aiohttp
46
import alembic.command
57
import alembic.config
68
import fastapi
79
import sqlmodel
8-
from fastapi import Depends
10+
from cryptography.hazmat.primitives import hashes
11+
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
12+
from fastapi import Depends, HTTPException, Security
913
from fastapi.middleware.cors import CORSMiddleware
14+
from fastapi.security import APIKeyCookie
15+
from jose import jwe, jwt
1016
from loguru import logger
1117
from oasst_inference_server import client_handler, deps, interface, models, worker_handler
1218
from oasst_inference_server.chat_repository import ChatRepository
1319
from oasst_inference_server.settings import settings
14-
from oasst_shared.schemas import inference
20+
from oasst_shared.schemas import inference, protocol
1521
from prometheus_fastapi_instrumentator import Instrumentator
1622

1723
app = fastapi.FastAPI()
24+
oauth2_scheme = APIKeyCookie(name=settings.auth_cookie_name)
1825

1926

2027
# add prometheus metrics at /metrics
@@ -48,7 +55,7 @@ def get_root_token(token: str = Depends(get_bearer_token)) -> str:
4855
root_token = settings.root_token
4956
if token == root_token:
5057
return token
51-
raise fastapi.HTTPException(
58+
raise HTTPException(
5259
status_code=fastapi.status.HTTP_401_UNAUTHORIZED,
5360
detail="Invalid token",
5461
)
@@ -106,6 +113,74 @@ def maybe_add_debug_api_keys():
106113
raise
107114

108115

116+
@app.get("/auth/login/discord")
117+
async def login_discord():
118+
redirect_uri = f"{settings.api_root}/auth/callback/discord"
119+
auth_url = f"https://discord.com/api/oauth2/authorize?client_id={settings.auth_discord_client_id}&redirect_uri={redirect_uri}&response_type=code&scope=identify"
120+
raise HTTPException(status_code=302, headers={"location": auth_url})
121+
122+
123+
@app.get("/auth/callback/discord", response_model=protocol.Token)
124+
async def callback_discord(
125+
code: str,
126+
db: sqlmodel.Session = Depends(deps.create_session),
127+
):
128+
redirect_uri = f"{settings.api_root}/auth/callback/discord"
129+
130+
async with aiohttp.ClientSession(raise_for_status=True) as session:
131+
# Exchange the auth code for a Discord access token
132+
async with session.post(
133+
"https://discord.com/api/oauth2/token",
134+
data={
135+
"client_id": settings.auth_discord_client_id,
136+
"client_secret": settings.auth_discord_client_secret,
137+
"grant_type": "authorization_code",
138+
"code": code,
139+
"redirect_uri": redirect_uri,
140+
"scope": "identify",
141+
},
142+
) as token_response:
143+
token_response_json = await token_response.json()
144+
145+
try:
146+
access_token = token_response_json["access_token"]
147+
except KeyError:
148+
raise HTTPException(status_code=400, detail="Invalid access token response from Discord")
149+
150+
# Retrieve user's Discord information using access token
151+
async with session.get(
152+
"https://discord.com/api/users/@me", headers={"Authorization": f"Bearer {access_token}"}
153+
) as user_response:
154+
user_response_json = await user_response.json()
155+
156+
try:
157+
discord_id = user_response_json["id"]
158+
discord_username = user_response_json["username"]
159+
except KeyError:
160+
raise HTTPException(status_code=400, detail="Invalid user info response from Discord")
161+
162+
# Try to find a user in our DB linked to the Discord user
163+
user: models.DbUser = query_user_by_provider_id(db, discord_id=discord_id)
164+
165+
# Create if no user exists
166+
if not user:
167+
user = models.DbUser(provider="discord", provider_account_id=discord_id, display_name=discord_username)
168+
169+
db.add(user)
170+
db.commit()
171+
db.refresh(user)
172+
173+
# Discord account is authenticated and linked to a user; create JWT
174+
access_token = create_access_token(
175+
{"user_id": user.id},
176+
settings.auth_secret,
177+
settings.auth_algorithm,
178+
settings.auth_access_token_expire_minutes,
179+
)
180+
181+
return protocol.Token(access_token=access_token, token_type="bearer")
182+
183+
109184
@app.get("/chat")
110185
async def list_chats(cr: ChatRepository = Depends(deps.create_chat_repository)) -> interface.ListChatsResponse:
111186
"""Lists all chats."""
@@ -142,13 +217,11 @@ async def get_chat(id: str, cr: ChatRepository = Depends(deps.create_chat_reposi
142217
@app.put("/worker")
143218
def create_worker(
144219
request: interface.CreateWorkerRequest,
145-
root_token: str = fastapi.Depends(get_root_token),
146-
session: sqlmodel.Session = fastapi.Depends(deps.create_session),
220+
root_token: str = Depends(get_root_token),
221+
session: sqlmodel.Session = Depends(deps.create_session),
147222
):
148223
"""Allows a client to register a worker."""
149-
worker = models.DbWorker(
150-
name=request.name,
151-
)
224+
worker = models.DbWorker(name=request.name)
152225
session.add(worker)
153226
session.commit()
154227
session.refresh(worker)
@@ -157,8 +230,8 @@ def create_worker(
157230

158231
@app.get("/worker")
159232
def list_workers(
160-
root_token: str = fastapi.Depends(get_root_token),
161-
session: sqlmodel.Session = fastapi.Depends(deps.create_session),
233+
root_token: str = Depends(get_root_token),
234+
session: sqlmodel.Session = Depends(deps.create_session),
162235
):
163236
"""Lists all workers."""
164237
workers = session.exec(sqlmodel.select(models.DbWorker)).all()
@@ -168,11 +241,52 @@ def list_workers(
168241
@app.delete("/worker/{worker_id}")
169242
def delete_worker(
170243
worker_id: str,
171-
root_token: str = fastapi.Depends(get_root_token),
172-
session: sqlmodel.Session = fastapi.Depends(deps.create_session),
244+
root_token: str = Depends(get_root_token),
245+
session: sqlmodel.Session = Depends(deps.create_session),
173246
):
174247
"""Deletes a worker."""
175248
worker = session.get(models.DbWorker, worker_id)
176249
session.delete(worker)
177250
session.commit()
178251
return fastapi.Response(status_code=200)
252+
253+
254+
def query_user_by_provider_id(db: sqlmodel.Session, discord_id: str | None = None) -> models.DbUser | None:
255+
"""Returns the user associated with a given provider ID if any."""
256+
user_qry = db.query(models.DbUser)
257+
258+
if discord_id:
259+
user_qry = user_qry.filter(models.DbUser.provider == "discord").filter(
260+
models.DbUser.provider_account_id == discord_id
261+
)
262+
# elif other IDs...
263+
else:
264+
return None
265+
266+
user: models.DbUser = user_qry.first()
267+
return user
268+
269+
270+
def create_access_token(data: dict, secret: str, algorithm: str, expire_minutes: int) -> str:
271+
"""Create encoded JSON Web Token (JWT) using the given data."""
272+
expires_delta = timedelta(minutes=expire_minutes)
273+
to_encode = data.copy()
274+
expire = datetime.utcnow() + expires_delta
275+
to_encode.update({"exp": expire})
276+
encoded_jwt = jwt.encode(to_encode, secret, algorithm=algorithm)
277+
return encoded_jwt
278+
279+
280+
def decode_user_access_token(token: str = Security(oauth2_scheme)) -> dict:
281+
"""Decode the current user JWT token and return the payload."""
282+
# We first generate a key from the auth secret
283+
hkdf = HKDF(
284+
algorithm=hashes.SHA256(),
285+
length=settings.auth_length,
286+
salt=settings.auth_salt,
287+
info=settings.auth_info,
288+
)
289+
key = hkdf.derive(settings.auth_secret)
290+
# Next we decrypt the JWE token
291+
payload = jwe.decrypt(token, key)
292+
return payload

inference/server/oasst_inference_server/models.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import sqlalchemy.dialects.postgresql as pg
77
from oasst_inference_server import interface
88
from oasst_shared.schemas import inference
9-
from sqlmodel import Field, Relationship, SQLModel
9+
from sqlmodel import Field, Index, Relationship, SQLModel
1010

1111

1212
class DbMessage(SQLModel, table=True):
@@ -81,5 +81,16 @@ class DbWorker(SQLModel, table=True):
8181

8282
in_compliance_check: bool = Field(default=False, sa_column=sa.Column(sa.Boolean, server_default=sa.text("false")))
8383
next_compliance_check: datetime.datetime | None = Field(None)
84-
8584
events: list[DbWorkerEvent] = Relationship(back_populates="worker")
85+
86+
87+
class DbUser(SQLModel, table=True):
88+
__tablename__ = "user"
89+
__table_args__ = (Index("provider", "provider_account_id", unique=True),)
90+
91+
id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True)
92+
93+
provider: str = Field(..., index=True)
94+
provider_account_id: str = Field(..., index=True)
95+
96+
display_name: str = Field(nullable=False, max_length=256)

inference/server/oasst_inference_server/settings.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,5 +43,20 @@ def assemble_db_connection(cls, v: str | None, values: dict[str, Any]) -> Any:
4343
do_compliance_checks: bool = True
4444
compliance_check_interval: int = 60
4545

46+
api_root: str = "https://inference.prod.open-assistant.io"
47+
48+
use_auth: bool = True
49+
50+
auth_info: bytes = b"NextAuth.js Generated Encryption Key"
51+
auth_salt: bytes = b""
52+
auth_length: int = 32
53+
auth_secret: str = ""
54+
auth_algorithm: str = "HS256"
55+
auth_access_token_expire_minutes: int = 60
56+
auth_cookie_name: str = "temp"
57+
58+
auth_discord_client_id: str = ""
59+
auth_discord_client_secret: str = ""
60+
4661

4762
settings = Settings()

inference/server/requirements.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
1+
aiohttp
12
alembic
3+
cryptography==39.0.0
24
fastapi[all]==0.88.0
35
loguru
46
nvidia-ml-py
57
prometheus-fastapi-instrumentator
68
psutil
79
psycopg2-binary
810
pydantic
11+
pynvml
12+
python-jose[cryptography]==3.3.0
913
redis
1014
sqlmodel
1115
sse-starlette

oasst-shared/oasst_shared/schemas/inference.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,7 @@ def __init__(self, **data):
6464

6565
class WorkerConfig(pydantic.BaseModel):
6666
model_name: str = DEFAULT_MODEL_NAME
67-
hardware_info: WorkerHardwareInfo = pydantic.Field(
68-
default_factory=WorkerHardwareInfo
69-
)
67+
hardware_info: WorkerHardwareInfo = pydantic.Field(default_factory=WorkerHardwareInfo)
7068

7169
@property
7270
def compat_hash(self) -> str:
@@ -81,9 +79,7 @@ class WorkParameters(pydantic.BaseModel):
8179
top_p: float = 0.9
8280
temperature: float = 1.0
8381
repetition_penalty: float | None = None
84-
seed: int = pydantic.Field(
85-
default_factory=lambda: random.randint(0, 0xFFFF_FFFF_FFFF_FFFF - 1)
86-
)
82+
seed: int = pydantic.Field(default_factory=lambda: random.randint(0, 0xFFFF_FFFF_FFFF_FFFF - 1))
8783

8884

8985
class MessageState(str, enum.Enum):
@@ -111,9 +107,7 @@ class Thread(pydantic.BaseModel):
111107

112108
class WorkRequest(pydantic.BaseModel):
113109
thread: Thread = pydantic.Field(..., repr=False)
114-
created_at: datetime.datetime = pydantic.Field(
115-
default_factory=datetime.datetime.utcnow
116-
)
110+
created_at: datetime.datetime = pydantic.Field(default_factory=datetime.datetime.utcnow)
117111
parameters: WorkParameters = pydantic.Field(default_factory=WorkParameters)
118112

119113

0 commit comments

Comments
 (0)