diff --git a/migrations/env.py b/migrations/env.py index 3b98615..62fe160 100644 --- a/migrations/env.py +++ b/migrations/env.py @@ -3,7 +3,7 @@ from alembic import context from sqlalchemy import engine_from_config, pool -from print_service.models import Model +from print_service.models import BaseDbModel from print_service.settings import get_settings @@ -20,7 +20,7 @@ # for 'autogenerate' support # from myapp import mymodel # target_metadata = mymodel.Base.metadata -target_metadata = Model.metadata +target_metadata = BaseDbModel.metadata # other values from the config, defined by the needs of env.py, # can be acquired: diff --git a/migrations/versions/90539e2253b3_add_soft_delete.py b/migrations/versions/90539e2253b3_add_soft_delete.py new file mode 100644 index 0000000..87c0364 --- /dev/null +++ b/migrations/versions/90539e2253b3_add_soft_delete.py @@ -0,0 +1,34 @@ +"""add soft delete + +Revision ID: 90539e2253b3 +Revises: a68c6bb2972c +Create Date: 2025-06-01 17:29:08.641697 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '90539e2253b3' +down_revision = 'a68c6bb2972c' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column('file', 'source', + existing_type=sa.VARCHAR(), + nullable=False) + op.add_column('union_member', sa.Column('is_deleted', sa.Boolean(), nullable=False)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('union_member', 'is_deleted') + op.alter_column('file', 'source', + existing_type=sa.VARCHAR(), + nullable=True) + # ### end Alembic commands ### diff --git a/print_service/exceptions.py b/print_service/exceptions.py index 99ed16a..d34d15a 100644 --- a/print_service/exceptions.py +++ b/print_service/exceptions.py @@ -3,9 +3,18 @@ settings = get_settings() - class ObjectNotFound(Exception): - pass + def __init__(self, obj: type, obj_id_or_name: int | str): + super().__init__( + f"Object {obj.__name__} {obj_id_or_name=} not found", + ) + + +class AlreadyExists(Exception): + def __init__(self, obj: type, obj_id_or_name: int | str): + super().__init__( + f"Object {obj.__name__}, {obj_id_or_name=} already exists", + ) class TerminalTokenNotFound(ObjectNotFound): @@ -71,6 +80,14 @@ def __init__(self, content_type: str): f'Only {", ".join(settings.CONTENT_TYPES)} files allowed, but {content_type} was recieved' ) +class PrintCodeExpired(Exception): + def __init__(self): + super().__init__(f'Print code expired') + +class PrintLimitExceed(Exception): + def __init__(self): + super().__init__(f'Print limit exceed') + class AlreadyUploaded(Exception): def __init__(self): diff --git a/print_service/models/__init__.py b/print_service/models/__init__.py index d1ffcfe..f203ead 100644 --- a/print_service/models/__init__.py +++ b/print_service/models/__init__.py @@ -1,92 +1,5 @@ -from __future__ import annotations +from .base import Base, BaseDbModel +from .db import * -import math -from datetime import datetime -from sqlalchemy import Column, DateTime, Integer, String -from sqlalchemy.ext.declarative import as_declarative -from sqlalchemy.orm import Mapped, mapped_column, relationship -from sqlalchemy.sql.schema import ForeignKey -from sqlalchemy.sql.sqltypes import Boolean - - -@as_declarative() -class Model: - pass - - -class UnionMember(Model): - __tablename__ = 'union_member' - - id: Mapped[int] = mapped_column(Integer, primary_key=True) - surname: Mapped[str] = mapped_column(String, nullable=False) - union_number: Mapped[str] = mapped_column(String, nullable=True) - student_number: Mapped[str] = mapped_column(String, nullable=True) - - files: Mapped[list[File]] = relationship('File', back_populates='owner') - print_facts: Mapped[list[PrintFact]] = relationship('PrintFact', back_populates='owner') - - -class File(Model): - __tablename__ = 'file' - - id: Mapped[int] = Column(Integer, primary_key=True) - pin: Mapped[str] = Column(String, nullable=False) - file: Mapped[str] = Column(String, nullable=False) - owner_id: Mapped[int] = Column(Integer, ForeignKey('union_member.id'), nullable=False) - option_pages: Mapped[str] = Column(String) - option_copies: Mapped[int] = Column(Integer) - option_two_sided: Mapped[bool] = Column(Boolean) - created_at: Mapped[datetime] = Column(DateTime, nullable=False, default=datetime.utcnow) - updated_at: Mapped[datetime] = Column( - DateTime, nullable=False, default=datetime.utcnow, onupdate=datetime.utcnow - ) - number_of_pages: Mapped[int] = Column(Integer) - source: Mapped[str] = Column(String, default='unknown', nullable=False) - - owner: Mapped[UnionMember] = relationship('UnionMember', back_populates='files') - print_facts: Mapped[list[PrintFact]] = relationship('PrintFact', back_populates='file') - - @property - def flatten_pages(self) -> list[int] | None: - '''Возвращает расширенный список из элементов списков внутренних целочисленных точек переданного множества отрезков - "1-5, 3, 2" --> [1, 2, 3, 4, 5, 3, 2]''' - if self.number_of_pages is None: - return None - result = list() - if self.option_pages == '': - return result - for part in self.option_pages.split(','): - x = part.split('-') - result.extend(range(int(x[0]), int(x[-1]) + 1)) - return result - - @property - def sheets_count(self) -> int | None: - '''Возвращает количество элементов списков внутренних целочисленных точек переданного множества отрезков - "1-5, 3, 2" --> 7 - P.S. 1, 2, 3, 4, 5, 3, 2 -- 7 чисел''' - if self.number_of_pages is None: - return None - if not self.flatten_pages: - return ( - math.ceil(self.number_of_pages - (self.option_two_sided * self.number_of_pages / 2)) - * self.option_copies - ) - if self.option_two_sided: - return math.ceil(len(self.flatten_pages) / 2) * self.option_copies - else: - return len(self.flatten_pages) * self.option_copies - - -class PrintFact(Model): - __tablename__ = 'print_fact' - - id: Mapped[int] = Column(Integer, primary_key=True) - file_id: Mapped[int] = Column(Integer, ForeignKey('file.id'), nullable=False) - owner_id: Mapped[int] = Column(Integer, ForeignKey('union_member.id'), nullable=False) - created_at: Mapped[datetime] = Column(DateTime, nullable=False, default=datetime.utcnow) - - owner: Mapped[UnionMember] = relationship('UnionMember', back_populates='print_facts') - file: Mapped[File] = relationship('File', back_populates='print_facts') - sheets_used: Mapped[int] = Column(Integer) +__all__ = ["Base", "BaseDbModel", "UnionMember", "File", "PrintFact"] diff --git a/print_service/models/base.py b/print_service/models/base.py new file mode 100644 index 0000000..66afb57 --- /dev/null +++ b/print_service/models/base.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +import re + +from sqlalchemy import not_ +from sqlalchemy.exc import NoResultFound +from sqlalchemy.orm import Query, Session, as_declarative, declared_attr + +from ..exceptions import AlreadyExists, ObjectNotFound + + +@as_declarative() +class Base: + """Base class for all database entities""" + + @declared_attr + def __tablename__(cls) -> str: # pylint: disable=no-self-argument + """Generate database table name automatically. + Convert CamelCase class name to snake_case db table name. + """ + return re.sub(r"(? BaseDbModel: + obj = cls(**kwargs) + session.add(obj) + session.flush() + return obj + + @classmethod + def query(cls, *, with_deleted: bool = False, session: Session) -> Query: + """Get all objects with soft deletes""" + objs = session.query(cls) + if not with_deleted and hasattr(cls, "is_deleted"): + objs = objs.filter(not_(cls.is_deleted)) + return objs + + @classmethod + def get(cls, id: int | str, *, with_deleted=False, session: Session) -> BaseDbModel: + """Get object with soft deletes""" + objs = session.query(cls) + if not with_deleted and hasattr(cls, "is_deleted"): + objs = objs.filter(not_(cls.is_deleted)) + try: + if hasattr(cls, "uuid"): + return objs.filter(cls.uuid == id).one() + return objs.filter(cls.id == id).one() + except NoResultFound: + raise ObjectNotFound(obj=cls, obj_id_or_name=id) + + @classmethod + def update(cls, id: int | str, *, session: Session, **kwargs) -> BaseDbModel: + """Update model with new values from kwargs. + If no new values are given, raise HTTP 409 error. + """ + get_new_values = False + obj = cls.get(id, session=session) + for k, v in kwargs.items(): + cur_v = getattr(obj, k) + if cur_v != v: + setattr(obj, k, v) + get_new_values = True + if not get_new_values: + raise AlreadyExists(cls, id) + session.add(obj) + session.flush() + return obj + + @classmethod + def delete(cls, id: int | str, *, session: Session) -> None: + """Soft delete object if possible, else hard delete""" + obj = cls.get(id, session=session) + if hasattr(obj, "is_deleted"): + obj.is_deleted = True + else: + session.delete(obj) + session.flush() diff --git a/print_service/models/db.py b/print_service/models/db.py new file mode 100644 index 0000000..ca27d35 --- /dev/null +++ b/print_service/models/db.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +import math +from datetime import datetime + +from sqlalchemy import Column, DateTime, Integer, String +from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.sql.schema import ForeignKey +from sqlalchemy.sql.sqltypes import Boolean + +from .base import BaseDbModel + + +class UnionMember(BaseDbModel): + __tablename__ = 'union_member' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + surname: Mapped[str] = mapped_column(String, nullable=False) + union_number: Mapped[str] = mapped_column(String, nullable=True) + student_number: Mapped[str] = mapped_column(String, nullable=True) + + files: Mapped[list[File]] = relationship('File', back_populates='owner') + print_facts: Mapped[list[PrintFact]] = relationship('PrintFact', back_populates='owner') + is_deleted: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + + +class File(BaseDbModel): + __tablename__ = 'file' + + id: Mapped[int] = Column(Integer, primary_key=True) + pin: Mapped[str] = Column(String, nullable=False) + file: Mapped[str] = Column(String, nullable=False) + owner_id: Mapped[int] = Column(Integer, ForeignKey('union_member.id'), nullable=False) + option_pages: Mapped[str] = Column(String) + option_copies: Mapped[int] = Column(Integer) + option_two_sided: Mapped[bool] = Column(Boolean) + created_at: Mapped[datetime] = Column(DateTime, nullable=False, default=datetime.utcnow) + updated_at: Mapped[datetime] = Column( + DateTime, nullable=False, default=datetime.utcnow, onupdate=datetime.utcnow + ) + number_of_pages: Mapped[int] = Column(Integer) + source: Mapped[str] = Column(String, default='unknown', nullable=False) + + owner: Mapped[UnionMember] = relationship( + 'UnionMember', + primaryjoin="and_(File.owner_id==UnionMember.id, not_(UnionMember.is_deleted))", + back_populates='files', + ) + print_facts: Mapped[list[PrintFact]] = relationship('PrintFact', back_populates='file') + + @property + def flatten_pages(self) -> list[int] | None: + '''Возвращает расширенный список из элементов списков внутренних целочисленных точек переданного множества отрезков + "1-5, 3, 2" --> [1, 2, 3, 4, 5, 3, 2]''' + if self.number_of_pages is None: + return None + result = list() + if self.option_pages == '': + return result + for part in self.option_pages.split(','): + x = part.split('-') + result.extend(range(int(x[0]), int(x[-1]) + 1)) + return result + + @property + def sheets_count(self) -> int | None: + '''Возвращает количество элементов списков внутренних целочисленных точек переданного множества отрезков + "1-5, 3, 2" --> 7 + P.S. 1, 2, 3, 4, 5, 3, 2 -- 7 чисел''' + if self.number_of_pages is None: + return None + if not self.flatten_pages: + return ( + math.ceil(self.number_of_pages - (self.option_two_sided * self.number_of_pages / 2)) + * self.option_copies + ) + if self.option_two_sided: + return math.ceil(len(self.flatten_pages) / 2) * self.option_copies + else: + return len(self.flatten_pages) * self.option_copies + + +class PrintFact(BaseDbModel): + __tablename__ = 'print_fact' + + id: Mapped[int] = Column(Integer, primary_key=True) + file_id: Mapped[int] = Column(Integer, ForeignKey('file.id'), nullable=False) + owner_id: Mapped[int] = Column(Integer, ForeignKey('union_member.id'), nullable=False) + created_at: Mapped[datetime] = Column(DateTime, nullable=False, default=datetime.utcnow) + + owner: Mapped[UnionMember] = relationship( + 'UnionMember', + primaryjoin="and_(PrintFact.owner_id==UnionMember.id, not_(UnionMember.is_deleted))", + back_populates='print_facts', + ) + file: Mapped[File] = relationship('File', back_populates='print_facts') + sheets_used: Mapped[int] = Column(Integer) diff --git a/print_service/routes/admin.py b/print_service/routes/admin.py index da2b7b1..25cef4e 100644 --- a/print_service/routes/admin.py +++ b/print_service/routes/admin.py @@ -2,7 +2,7 @@ import logging from auth_lib.fastapi import UnionAuth -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends from redis import Redis from print_service.exceptions import TerminalTokenNotFound diff --git a/print_service/routes/exc_handlers.py b/print_service/routes/exc_handlers.py index ecb8aeb..219c707 100644 --- a/print_service/routes/exc_handlers.py +++ b/print_service/routes/exc_handlers.py @@ -1,4 +1,4 @@ -import requests.models +from email import message import starlette.requests from starlette.responses import JSONResponse @@ -21,6 +21,10 @@ UnionStudentDuplicate, UnprocessableFileInstance, UserNotFound, + PrintCodeExpired, + PrintLimitExceed, + ObjectNotFound, + AlreadyExists ) from print_service.routes.base import app from print_service.settings import get_settings @@ -130,7 +134,6 @@ async def generate_error(req: starlette.requests.Request, exc: PINGenerateError) status_code=500, ) - @app.exception_handler(FileIsNotReceived) async def file_not_received(req: starlette.requests.Request, exc: FileIsNotReceived): return JSONResponse( @@ -207,3 +210,39 @@ async def not_uploaded(req: starlette.requests.Request, exc: IsNotUploaded): ).model_dump(), status_code=415, ) + +@app.exception_handler(PrintLimitExceed) +async def exceed_print_limit(req: starlette.requests.Request, exc: PrintLimitExceed): + return JSONResponse( + content=StatusResponseModel( + status="Error", message=f"{exc}", ru="Превышено максимально число печатей для файла" + ).model_dump(), + status_code=410, + ) + +@app.exception_handler(PrintCodeExpired) +async def expire_pin(req: starlette.requests.Request, exc: PrintCodeExpired): + return JSONResponse( + content=StatusResponseModel( + status="Error", message=f"{exc}", ru="Время жизни Pin закончилось" + ).model_dump(), + status_code=410, + ) + +@app.exception_handler(ObjectNotFound) +async def obj_not_found(req: starlette.requests.Request, exc: ObjectNotFound): + return JSONResponse( + content=StatusResponseModel( + status="Error", message=f"{exc}", ru="Объект не найден" + ).model_dump(), + status_code=404 + ) + +@app.exception_handler(AlreadyExists) +async def obj_exists(req: starlette.requests.Request, exc: AlreadyExists): + return JSONResponse( + content=StatusResponseModel( + status="Error", message=f"{exc}", ru="Объект уже существует" + ).model_dump(), + status_code=403 + ) \ No newline at end of file diff --git a/print_service/routes/file.py b/print_service/routes/file.py index 21e3570..35c55b8 100644 --- a/print_service/routes/file.py +++ b/print_service/routes/file.py @@ -6,7 +6,6 @@ import aiofiles.os from auth_lib.fastapi import UnionAuth from fastapi import APIRouter, File, UploadFile -from fastapi.exceptions import HTTPException from fastapi.params import Depends from fastapi_sqlalchemy import db from pydantic import Field, field_validator @@ -24,11 +23,9 @@ PINNotFound, TooLargeSize, TooManyPages, - UnprocessableFileInstance, - UserNotFound, ) -from print_service.models import File as FileModel -from print_service.models import UnionMember +from print_service.models.db import File as FileModel +from print_service.models.db import UnionMember from print_service.schema import BaseModel from print_service.settings import Settings, get_settings from print_service.utils import checking_for_pdf, generate_filename, generate_pin, get_file @@ -119,7 +116,7 @@ async def send( Полученный пин-код можно использовать в методах POST и GET `/file/{pin}`. """ - user = db.session.query(UnionMember) + user = UnionMember.query(session=db.session) if not settings.ALLOW_STUDENT_NUMBER: user = user.filter(UnionMember.union_number != None) @@ -145,12 +142,16 @@ async def send( except RuntimeError: raise PINGenerateError() filename = generate_filename(inp.filename) - file_model = FileModel(pin=pin, file=filename, source=inp.source) - file_model.owner = user - file_model.option_copies = inp.options.copies - file_model.option_pages = inp.options.pages - file_model.option_two_sided = inp.options.two_sided - db.session.add(file_model) + file_model = FileModel.create( + session=db.session, + pin=pin, + file=filename, + source=inp.source, + owner=user, + option_copies=inp.options.copies, + option_pages=inp.options.pages, + option_two_sided=inp.options.two_sided, + ) db.session.commit() return { @@ -186,7 +187,7 @@ async def upload_file( if file == ...: raise FileIsNotReceived() file_model = ( - db.session.query(FileModel) + FileModel.query(session=db.session) .filter(func.upper(FileModel.pin) == pin.upper()) .order_by(FileModel.created_at.desc()) .one_or_none() @@ -253,7 +254,7 @@ async def update_file_options( можно бесконечное количество раз. Можно изменять настройки по одной.""" options = inp.options.model_dump(exclude_unset=True) file_model = ( - db.session.query(FileModel) + FileModel.query(session=db.session) .filter(func.upper(FileModel.pin) == pin.upper()) .order_by(FileModel.created_at.desc()) .one_or_none() @@ -288,6 +289,7 @@ async def update_file_options( 404: {'model': StatusResponseModel, 'detail': 'Pin not found'}, 415: {'model': StatusResponseModel, 'detail': 'File error'}, 416: {'model': StatusResponseModel, 'detail': 'Invalid page request'}, + 410: {'model': StatusResponseModel, 'detail': 'Print code expired'} }, response_model=ReceiveOutput, ) diff --git a/print_service/routes/user.py b/print_service/routes/user.py index 34a3a3c..533827e 100644 --- a/print_service/routes/user.py +++ b/print_service/routes/user.py @@ -3,14 +3,13 @@ from auth_lib.fastapi import UnionAuth from fastapi import APIRouter, Depends -from fastapi.exceptions import HTTPException from fastapi_sqlalchemy import db -from pydantic import constr, validate_call +from pydantic import constr from sqlalchemy import and_, func, or_ from print_service import __version__ from print_service.exceptions import UnionStudentDuplicate, UserNotFound -from print_service.models import UnionMember +from print_service.models.db import UnionMember from print_service.schema import BaseModel from print_service.settings import get_settings @@ -22,7 +21,7 @@ # region schemas class UserCreate(BaseModel): - username: constr(strip_whitespace=True, to_upper=True, min_length=1) + surname: constr(strip_whitespace=True, to_upper=True, min_length=1) union_number: Optional[constr(strip_whitespace=True, to_upper=True, min_length=1)] student_number: Optional[constr(strip_whitespace=True, to_upper=True, min_length=1)] @@ -52,7 +51,7 @@ async def check_union_member( """Проверяет наличие пользователя в списке.""" surname = surname.upper() - user = db.session.query(UnionMember) + user = UnionMember.query(session=db.session) if not settings.ALLOW_STUDENT_NUMBER: user = user.filter(UnionMember.union_number != None) user: UnionMember = user.filter( @@ -67,7 +66,7 @@ async def check_union_member( return bool(user) if not user: - raise UserNotFound() + raise UserNotFound(obj=UnionMember, obj_id_or_name=surname) else: return { 'surname': user.surname, @@ -95,7 +94,7 @@ def update_list( for user in input.users: db_user: UnionMember = ( - db.session.query(UnionMember) + UnionMember.query(session=db.session) .filter( or_( and_( @@ -112,18 +111,16 @@ def update_list( ) if db_user: - db_user.surname = user.username - db_user.union_number = user.union_number - db_user.student_number = user.student_number + UnionMember.update(session=db.session, id=db_user.id, **user.model_dump(exclude_unset=False)) else: db.session.add( UnionMember( - surname=user.username, + surname=user.surname, union_number=user.union_number, student_number=user.student_number, ) ) - db.session.flush() + UnionMember.create(session=db.session, **user.model_dump(exclude_unset=False)) db.session.commit() return {"status": "ok", "count": len(input.users)} diff --git a/print_service/settings.py b/print_service/settings.py index c7de741..88868ce 100644 --- a/print_service/settings.py +++ b/print_service/settings.py @@ -4,7 +4,7 @@ from typing import List from auth_lib.fastapi import UnionAuthSettings -from pydantic import AnyUrl, ConfigDict, DirectoryPath, PostgresDsn, RedisDsn +from pydantic import ConfigDict, DirectoryPath, PostgresDsn, RedisDsn from pydantic_settings import BaseSettings @@ -20,11 +20,13 @@ class Settings(UnionAuthSettings, BaseSettings): MAX_PAGE_COUNT: int = 50 STORAGE_TIME: int = 7 * 24 # Время хранения файла в часах STATIC_FOLDER: DirectoryPath | None = None + MAX_PRINTS_PER_PIN: int = 7 #тестовое максимальное число печатей для одного кода ALLOW_STUDENT_NUMBER: bool = False PIN_SYMBOLS: str = string.ascii_uppercase + string.digits PIN_LENGTH: int = 6 + PIN_TTL: int = 3600 #тестовое время жизни кода печати CORS_ALLOW_ORIGINS: list[str] = ['*'] CORS_ALLOW_CREDENTIALS: bool = True diff --git a/print_service/utils/__init__.py b/print_service/utils/__init__.py index 1c44b33..45a5ce2 100644 --- a/print_service/utils/__init__.py +++ b/print_service/utils/__init__.py @@ -1,12 +1,10 @@ import io -import math import random import re from datetime import date, datetime, timedelta from os.path import abspath, exists from fastapi import File -from fastapi.exceptions import HTTPException from PyPDF4 import PdfFileReader from sqlalchemy import func from sqlalchemy.orm.session import Session @@ -16,11 +14,12 @@ InvalidPageRequest, IsNotUploaded, UnprocessableFileInstance, + PrintLimitExceed, + PrintCodeExpired ) -from print_service.models import File -from print_service.models import File as FileModel -from print_service.models import PrintFact -from print_service.routes import exc_handlers +from print_service.models.db import File +from print_service.models.db import File as FileModel +from print_service.models.db import PrintFact from print_service.settings import Settings, get_settings @@ -57,7 +56,7 @@ def generate_filename(original_filename: str): def get_file(dbsession, pin: str or list[str]): pin = [pin.upper()] if isinstance(pin, str) else tuple(p.upper() for p in pin) files: list[FileModel] = ( - dbsession.query(FileModel) + FileModel.query(session=dbsession) .filter(func.upper(FileModel.pin).in_(pin)) .order_by(FileModel.created_at.desc()) .all() @@ -85,8 +84,19 @@ def get_file(dbsession, pin: str or list[str]): if f.flatten_pages: if number_of_pages > max(f.flatten_pages): raise InvalidPageRequest() + #тут должна быть проверка на строк годности и число распечатанных документов(print_facts у FileModel) + if f.created_at + timedelta(hours=settings.PIN_TTL) >= datetime.now(): + raise PrintCodeExpired() + + if len(f.print_facts) > settings.MAX_PRINTS_PER_PIN: + raise PrintLimitExceed() + + file_model = PrintFact(file_id=f.id, owner_id=f.owner_id, sheets_used=f.sheets_count) - dbsession.add(file_model) + + PrintFact.create( + session=dbsession, file_id=f.id, owner_id=f.owner_id, sheets_used=f.sheets_count + ) dbsession.commit() return result diff --git a/test_client.py b/test_client.py index c53cdf7..3aadd34 100644 --- a/test_client.py +++ b/test_client.py @@ -1,14 +1,12 @@ import asyncio - import websockets - async def hello(): async with websockets.connect( - "ws://localhost:8000/qr", extra_headers={"Authorization": 'token ADAQ-123456789'} + "ws://localhost:8000/qr", + additional_headers={"Authorization": "token ADAQ-123456789"} ) as websocket: async for message in websocket: print(message) - -asyncio.run(hello()) +asyncio.run(hello()) \ No newline at end of file diff --git a/tests/test_routes/conftest.py b/tests/test_routes/conftest.py index 8eb798b..10cde43 100644 --- a/tests/test_routes/conftest.py +++ b/tests/test_routes/conftest.py @@ -2,7 +2,7 @@ import pytest -from print_service.models import File, PrintFact, UnionMember +from print_service.models.db import File, PrintFact, UnionMember @pytest.fixture(scope='function') @@ -13,13 +13,13 @@ def union_member_user(dbsession): union_number='6666667', student_number='13033224', ) - dbsession.add(UnionMember(**union_member)) + UnionMember.create(session=dbsession, **union_member) dbsession.commit() yield union_member - db_user = dbsession.query(UnionMember).filter(UnionMember.id == union_member['id']).one_or_none() + db_user = UnionMember.query(session=dbsession).filter(UnionMember.id == union_member['id']).one_or_none() assert db_user is not None - dbsession.query(PrintFact).filter(PrintFact.owner_id == union_member['id']).delete() - dbsession.query(UnionMember).filter(UnionMember.id == union_member['id']).delete() + PrintFact.query(session=dbsession).filter(PrintFact.owner_id == union_member['id']).delete() + UnionMember.query(session=dbsession).filter(UnionMember.id == union_member['id']).delete() dbsession.commit() @@ -32,12 +32,12 @@ def uploaded_file_db(dbsession, union_member_user, client): "options": {"pages": "", "copies": 1, "two_sided": False}, } res = client.post('/file', json=body) - db_file = dbsession.query(File).filter(File.pin == res.json()['pin']).one_or_none() + db_file = File.query(session=dbsession).filter(File.pin == res.json()['pin']).one_or_none() yield db_file - file = dbsession.query(File).filter(File.pin == res.json()['pin']).one_or_none() + file = File.query(session=dbsession).filter(File.pin == res.json()['pin']).one_or_none() assert file is not None - dbsession.query(PrintFact).filter(PrintFact.file_id == file.id).delete() - dbsession.query(File).filter(File.pin == res.json()['pin']).delete() + PrintFact.query(session=dbsession).filter(PrintFact.file_id == file.id).delete() + File.query(session=dbsession).filter(File.pin == res.json()['pin']).delete() dbsession.commit() @@ -60,8 +60,8 @@ def pin_pdf(dbsession, union_member_user, client): res = client.post('/file', json=body) pin = res.json()['pin'] yield pin - file = dbsession.query(File).filter(File.pin == res.json()['pin']).one_or_none() + file = File.query(session=dbsession).filter(File.pin == res.json()['pin']).one_or_none() assert file is not None - dbsession.query(PrintFact).filter(PrintFact.file_id == file.id).delete() - dbsession.query(File).filter(File.pin == res.json()['pin']).delete() + PrintFact.query(session=dbsession).filter(PrintFact.file_id == file.id).delete() + File.query(session=dbsession).filter(File.pin == res.json()['pin']).delete() dbsession.commit() diff --git a/tests/test_routes/test_file.py b/tests/test_routes/test_file.py index 864e245..50d71b0 100644 --- a/tests/test_routes/test_file.py +++ b/tests/test_routes/test_file.py @@ -5,7 +5,7 @@ from starlette import status from print_service.exceptions import FileNotFound, InvalidPageRequest, IsNotUploaded -from print_service.models import File +from print_service.models.db import File from print_service.settings import get_settings from print_service.utils import checking_for_pdf, get_file @@ -25,7 +25,7 @@ def test_post_success(union_member_user, client, dbsession): } res = client.post(url, data=json.dumps(body)) assert res.status_code == status.HTTP_200_OK - db_file = dbsession.query(File).filter(File.pin == res.json()['pin']).one_or_none() + db_file = File.query(session=dbsession).filter(File.pin == res.json()['pin']).one_or_none() assert db_file is not None assert db_file.source == 'webapp' body2 = { @@ -36,14 +36,13 @@ def test_post_success(union_member_user, client, dbsession): } res2 = client.post(url, data=json.dumps(body2)) assert res2.status_code == status.HTTP_200_OK - db_file2 = dbsession.query(File).filter(File.pin == res2.json()['pin']).one_or_none() + db_file2 = File.query(session=dbsession).filter(File.pin == res2.json()['pin']).one_or_none() assert db_file2 is not None assert db_file2.source == 'unknown' - dbsession.delete(db_file) - dbsession.delete(db_file2) + File.delete(db_file.id, session=dbsession) + File.delete(db_file2.id, session=dbsession) dbsession.commit() - def test_post_unauthorized_user(client): body = { "surname": 'surname', diff --git a/tests/test_routes/test_user.py b/tests/test_routes/test_user.py index 97fdb44..c526e62 100644 --- a/tests/test_routes/test_user.py +++ b/tests/test_routes/test_user.py @@ -33,7 +33,7 @@ def test_post_success(client, dbsession): body = { 'users': [ { - 'username': 'paul', + 'surname': 'paul', 'union_number': '1966', 'student_number': '1967', } @@ -41,8 +41,8 @@ def test_post_success(client, dbsession): } res = client.post(url, data=json.dumps(body)) assert res.status_code == status.HTTP_200_OK - dbsession.query(UnionMember).filter( - UnionMember.surname == body['users'][0]['username'], + UnionMember.query(session=dbsession).filter( + UnionMember.surname == body['users'][0]['surname'], UnionMember.union_number == body['users'][0]['union_number'], UnionMember.student_number == body['users'][0]['student_number'], ).delete() @@ -55,12 +55,12 @@ def test_post_success(client, dbsession): pytest.param( [ { - 'username': 'paul', + 'surname': 'paul', 'union_number': '404man', 'student_number': '30311', }, { - 'username': 'marty', + 'surname': 'marty', 'union_number': '404man', 'student_number': '303112', }, @@ -70,12 +70,12 @@ def test_post_success(client, dbsession): pytest.param( [ { - 'username': 'alice', + 'surname': 'alice', 'union_number': '500', 'student_number': '42', }, { - 'username': 'polly', + 'surname': 'polly', 'union_number': '503', 'student_number': '42', },