Skip to content

Commit

Permalink
Make database calls asynchronous
Browse files Browse the repository at this point in the history
  • Loading branch information
deanOcoin committed Aug 23, 2023
1 parent dca3bb2 commit ffe938f
Show file tree
Hide file tree
Showing 10 changed files with 60 additions and 58 deletions.
8 changes: 4 additions & 4 deletions backend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@

@app.post("/authenticate")
async def authenticate(data: UserCredentialsRequest):
u = db.user_manager.get_by_email(data.email)
u = await db.user_manager.get_by_email(data.email)

if not u:
return {"authenticated": False}
Expand All @@ -61,18 +61,18 @@ async def authenticate(data: UserCredentialsRequest):
response = JSONResponse(status_code=200, content={"authenticated": True})

response.set_cookie(key="authenticate", value=to_jwt(str(u["id"])), path="/")
db.user_manager.update_column(str(u["id"]), "last_signed_in", str(get_current_isodate()))
await db.user_manager.update_column(str(u["id"]), "last_signed_in", str(get_current_isodate()))
return response

return {"authenticated": False}


@app.get("/")
async def root(request: Request):
if not db.is_authenticated(request):
if not await db.is_authenticated(request):
return JSONResponse(status_code=200, content={"apiVersion": 1.1, "user":None})

udata = clean_udata(json.loads(db.user_manager.get_data_by_id(from_jwt(str(request.cookies.get("authenticate"))))))
udata = clean_udata(json.loads(await db.user_manager.get_data_by_id(from_jwt(str(request.cookies.get("authenticate"))))))

return JSONResponse(status_code=200, content=(udata)) # add API version to response content

Expand Down
4 changes: 2 additions & 2 deletions backend/dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
async def auth_dependency(request: Request) -> Union[bool, dict]:
user_id = from_jwt(str(request.cookies.get("authenticate")))

if not db.is_authenticated(request):
if not await db.is_authenticated(request):
raise HTTPException(status_code=401, detail="Not Authenticated")

return json.loads(db.user_manager.get_data_by_id(user_id))
return json.loads(await db.user_manager.get_data_by_id(user_id))

async def require_access_flag(flag: str):
# If in testing environment, just skip this because the meta server isn't run in test
Expand Down
4 changes: 2 additions & 2 deletions backend/gptroutes/ai_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async def generate_summary(
):
sentences = ""

note = db.note_manager.get_by_id(request, id)
note = await db.note_manager.get_by_id(request, id)
blocks = note["blocks"]
for b in blocks:
if b["type"] == "text" or b["type"] == "header":
Expand All @@ -81,7 +81,7 @@ async def generate_summary(
async def generate_quiz(request: Request, id:str, n:int, is_auth: Union[bool, dict] = Depends(auth_dependency)):
sentences = ""

note = db.note_manager.get_by_id(request, id)
note = await db.note_manager.get_by_id(request, id)
blocks = note["blocks"]
for b in blocks:
if b["type"] == "text" or b["type"] == "header":
Expand Down
38 changes: 19 additions & 19 deletions backend/noterdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def get_session(self):
session.commit()

class UserManager(BaseManager):
def get_data_by_id(self, id: str):
async def get_data_by_id(self, id: str):
with self.get_session() as session:
user = session.query(User).filter(User.id == id).first()

Expand All @@ -49,7 +49,7 @@ def get_data_by_id(self, id: str):

return False

def insert(self, user: dict):
async def insert(self, user: dict):
with self.get_session() as session:
user_obj = User(
id=user.get('id'),
Expand All @@ -68,7 +68,7 @@ def insert(self, user: dict):
session.add(user_obj)
session.commit()

def get_by_email(self, email: str):
async def get_by_email(self, email: str):
with self.get_session() as session:
user = session.query(User).filter(User.email == email).first()

Expand All @@ -78,21 +78,21 @@ def get_by_email(self, email: str):
"joined_on": str(user.joined_on), "history": user.history, "email_verified": user.email_verified,
"has_noter_access": user.has_noter_access, "verification_code": user.verification_code}

def get_notes(self, request: Request):
async def get_notes(self, request: Request):
user_id = from_jwt(str(request.cookies.get("authenticate")))
with self.get_session() as session:
notes = session.query(Note).filter(Note.owner_id == user_id).all()
return [{"id": note.id, "type": note.type, "name": note.name, "path": note.path,
"last_edited": str(note.last_edited), "created_on": str(note.created_on), "blocks": note.blocks} for note in notes]

def get_folders(self, request: Request):
async def get_folders(self, request: Request):
user_id = from_jwt(str(request.cookies.get("authenticate")))
with self.get_session() as session:
folders = session.query(Folder).filter(Folder.owner_id == user_id).all()
return [{"id": folder.id, "type": "folder", "name": folder.name, "path": folder.path,
"last_edited": str(folder.last_edited), "created_on": str(folder.created_on)} for folder in folders]

def update_column(self, user_id, column_name, column_value):
async def update_column(self, user_id, column_name, column_value):
if not is_valid_uuid4(user_id): return False

with self.get_session() as session:
Expand All @@ -103,7 +103,7 @@ def update_column(self, user_id, column_name, column_value):
return True
return False

def delete(self, user_id: str):
async def delete(self, user_id: str):
with self.get_session() as session:
try:
# Remove items with foreign key relationships first
Expand All @@ -119,7 +119,7 @@ def delete(self, user_id: str):


class NoteManager(BaseManager):
def insert(self, note: dict):
async def insert(self, note: dict):
with self.get_session() as session:
note_obj = Note(
id=note.get('id'),
Expand All @@ -134,7 +134,7 @@ def insert(self, note: dict):
session.add(note_obj)
session.commit()

def update_blocks_by_id(self, request: Request, id: str, new_blocks: str):
async def update_blocks_by_id(self, request: Request, id: str, new_blocks: str):
user_id = from_jwt(str(request.cookies.get("authenticate")))
with self.get_session() as session:
notes = session.query(Note).filter(Note.id == id, Note.owner_id == user_id).all()
Expand All @@ -145,7 +145,7 @@ def update_blocks_by_id(self, request: Request, id: str, new_blocks: str):

session.commit()

def get_by_id(self, request: Request, id:str):
async def get_by_id(self, request: Request, id:str):
user_id = from_jwt(str(request.cookies.get("authenticate")))

with self.get_session() as session:
Expand All @@ -167,7 +167,7 @@ def get_by_id(self, request: Request, id:str):


class FolderManager(BaseManager):
def does_path_exist(self, request: Request, fullpath: list):
async def does_path_exist(self, request: Request, fullpath: list):
if len(fullpath) == 0:
return True

Expand All @@ -181,7 +181,7 @@ def does_path_exist(self, request: Request, fullpath: list):

return False

def insert(self, folder: dict):
async def insert(self, folder: dict):
with self.get_session() as session:
folder_obj = Folder(
id=folder.get('id'),
Expand Down Expand Up @@ -233,14 +233,14 @@ def connect(self): # Returns True if connection to database is successful
return False


def is_authenticated(self, request: Request) -> bool:
async def is_authenticated(self, request: Request) -> bool:
user_id = from_jwt(str(request.cookies.get("authenticate")))
with self.get_session() as session:
user = session.query(User).filter(User.id == user_id).first()
return user is not None


def get_item(self, request: Request, id: str):
async def get_item(self, request: Request, id: str):
user_id = from_jwt(str(request.cookies.get("authenticate")))

with self.get_session() as session:
Expand Down Expand Up @@ -273,7 +273,7 @@ def get_item(self, request: Request, id: str):
return False


def delete_item_by_id(self, request: Request, id: str):
async def delete_item_by_id(self, request: Request, id: str):
user_id = from_jwt(str(request.cookies.get("authenticate")))

with self.get_session() as session:
Expand All @@ -286,7 +286,7 @@ def delete_item_by_id(self, request: Request, id: str):
session.commit()


def update_folder(self, folder: Folder, old_path: list, new_path: list, new_name: str):
async def update_folder(self, folder: Folder, old_path: list, new_path: list, new_name: str):
child_items = []

folder.name = new_name
Expand All @@ -302,11 +302,11 @@ def update_folder(self, folder: Folder, old_path: list, new_path: list, new_name
if item.type != "folder":
item.path = new_path + [folder.id]
elif item.type == "folder":
self.update_folder(item, item.path, (new_path + [folder.id]), item.name)
await self.update_folder(item, item.path, (new_path + [folder.id]), item.name)

session.commit()

def update_metadata_by_id(self, request: Request, id: str, new_name: str, new_path: list):
async def update_metadata_by_id(self, request: Request, id: str, new_name: str, new_path: list):
user_id = from_jwt(str(request.cookies.get("authenticate")))

with self.get_session() as session:
Expand All @@ -319,7 +319,7 @@ def update_metadata_by_id(self, request: Request, id: str, new_name: str, new_pa
return

folder = session.query(Folder).filter(Folder.id == id, Folder.owner_id == user_id).first()
self.update_folder(folder, folder.path, new_path, new_name)
await self.update_folder(folder, folder.path, new_path, new_name)


db = DB(os.environ['SQLALCHEMY_URL'])
Expand Down
26 changes: 14 additions & 12 deletions backend/routes/account_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,23 @@ async def request_password_update(request: Request, user: Union[bool, dict] = De
id = user.get("id")

v_code = str(randint_n(16))
if not db.user_manager.update_column(id, "verification_code", v_code): return Response(status_code=400)
smtp_client.send_verification_code(user["email"], v_code)
if not await db.user_manager.update_column(id, "verification_code", v_code): return Response(status_code=400)
print(user["email"])
print(v_code)
print(smtp_client.send_verification_code(user["email"], v_code))

return Response(status_code=200)


@router.post("/items/update/reqemail") # JSON EX: {"email":"[email protected]"}
async def request_email_update(request: Request, new_email: RequestEmailUpdateRequest, user: Union[bool, dict] = Depends(auth_dependency)):
if db.user_manager.get_by_email(new_email.email) is None:
if await db.user_manager.get_by_email(new_email.email) is None:
id = user.get("id")

cur_code = str(randint_n(16))
new_code = str(randint_n(16))

if not db.user_manager.update_column(id, "verification_code", f"{cur_code}#{new_code}"):
if not await db.user_manager.update_column(id, "verification_code", f"{cur_code}#{new_code}"):
return Response(status_code=400)

smtp_client.send_verification_code(user["email"], cur_code)
Expand All @@ -59,8 +61,8 @@ async def update_password(request: Request, user: Union[bool, dict] = Depends(au
user_ver_code = str(user["verification_code"]).split("#")[0]

if str(in_code) == user_ver_code:
if not db.user_manager.update_column(id, "password", new_password): return Response(status_code=400)
if not db.user_manager.update_column(id, "verification_code", ""): return Response(status_code=400)
if not await db.user_manager.update_column(id, "password", new_password): return Response(status_code=400)
if not await db.user_manager.update_column(id, "verification_code", ""): return Response(status_code=400)
return Response(status_code=204) # Valid code - password updated

return Response(status_code=400) # Invalid code
Expand All @@ -73,9 +75,9 @@ async def update_email(request: Request, new_email: EmailUpdateRequest, user: Un
expected_current_email_code, expected_new_email_code = str(user['verification_code']).split('#')

if str(new_email.cur_code) == expected_current_email_code and str(new_email.new_code) == expected_new_email_code:
if not db.user_manager.update_column(id, "email", new_email.email) or \
not db.user_manager.update_column(id, "verification_code", ""):
not db.user_manager.update_column(id, "email_verified", True)
if not await db.user_manager.update_column(id, "email", new_email.email) or \
not await db.user_manager.update_column(id, "verification_code", ""):
not await db.user_manager.update_column(id, "email_verified", True)
return Response(status_code=500)

return Response(status_code=204) # Valid code - password updated
Expand All @@ -86,7 +88,7 @@ async def update_email(request: Request, new_email: EmailUpdateRequest, user: Un
async def update_name(request: Request, name: NameUpdateRequest, user: Union[bool, dict] = Depends(auth_dependency)):
new_name = name.name

if not db.user_manager.update_column(user.get("id"), "name", new_name):
if not await db.user_manager.update_column(user.get("id"), "name", new_name):
return Response(status_code=400)

return Response(status_code=204)
Expand All @@ -95,7 +97,7 @@ async def update_name(request: Request, name: NameUpdateRequest, user: Union[boo
async def update_name(request: Request, pfp_data: PFPUpdateRequest, user: Union[bool, dict] = Depends(auth_dependency)):
new_pfp = pfp_data.image

if not db.user_manager.update_column(user.get("id"), "pfp", new_pfp):
if not await db.user_manager.update_column(user.get("id"), "pfp", new_pfp):
return Response(status_code=400)

return Response(status_code=204)
Expand All @@ -112,7 +114,7 @@ async def resend_verification_email(request: Request, user: Union[bool, dict] =

@router.post("/verify")
async def verify_email(request: Request, id: str):
if not db.user_manager.update_column(id, "email_verified", True):
if not await db.user_manager.update_column(id, "email_verified", True):
return Response(status_code=400)

return Response(status_code=204)
16 changes: 8 additions & 8 deletions backend/routes/create_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ async def create_user(
creds: UserCredentialsRequest,
_ = Depends(require_user_creation_access)
):
if db.user_manager.get_by_email(creds.email) is None:
if await db.user_manager.get_by_email(creds.email) is None:
user = make_user(creds.email, hash_password(creds.password))
db.user_manager.insert(user)
await db.user_manager.insert(user)
user = clean_udata(user)

m_link = f"{os.environ['LANDING_PAGE_URL']}/register?vid={user.get('id')}"
Expand All @@ -44,11 +44,11 @@ async def create_note(
is_auth: Union[bool, dict] = Depends(auth_dependency),
_ = Depends(require_item_creation_access)
):
if not db.folder_manager.does_path_exist(request, new_note.path):
if not await db.folder_manager.does_path_exist(request, new_note.path):
return Response(status_code=400)

note = make_note(request, new_note.name, new_note.path, False)
db.note_manager.insert(note)
await db.note_manager.insert(note)
return JSONResponse(status_code=201, content=note)


Expand All @@ -59,11 +59,11 @@ async def create_studyguide(
is_auth: Union[bool, dict] = Depends(auth_dependency),
_ = Depends(require_item_creation_access)
):
if not db.folder_manager.does_path_exist(request, new_study_guide.path):
if not await db.folder_manager.does_path_exist(request, new_study_guide.path):
return Response(status_code=400)

note = make_note(request, new_study_guide.name, new_study_guide.path, True)
db.note_manager.insert(note)
await db.note_manager.insert(note)
return JSONResponse(status_code=201, content=note)


Expand All @@ -74,9 +74,9 @@ async def create_folder(
is_auth: Union[bool, dict] = Depends(auth_dependency),
_ = Depends(require_item_creation_access)
):
if not db.folder_manager.does_path_exist(request, new_folder.path):
if not await db.folder_manager.does_path_exist(request, new_folder.path):
return Response(status_code=400)

folder = make_folder(request, new_folder.name, new_folder.path)
db.folder_manager.insert(folder)
await db.folder_manager.insert(folder)
return JSONResponse(status_code=201, content=folder)
8 changes: 4 additions & 4 deletions backend/routes/retrieve_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

@router.get("/items/{id}")
async def get_item(request: Request, id: str, is_auth: Union[bool, dict] = Depends(auth_dependency)):
item = db.get_item(request, id)
item = await db.get_item(request, id)
if not item: return Response(status_code=404)
return JSONResponse(json.loads(item), status_code=200)

Expand All @@ -25,13 +25,13 @@ async def list_items(request: Request, is_auth: Union[bool, dict] = Depends(auth
try: path = await request.json()
except json.decoder.JSONDecodeError: return Response(status_code=400)

if not db.folder_manager.does_path_exist(request, path): return Response(status_code=400)
if not await db.folder_manager.does_path_exist(request, path): return Response(status_code=400)

curr_users_notes = db.user_manager.get_notes(request)
curr_users_notes = await db.user_manager.get_notes(request)
for n in curr_users_notes:
if str(n["path"]) == str(path): ret.append(n)

curr_users_folders = db.user_manager.get_folders(request)
curr_users_folders = await db.user_manager.get_folders(request)
for f in curr_users_folders:
if str(f["path"]) == str(path): ret.append(f)

Expand Down
2 changes: 1 addition & 1 deletion backend/routes/stripe_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async def stripe_webhook(request: Request, stripe_signature: str = Header(None))
has_noter_access = subscription_status == 'active' or subscription_status == 'trialing'

# Update the database
with db.get_session() as session:
async with db.get_session() as session:
session.query(User).filter(User.stripe_id == stripe_customer_id).update({
User.has_noter_access: has_noter_access
})
Expand Down
Loading

0 comments on commit ffe938f

Please sign in to comment.