Skip to content

Commit 6a0d976

Browse files
committed
auth refactor
1 parent 6406c1a commit 6a0d976

File tree

3 files changed

+59
-65
lines changed

3 files changed

+59
-65
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,4 +181,5 @@ restai_prod
181181
lab.py
182182

183183
# pyenv
184-
.python-version
184+
.python-version
185+
.aider*

app/auth.py

Lines changed: 54 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -29,85 +29,77 @@ def get_current_username(
2929
db_wrapper: DBWrapper = Depends(get_db_wrapper)
3030
):
3131
auth_header = request.headers.get('Authorization')
32-
bearer_token = None
3332
credentials = None
33+
3434
if auth_header:
3535
temp_bearer_token = auth_header.split(" ")[1]
36-
if "Bearer" in auth_header:
37-
bearer_token = temp_bearer_token
38-
else:
36+
37+
if "Bearer" in auth_header:
38+
user = db_wrapper.get_user_by_apikey(temp_bearer_token)
39+
40+
if user is None:
41+
raise HTTPException(
42+
status_code=401,
43+
detail="Invalid key"
44+
)
45+
46+
return User.model_validate(user)
47+
elif "Basic" in auth_header:
3948
try:
4049
credentials_b64 = base64.b64decode(temp_bearer_token).decode('utf-8')
4150
username, password = credentials_b64.split(':', 1)
4251
credentials = {
4352
'username': username,
4453
'password': password
4554
}
55+
56+
if RESTAI_AUTH_DISABLE_LOCAL or not credentials or ("username" not in credentials or "password" not in credentials):
57+
raise HTTPException(
58+
status_code=401,
59+
detail="Invalid credentials"
60+
)
61+
62+
user = db_wrapper.get_user_by_username(credentials["username"])
63+
64+
if user is None or user.sso:
65+
raise HTTPException(
66+
status_code=401,
67+
detail="Invalid credentials"
68+
)
69+
70+
is_correct_username = credentials["username"] == user.username
71+
is_correct_password = pwd_context.verify(
72+
credentials["password"], user.hashed_password)
73+
74+
if not (is_correct_username and is_correct_password):
75+
raise HTTPException(
76+
status_code=401,
77+
detail="Incorrect email or password",
78+
headers={"WWW-Authenticate": "Basic"},
79+
)
80+
81+
return User.model_validate(user)
82+
4683
except Exception:
4784
pass
85+
else:
86+
jwt_token = request.cookies.get("restai_token")
4887

49-
jwt_token = request.cookies.get("restai_token")
50-
51-
if bearer_token:
52-
user = db_wrapper.get_user_by_apikey(bearer_token)
88+
if jwt_token:
89+
try:
90+
data = jwt.decode(jwt_token, RESTAI_AUTH_SECRET, algorithms=["HS512"])
5391

54-
if user is None:
55-
raise HTTPException(
56-
status_code=401,
57-
detail="Invalid key"
58-
)
92+
user = db_wrapper.get_user_by_username(data["username"])
5993

60-
return User.model_validate(user)
61-
elif jwt_token:
62-
try:
63-
data = jwt.decode(jwt_token, RESTAI_AUTH_SECRET, algorithms=["HS512"])
94+
return User.model_validate(user)
95+
except Exception:
96+
raise HTTPException(
97+
status_code=401,
98+
detail="Invalid token"
99+
)
64100

65-
user = db_wrapper.get_user_by_username(data["username"])
66101

67-
return User.model_validate(user)
68-
except Exception:
69-
raise HTTPException(
70-
status_code=401,
71-
detail="Invalid token"
72-
)
73-
else:
74-
if RESTAI_AUTH_DISABLE_LOCAL or not credentials or (
75-
"username" not in credentials or "password" not in credentials):
76-
raise HTTPException(
77-
status_code=401,
78-
detail="Invalid credentials"
79-
)
80-
81-
user = db_wrapper.get_user_by_username(credentials["username"])
82-
83-
if user is None:
84-
raise HTTPException(
85-
status_code=401,
86-
detail="Invalid credentials"
87-
)
88-
89-
if user.sso:
90-
raise HTTPException(
91-
status_code=401,
92-
detail="SSO user"
93-
)
94-
95-
if user is not None:
96-
is_correct_username = credentials["username"] == user.username
97-
is_correct_password = pwd_context.verify(
98-
credentials["password"], user.hashed_password)
99-
else:
100-
is_correct_username = False
101-
is_correct_password = False
102-
103-
if not (is_correct_username and is_correct_password):
104-
raise HTTPException(
105-
status_code=401,
106-
detail="Incorrect email or password",
107-
headers={"WWW-Authenticate": "Basic"},
108-
)
109-
110-
return User.model_validate(user)
102+
111103

112104

113105
def get_current_username_admin(user: User = Depends(get_current_username)):

app/models/databasemodels.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,10 @@ class ProjectDatabase(Base):
4343
human_name = Column(String(255))
4444
human_description = Column(Text)
4545
tools = Column(Text)
46+
creator = Column(Integer)
4647
public = Column(Boolean, default=False)
4748
default_prompt = Column(Text)
48-
creator = Column(Integer, ForeignKey("users.id"))
49+
owner = Column(Integer, ForeignKey("users.id"))
4950
users = relationship('UserDatabase', secondary=users_projects, back_populates='projects')
5051
entrances = relationship("RouterEntrancesDatabase", back_populates="project")
5152

@@ -109,4 +110,4 @@ class EmbeddingDatabase(Base):
109110
options = Column(Text)
110111
privacy = Column(String(255))
111112
description = Column(Text)
112-
dimension = Column(Integer, default=1536)
113+
dimension = Column(Integer, default=1536)

0 commit comments

Comments
 (0)