From fce6f118a52cf6d8c1c9696ac6bdec321aa47f35 Mon Sep 17 00:00:00 2001 From: ChristianOertlin Date: Fri, 5 Apr 2024 15:51:20 +0200 Subject: [PATCH] refactore(add store) (#125)(major) # Description Rewrites of the sql backend to phase out sqlmodel and update to pydantic2 --- genotype_api/api/app.py | 12 +- genotype_api/api/endpoints/analyses.py | 23 +- genotype_api/api/endpoints/plates.py | 25 +- genotype_api/api/endpoints/samples.py | 39 ++- genotype_api/api/endpoints/snps.py | 32 +- genotype_api/api/endpoints/users.py | 23 +- genotype_api/config.py | 10 +- genotype_api/database/base_handler.py | 11 + genotype_api/database/crud/create.py | 102 +++---- genotype_api/database/crud/delete.py | 44 ++- genotype_api/database/crud/read.py | 288 ++++++++---------- genotype_api/database/crud/update.py | 158 +++++----- genotype_api/database/database.py | 64 ++++ genotype_api/database/models.py | 152 ++++----- genotype_api/database/session_handler.py | 19 -- genotype_api/database/store.py | 28 ++ genotype_api/dto/analysis.py | 14 +- genotype_api/dto/dto.py | 5 - genotype_api/dto/plate.py | 21 +- genotype_api/dto/sample.py | 34 ++- genotype_api/dto/snp.py | 8 +- genotype_api/dto/user.py | 8 +- genotype_api/exceptions.py | 4 + genotype_api/models.py | 2 +- genotype_api/security.py | 59 ++-- genotype_api/services/__init__.py | 0 .../services/endpoint_services/__init__.py | 0 .../analysis_service.py | 36 +-- .../endpoint_services/base_service.py | 8 + .../plate_service.py | 61 ++-- .../sample_service.py | 71 ++--- .../snp_service.py | 18 +- .../user_service.py | 34 +-- requirements.txt | 10 +- 34 files changed, 695 insertions(+), 728 deletions(-) create mode 100644 genotype_api/database/base_handler.py create mode 100644 genotype_api/database/database.py delete mode 100644 genotype_api/database/session_handler.py create mode 100644 genotype_api/database/store.py delete mode 100644 genotype_api/dto/dto.py create mode 100644 genotype_api/services/__init__.py create mode 100644 genotype_api/services/endpoint_services/__init__.py rename genotype_api/services/{analysis_service => endpoint_services}/analysis_service.py (63%) create mode 100644 genotype_api/services/endpoint_services/base_service.py rename genotype_api/services/{plate_service => endpoint_services}/plate_service.py (69%) rename genotype_api/services/{sample_service => endpoint_services}/sample_service.py (67%) rename genotype_api/services/{snp_service => endpoint_services}/snp_service.py (64%) rename genotype_api/services/{user_service => endpoint_services}/user_service.py (63%) diff --git a/genotype_api/api/app.py b/genotype_api/api/app.py index e83b4f8..2e1bcf0 100644 --- a/genotype_api/api/app.py +++ b/genotype_api/api/app.py @@ -7,8 +7,8 @@ from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware -from genotype_api.config import security_settings -from genotype_api.database.session_handler import create_db_and_tables +from genotype_api.config import security_settings, settings +from genotype_api.database.database import create_all_tables, initialise_database, close_session from genotype_api.api.endpoints import samples, snps, users, plates, analyses from sqlalchemy.exc import NoResultFound @@ -74,4 +74,10 @@ def welcome(): @app.on_event("startup") def on_startup(): - create_db_and_tables() + initialise_database(settings.db_uri) + create_all_tables() + + +@app.on_event("shutdown") +async def on_shutdown(): + close_session() diff --git a/genotype_api/api/endpoints/analyses.py b/genotype_api/api/endpoints/analyses.py index 327d2b5..26b9ba7 100644 --- a/genotype_api/api/endpoints/analyses.py +++ b/genotype_api/api/endpoints/analyses.py @@ -4,29 +4,30 @@ from fastapi import APIRouter, Depends, File, Query, UploadFile, status, HTTPException from fastapi.responses import JSONResponse -from sqlmodel import Session - -from genotype_api.database.models import User +from genotype_api.database.store import Store, get_store from genotype_api.dto.analysis import AnalysisResponse +from genotype_api.dto.user import CurrentUser -from genotype_api.database.session_handler import get_session from genotype_api.exceptions import AnalysisNotFoundError from genotype_api.security import get_active_user -from genotype_api.services.analysis_service.analysis_service import AnalysisService +from genotype_api.services.endpoint_services.analysis_service import ( + AnalysisService, +) + router = APIRouter() -def get_analysis_service(session: Session = Depends(get_session)): - return AnalysisService(session) +def get_analysis_service(store: Store = Depends(get_store)) -> AnalysisService: + return AnalysisService(store) @router.get("/{analysis_id}", response_model=AnalysisResponse) def read_analysis( analysis_id: int, analysis_service: AnalysisService = Depends(get_analysis_service), - current_user: User = Depends(get_active_user), + current_user: CurrentUser = Depends(get_active_user), ): """Return analysis.""" try: @@ -43,7 +44,7 @@ def read_analyses( skip: int = 0, limit: int = Query(default=100, lte=100), analysis_service: AnalysisService = Depends(get_analysis_service), - current_user: User = Depends(get_active_user), + current_user: CurrentUser = Depends(get_active_user), ): """Return all analyses.""" try: @@ -59,7 +60,7 @@ def read_analyses( def delete_analysis( analysis_id: int, analysis_service: AnalysisService = Depends(get_analysis_service), - current_user: User = Depends(get_active_user), + current_user: CurrentUser = Depends(get_active_user), ): """Delete analysis based on analysis id.""" try: @@ -78,7 +79,7 @@ def delete_analysis( def upload_sequence_analysis( file: UploadFile = File(...), analysis_service: AnalysisService = Depends(get_analysis_service), - current_user: User = Depends(get_active_user), + current_user: CurrentUser = Depends(get_active_user), ): """Reading VCF file, creating and uploading sequence analyses and sample objects to the database.""" diff --git a/genotype_api/api/endpoints/plates.py b/genotype_api/api/endpoints/plates.py index a8038d0..9618d19 100644 --- a/genotype_api/api/endpoints/plates.py +++ b/genotype_api/api/endpoints/plates.py @@ -4,21 +4,22 @@ from typing import Literal from fastapi import APIRouter, Depends, File, Query, UploadFile, status, HTTPException from fastapi.responses import JSONResponse -from sqlmodel import Session from genotype_api.database.filter_models.plate_models import PlateOrderParams -from genotype_api.database.models import User -from genotype_api.database.session_handler import get_session + + +from genotype_api.database.store import Store, get_store from genotype_api.dto.plate import PlateResponse +from genotype_api.dto.user import CurrentUser from genotype_api.exceptions import PlateNotFoundError from genotype_api.security import get_active_user -from genotype_api.services.plate_service.plate_service import PlateService +from genotype_api.services.endpoint_services.plate_service import PlateService router = APIRouter() -def get_plate_service(session: Session = Depends(get_session)) -> PlateService: - return PlateService(session) +def get_plate_service(store: Store = Depends(get_store)) -> PlateService: + return PlateService(store) @router.post( @@ -29,7 +30,7 @@ def get_plate_service(session: Session = Depends(get_session)) -> PlateService: def upload_plate( file: UploadFile = File(...), plate_service: PlateService = Depends(get_plate_service), - current_user: User = Depends(get_active_user), + current_user: CurrentUser = Depends(get_active_user), ): return plate_service.upload_plate(file) @@ -44,10 +45,10 @@ def sign_off_plate( method_document: str = Query(...), method_version: str = Query(...), plate_service: PlateService = Depends(get_plate_service), - current_user: User = Depends(get_active_user), + current_user: CurrentUser = Depends(get_active_user), ): """Sign off a plate. - This means that current User sign off that the plate is checked + This means that current CurrentUser sign off that the plate is checked Add Depends with current user """ @@ -84,7 +85,7 @@ def sign_off_plate( def read_plate( plate_id: int, plate_service: PlateService = Depends(get_plate_service), - current_user: User = Depends(get_active_user), + current_user: CurrentUser = Depends(get_active_user), ): """Display information about a plate.""" try: @@ -107,7 +108,7 @@ async def read_plates( skip: int | None = 0, limit: int | None = 10, plate_service: PlateService = Depends(get_plate_service), - current_user: User = Depends(get_active_user), + current_user: CurrentUser = Depends(get_active_user), ): """Display all plates""" order_params = PlateOrderParams( @@ -125,7 +126,7 @@ async def read_plates( def delete_plate( plate_id: int, plate_service: PlateService = Depends(get_plate_service), - current_user: User = Depends(get_active_user), + current_user: CurrentUser = Depends(get_active_user), ): """Delete plate.""" try: diff --git a/genotype_api/api/endpoints/samples.py b/genotype_api/api/endpoints/samples.py index 73dcb20..88858f1 100644 --- a/genotype_api/api/endpoints/samples.py +++ b/genotype_api/api/endpoints/samples.py @@ -4,27 +4,24 @@ from fastapi import APIRouter, Depends, Query from fastapi.responses import JSONResponse -from sqlmodel import Session from starlette import status from genotype_api.constants import Sexes, Types from genotype_api.database.filter_models.sample_models import SampleFilterParams -from genotype_api.database.models import ( - Sample, - User, -) -from genotype_api.database.session_handler import get_session -from genotype_api.dto.sample import SampleResponse + +from genotype_api.database.store import Store, get_store +from genotype_api.dto.sample import SampleResponse, SampleCreate +from genotype_api.dto.user import CurrentUser from genotype_api.exceptions import SampleNotFoundError, SampleExistsError from genotype_api.models import MatchResult, SampleDetail from genotype_api.security import get_active_user -from genotype_api.services.sample_service.sample_service import SampleService +from genotype_api.services.endpoint_services.sample_service import SampleService router = APIRouter() -def get_sample_service(session: Session = Depends(get_session)) -> SampleService: - return SampleService(session) +def get_sample_service(store: Store = Depends(get_store)) -> SampleService: + return SampleService(store) @router.get( @@ -34,7 +31,7 @@ def get_sample_service(session: Session = Depends(get_session)) -> SampleService def read_sample( sample_id: str, sample_service: SampleService = Depends(get_sample_service), - current_user: User = Depends(get_active_user), + current_user: CurrentUser = Depends(get_active_user), ): try: return sample_service.get_sample(sample_id) @@ -48,12 +45,12 @@ def read_sample( "/", ) def create_sample( - sample: Sample, + sample: SampleCreate, sample_service: SampleService = Depends(get_sample_service), - current_user: User = Depends(get_active_user), + current_user: CurrentUser = Depends(get_active_user), ): try: - sample_service.create_sample(sample=sample) + sample_service.create_sample(sample_create=sample) new_sample: SampleResponse = sample_service.get_sample(sample_id=sample.id) if not new_sample: return JSONResponse( @@ -90,7 +87,7 @@ def read_samples( commented: bool | None = False, status_missing: bool | None = False, sample_service: SampleService = Depends(get_sample_service), - current_user: User = Depends(get_active_user), + current_user: CurrentUser = Depends(get_active_user), ): """Returns a list of samples matching the provided filters.""" filter_params = SampleFilterParams( @@ -113,7 +110,7 @@ def update_sex( genotype_sex: Sexes | None = None, sequence_sex: Sexes | None = None, sample_service: SampleService = Depends(get_sample_service), - current_user: User = Depends(get_active_user), + current_user: CurrentUser = Depends(get_active_user), ): """Updating sex field on sample and sample analyses.""" try: @@ -135,7 +132,7 @@ def update_comment( sample_id: str, comment: str = Query(...), sample_service: SampleService = Depends(get_sample_service), - current_user: User = Depends(get_active_user), + current_user: CurrentUser = Depends(get_active_user), ): """Updating comment field on sample.""" try: @@ -155,7 +152,7 @@ def set_sample_status( sample_id: str, sample_service: SampleService = Depends(get_sample_service), status: Literal["pass", "fail", "cancel"] | None = None, - current_user: User = Depends(get_active_user), + current_user: CurrentUser = Depends(get_active_user), ): """Check sample analyses and update sample status accordingly.""" try: @@ -175,7 +172,7 @@ def match( date_min: date | None = date.min, date_max: date | None = date.max, sample_service: SampleService = Depends(get_sample_service), - current_user: User = Depends(get_active_user), + current_user: CurrentUser = Depends(get_active_user), ) -> list[MatchResult]: """Match sample genotype against all other genotypes.""" return sample_service.get_match_results( @@ -196,7 +193,7 @@ def match( def get_status_detail( sample_id: str, sample_service: SampleService = Depends(get_sample_service), - current_user: User = Depends(get_active_user), + current_user: CurrentUser = Depends(get_active_user), ): try: return sample_service.get_status_detail(sample_id) @@ -210,7 +207,7 @@ def get_status_detail( def delete_sample( sample_id: str, sample_service: SampleService = Depends(get_sample_service), - current_user: User = Depends(get_active_user), + current_user: CurrentUser = Depends(get_active_user), ): """Delete sample and its Analyses.""" sample_service.delete_sample(sample_id) diff --git a/genotype_api/api/endpoints/snps.py b/genotype_api/api/endpoints/snps.py index 7d0b748..2e140fb 100644 --- a/genotype_api/api/endpoints/snps.py +++ b/genotype_api/api/endpoints/snps.py @@ -1,29 +1,23 @@ """Routes for the snps""" -from fastapi import APIRouter, Depends, HTTPException, Query, UploadFile -from sqlmodel import Session -from sqlmodel.sql.expression import Select, SelectOfScalar -from starlette.responses import JSONResponse +from fastapi import APIRouter, Depends, Query, UploadFile + -from genotype_api.database.crud.create import create_snps -from genotype_api.database.crud.read import get_snps, get_snps_by_limit_and_skip -from genotype_api.database.models import SNP, User -from genotype_api.database.crud import delete -from genotype_api.database.session_handler import get_session +from starlette.responses import JSONResponse +from genotype_api.database.store import Store, get_store from genotype_api.dto.snp import SNPResponse +from genotype_api.dto.user import CurrentUser from genotype_api.exceptions import SNPExistsError from genotype_api.security import get_active_user -from genotype_api.services.snp_reader_service.snp_reader import SNPReaderService -from genotype_api.services.snp_service.snp_service import SNPService -SelectOfScalar.inherit_cache = True -Select.inherit_cache = True +from genotype_api.services.endpoint_services.snp_service import SNPService + router = APIRouter() -def get_snp_service(session: Session = Depends(get_session)) -> SNPService: - return SNPService(session) +def get_snp_service(store: Store = Depends(get_store)) -> SNPService: + return SNPService(store) @router.get("/", response_model=list[SNPResponse]) @@ -31,16 +25,16 @@ def read_snps( skip: int = 0, limit: int = Query(default=100, lte=100), snp_service: SNPService = Depends(get_snp_service), - current_user: User = Depends(get_active_user), + current_user: CurrentUser = Depends(get_active_user), ): return snp_service.get_snps(skip=skip, limit=limit) -@router.post("/", response_model=list[SNP]) +@router.post("/", response_model=list[SNPResponse]) async def upload_snps( snps_file: UploadFile, snp_service: SNPService = Depends(get_snp_service), - current_user: User = Depends(get_active_user), + current_user: CurrentUser = Depends(get_active_user), ): try: return snp_service.upload_snps(snps_file) @@ -51,7 +45,7 @@ async def upload_snps( @router.delete("/") def delete_snps( snp_service: SNPService = Depends(get_snp_service), - current_user: User = Depends(get_active_user), + current_user: CurrentUser = Depends(get_active_user), ): """Delete all SNPs""" diff --git a/genotype_api/api/endpoints/users.py b/genotype_api/api/endpoints/users.py index 2977478..096be0c 100644 --- a/genotype_api/api/endpoints/users.py +++ b/genotype_api/api/endpoints/users.py @@ -2,30 +2,29 @@ from fastapi import APIRouter, Depends, Query, HTTPException from pydantic import EmailStr -from sqlmodel import Session + from starlette import status from starlette.responses import JSONResponse -from genotype_api.database.models import User -from genotype_api.database.session_handler import get_session -from genotype_api.dto.user import UserRequest, UserResponse +from genotype_api.database.store import get_store, Store +from genotype_api.dto.user import UserRequest, UserResponse, CurrentUser from genotype_api.exceptions import UserNotFoundError, UserArchiveError, UserExistsError from genotype_api.security import get_active_user -from genotype_api.services.user_service.user_service import UserService +from genotype_api.services.endpoint_services.user_service import UserService router = APIRouter() -def get_user_service(session: Session = Depends(get_session)) -> UserService: - return UserService(session) +def get_user_service(store: Store = Depends(get_store)) -> UserService: + return UserService(store) @router.get("/{user_id}", response_model=UserResponse) def read_user( user_id: int, user_service: UserService = Depends(get_user_service), - current_user: User = Depends(get_active_user), + current_user: CurrentUser = Depends(get_active_user), ) -> UserResponse: try: return user_service.get_user(user_id) @@ -37,7 +36,7 @@ def read_user( def delete_user( user_id: int, user_service: UserService = Depends(get_user_service), - current_user: User = Depends(get_active_user), + current_user: CurrentUser = Depends(get_active_user), ) -> JSONResponse: try: user_service.delete_user(user_id) @@ -56,7 +55,7 @@ def change_user_email( user_id: int, email: EmailStr, user_service: UserService = Depends(get_user_service), - current_user: User = Depends(get_active_user), + current_user: CurrentUser = Depends(get_active_user), ) -> UserResponse: try: return user_service.update_user_email(user_id=user_id, email=email) @@ -69,7 +68,7 @@ def read_users( skip: int = 0, limit: int = Query(default=100, lte=100), user_service: UserService = Depends(get_user_service), - current_user: User = Depends(get_active_user), + current_user: CurrentUser = Depends(get_active_user), ) -> list[UserResponse]: return user_service.get_users(skip=skip, limit=limit) @@ -79,7 +78,7 @@ def read_users( def create_user( user: UserRequest, user_service: UserService = Depends(get_user_service), - current_user: User = Depends(get_active_user), + current_user: CurrentUser = Depends(get_active_user), ): try: return user_service.create_user(user) diff --git a/genotype_api/config.py b/genotype_api/config.py index d39a359..2602846 100644 --- a/genotype_api/config.py +++ b/genotype_api/config.py @@ -1,6 +1,6 @@ from pathlib import Path -from pydantic import BaseSettings +from pydantic_settings import BaseSettings GENOTYPE_PACKAGE = Path(__file__).parent PACKAGE_ROOT: Path = GENOTYPE_PACKAGE.parent @@ -22,10 +22,10 @@ class Config: class SecuritySettings(BaseSettings): """Settings for serving the genotype-api app""" - client_id = "" - algorithm = "" - jwks_uri = "https://www.googleapis.com/oauth2/v3/certs" - api_root_path = "/" + client_id: str = "" + algorithm: str = "" + jwks_uri: str = "https://www.googleapis.com/oauth2/v3/certs" + api_root_path: str = "/" class Config: env_file = str(ENV_FILE) diff --git a/genotype_api/database/base_handler.py b/genotype_api/database/base_handler.py new file mode 100644 index 0000000..77d3e83 --- /dev/null +++ b/genotype_api/database/base_handler.py @@ -0,0 +1,11 @@ +from dataclasses import dataclass + +from sqlalchemy.orm import Session + + +@dataclass +class BaseHandler: + """All queries in one base class.""" + + def __init__(self, session: Session): + self.session = session diff --git a/genotype_api/database/crud/create.py b/genotype_api/database/crud/create.py index 1943ef3..2579d59 100644 --- a/genotype_api/database/crud/create.py +++ b/genotype_api/database/crud/create.py @@ -1,67 +1,55 @@ import logging -from fastapi import HTTPException -from sqlmodel import Session -from sqlmodel.sql.expression import Select, SelectOfScalar +from genotype_api.database.base_handler import BaseHandler from genotype_api.database.models import Analysis, Plate, Sample, User, SNP -from genotype_api.dto.dto import PlateCreate from genotype_api.dto.user import UserRequest from genotype_api.exceptions import SampleExistsError -SelectOfScalar.inherit_cache = True -Select.inherit_cache = True - LOG = logging.getLogger(__name__) -def create_analysis(session: Session, analysis: Analysis) -> Analysis: - session.add(analysis) - session.commit() - session.refresh(analysis) - return analysis - - -def create_plate(session: Session, plate: PlateCreate) -> Plate: - db_plate = Plate.from_orm(plate) - db_plate.analyses = plate.analyses # not sure why from_orm wont pick up the analyses - session.add(db_plate) - session.commit() - session.refresh(db_plate) - LOG.info(f"Creating plate with id {db_plate.plate_id}.") - return db_plate - - -def create_sample(session: Session, sample: Sample) -> Sample: - """Creates a sample in the database.""" - - sample_in_db = session.get(Sample, sample.id) - if sample_in_db: - raise SampleExistsError - session.add(sample) - session.commit() - session.refresh(sample) - return sample - - -def create_analyses_samples(session: Session, analyses: list[Analysis]) -> list[Sample]: - """creating samples in an analysis if not already in db.""" - return [ - create_sample(session=session, sample=Sample(id=analysis.sample_id)) - for analysis in analyses - if not session.get(Sample, analysis.sample_id) - ] - - -def create_user(session: Session, user: UserRequest): - db_user = User.from_orm(user) - session.add(db_user) - session.commit() - session.refresh(db_user) - return db_user - - -def create_snps(session: Session, snps: list[SNP]) -> list[SNP]: - session.add_all(snps) - session.commit() - return snps +class CreateHandler(BaseHandler): + + def create_analysis(self, analysis: Analysis) -> Analysis: + self.session.add(analysis) + self.session.commit() + self.session.refresh(analysis) + return analysis + + def create_plate(self, plate: Plate) -> Plate: + self.session.add(plate) + self.session.commit() + self.session.refresh(plate) + LOG.info(f"Creating plate with id {plate.plate_id}.") + return plate + + def create_sample(self, sample: Sample) -> Sample: + """Creates a sample in the database.""" + sample_in_db = self.session.query(Sample).filter(Sample.id == sample.id).one_or_none() + if sample_in_db: + raise SampleExistsError + self.session.add(sample) + self.session.commit() + self.session.refresh(sample) + return sample + + def create_analyses_samples(self, analyses: list[Analysis]) -> list[Sample]: + """creating samples in an analysis if not already in db.""" + return [ + self.create_sample(sample=Sample(id=analysis.sample_id)) + for analysis in analyses + if not self.session.query(Sample).filter(Sample.id == analysis.sample_id).one_or_none() + ] + + def create_user(self, user: UserRequest) -> User: + db_user = User(email=user.email, name=user.name) + self.session.add(db_user) + self.session.commit() + self.session.refresh(db_user) + return db_user + + def create_snps(self, snps: list[SNP]) -> list[SNP]: + self.session.add_all(snps) + self.session.commit() + return snps diff --git a/genotype_api/database/crud/delete.py b/genotype_api/database/crud/delete.py index 45c0aec..0fb6218 100644 --- a/genotype_api/database/crud/delete.py +++ b/genotype_api/database/crud/delete.py @@ -1,38 +1,32 @@ import logging - from sqlalchemy import delete -from sqlmodel import Session -from sqlmodel.sql.expression import Select, SelectOfScalar -from genotype_api.database.models import Analysis, Plate, Sample, User, SNP -SelectOfScalar.inherit_cache = True -Select.inherit_cache = True +from genotype_api.database.base_handler import BaseHandler +from genotype_api.database.models import Analysis, Plate, Sample, User, SNP LOG = logging.getLogger(__name__) -def delete_analysis(session: Session, analysis: Analysis) -> None: - session.delete(analysis) - session.commit() - - -def delete_plate(session: Session, plate: Plate) -> None: - session.delete(plate) - session.commit() - +class DeleteHandler(BaseHandler): -def delete_sample(session: Session, sample: Sample) -> None: - session.delete(sample) - session.commit() + def delete_analysis(self, analysis: Analysis) -> None: + self.session.delete(analysis) + self.session.commit() + def delete_plate(self, plate: Plate) -> None: + self.session.delete(plate) + self.session.commit() -def delete_user(session: Session, user: User) -> None: - session.delete(user) - session.commit() + def delete_sample(self, sample: Sample) -> None: + self.session.delete(sample) + self.session.commit() + def delete_user(self, user: User) -> None: + self.session.delete(user) + self.session.commit() -def delete_snps(session) -> any: - result = session.exec(delete(SNP)) - session.commit() - return result + def delete_snps(self) -> any: + result = self.session.execute(delete(SNP)) + self.session.commit() + return result diff --git a/genotype_api/database/crud/read.py b/genotype_api/database/crud/read.py index 9b643a4..d74a81e 100644 --- a/genotype_api/database/crud/read.py +++ b/genotype_api/database/crud/read.py @@ -1,12 +1,9 @@ import logging from datetime import timedelta, date - from sqlalchemy import func, desc, asc from sqlalchemy.orm import Query -from sqlmodel import Session, select -from sqlmodel.sql.expression import Select, SelectOfScalar - from genotype_api.constants import Types +from genotype_api.database.base_handler import BaseHandler from genotype_api.database.filter_models.plate_models import PlateOrderParams from genotype_api.database.filter_models.sample_models import SampleFilterParams from genotype_api.database.models import ( @@ -17,175 +14,146 @@ SNP, ) -SelectOfScalar.inherit_cache = True -Select.inherit_cache = True - LOG = logging.getLogger(__name__) -def get_analyses_from_plate(plate_id: int, session: Session) -> list[Analysis]: - statement = select(Analysis).where(Analysis.plate_id == plate_id) - return session.exec(statement).all() - - -def get_analysis_by_type_sample( - sample_id: str, analysis_type: str, session: Session -) -> Analysis | None: - statement = select(Analysis).where( - Analysis.sample_id == sample_id, Analysis.type == analysis_type - ) - return session.exec(statement).first() - - -def get_analysis_by_id(session: Session, analysis_id: int) -> Analysis: - """Get analysis""" - statement = select(Analysis).where(Analysis.id == analysis_id) - return session.exec(statement).one() - - -def get_analyses(session: Session) -> list[Analysis]: - statement = select(Analysis) - return session.exec(statement).all() - - -def get_analyses_with_skip_and_limit(session: Session, skip: int, limit: int) -> list[Analysis]: - statement = select(Analysis).offset(skip).limit(limit) - return session.exec(statement).all() +class ReadHandler(BaseHandler): + def get_analyses_from_plate(self, plate_id: int) -> list[Analysis]: + return self.session.query(Analysis).filter(Analysis.plate_id == plate_id).all() -def get_analyses_by_type_between_dates( - session, analysis_type: str, date_min: date, date_max: date -) -> list[Analysis]: - analyses: Query = session.query(Analysis).filter( - Analysis.type == analysis_type, - Analysis.created_at > date_min - timedelta(days=1), - Analysis.created_at < date_max + timedelta(days=1), - ) - return analyses.all() - - -def get_analysis_by_type_and_sample_id( - session: Session, analysis_type: str, sample_id: str -) -> Analysis: - return ( - session.query(Analysis).filter( - Analysis.sample_id == sample_id, Analysis.type == analysis_type + def get_analysis_by_type_sample( + self, + sample_id: str, + analysis_type: str, + ) -> Analysis | None: + return ( + self.session.query(Analysis) + .filter(Analysis.sample_id == sample_id, Analysis.type == analysis_type) + .first() ) - ).one() - - -def get_plate_by_id(session: Session, plate_id: int) -> Plate: - statement = select(Plate).where(Plate.id == plate_id) - return session.exec(statement).one() - - -def get_ordered_plates(session: Session, order_params: PlateOrderParams) -> list[Plate]: - sort_func = desc if order_params.sort_order == "descend" else asc - return session.exec( - select(Plate) - .order_by(sort_func(order_params.order_by)) - .offset(order_params.skip) - .limit(order_params.limit) - ).all() - -def get_incomplete_samples(statement: SelectOfScalar) -> SelectOfScalar: - """Returning sample query statement for samples with less than two analyses.""" - return ( - statement.group_by(Analysis.sample_id) - .order_by(Analysis.created_at) - .having(func.count(Analysis.sample_id) < 2) - ) - - -def get_filtered_samples(session: Session, filter_params: SampleFilterParams) -> list[Sample]: - statement: SelectOfScalar = select(Sample).distinct().join(Analysis) - if filter_params.sample_id: - statement: SelectOfScalar = get_samples( - statement=statement, sample_id=filter_params.sample_id + def get_analysis_by_id(self, analysis_id: int) -> Analysis: + return self.session.query(Analysis).filter(Analysis.id == analysis_id).one() + + def get_analyses(self) -> list[Analysis]: + return self.session.query(Analysis).all() + + def get_analyses_with_skip_and_limit(self, skip: int, limit: int) -> list[Analysis]: + return self.session.query(Analysis).offset(skip).limit(limit).all() + + def get_analyses_by_type_between_dates( + self, analysis_type: Types, date_min: date, date_max: date + ) -> list[Analysis]: + return ( + self.session.query(Analysis) + .filter( + Analysis.type == analysis_type, + Analysis.created_at > date_min - timedelta(days=1), + Analysis.created_at < date_max + timedelta(days=1), + ) + .all() ) - if filter_params.plate_id: - statement: SelectOfScalar = get_plate_samples( - statement=statement, plate_id=filter_params.plate_id - ) - if filter_params.is_incomplete: - statement: SelectOfScalar = get_incomplete_samples(statement=statement) - if filter_params.is_commented: - statement: SelectOfScalar = get_commented_samples(statement=statement) - if filter_params.is_missing: - statement: SelectOfScalar = get_status_missing_samples(statement=statement) - return session.exec( - statement.order_by(Sample.created_at.desc()) - .offset(filter_params.skip) - .limit(filter_params.limit) - ).all() - - -def get_plate_samples(statement: SelectOfScalar, plate_id: str) -> SelectOfScalar: - """Returning sample query statement for samples analysed on a specific plate.""" - return statement.where(Analysis.plate_id == plate_id) - - -def get_commented_samples(statement: SelectOfScalar) -> SelectOfScalar: - """Returning sample query statement for samples with no comment.""" - - return statement.where(Sample.comment != None) - - -def get_status_missing_samples(statement: SelectOfScalar) -> SelectOfScalar: - """Returning sample query statement for samples with no comment.""" - - return statement.where(Sample.status == None) - - -def get_sample(session: Session, sample_id: str) -> Sample: - """Get sample or raise 404.""" - - statement = select(Sample).where(Sample.id == sample_id) - return session.exec(statement).one() - - -def get_samples(statement: SelectOfScalar, sample_id: str) -> SelectOfScalar: - """Returns a query for samples containing the given sample_id.""" - return statement.where(Sample.id.contains(sample_id)) - - -def get_user_by_id(session: Session, user_id: int): - statement = select(User).where(User.id == user_id) - return session.exec(statement).one() - - -def get_user_by_email(session: Session, email: str) -> User | None: - statement = select(User).where(User.email == email) - return session.exec(statement).first() - - -def get_users(session: Session, skip: int = 0, limit: int = 100) -> list[User]: - statement = select(User).offset(skip).limit(limit) - return session.exec(statement).all() + def get_analysis_by_type_and_sample_id(self, analysis_type: str, sample_id: str) -> Analysis: + return ( + self.session.query(Analysis) + .filter(Analysis.sample_id == sample_id, Analysis.type == analysis_type) + .one() + ) -def get_users_with_skip_and_limit(session: Session, skip: int, limit: int) -> list[User]: - return session.exec(select(User).offset(skip).limit(limit)).all() + def get_plate_by_id(self, plate_id: int) -> Plate: + return self.session.query(Plate).filter(Plate.id == plate_id).one() + def get_plate_by_plate_id(self, plate_id: str) -> Plate: + return self.session.query(Plate).filter(Plate.plate_id == plate_id).one() -def check_analyses_objects( - session: Session, analyses: list[Analysis], analysis_type: Types -) -> None: - """Raising 400 if any analysis in the list already exist in the database""" - for analysis_obj in analyses: - existing_analysis: Analysis = get_analysis_by_type_sample( - session=session, - sample_id=analysis_obj.sample_id, - analysis_type=analysis_type, + def get_ordered_plates(self, order_params: PlateOrderParams) -> list[Plate]: + sort_func = desc if order_params.sort_order == "descend" else asc + return ( + self.session.query(Plate) + .order_by(sort_func(order_params.order_by)) + .offset(order_params.skip) + .limit(order_params.limit) + .all() ) - if existing_analysis: - session.delete(existing_analysis) - -def get_snps(session) -> list[SNP]: - return session.exec(select(SNP)).all() + def get_filtered_samples(self, filter_params: SampleFilterParams) -> list[Sample]: + query = self.session.query(Sample).distinct().join(Analysis) + if filter_params.sample_id: + query = self._get_samples(query, filter_params.sample_id) + if filter_params.plate_id: + query = self._get_plate_samples(query, filter_params.plate_id) + if filter_params.is_incomplete: + query = self._get_incomplete_samples(query) + if filter_params.is_commented: + query = self._get_commented_samples(query) + if filter_params.is_missing: + query = self._get_status_missing_samples(query) + return ( + query.order_by(Sample.created_at.desc()) + .offset(filter_params.skip) + .limit(filter_params.limit) + .all() + ) + @staticmethod + def _get_incomplete_samples(query: Query) -> Query: + """Returning sample query statement for samples with less than two analyses.""" + return ( + query.group_by(Analysis.sample_id) + .order_by(Analysis.created_at) + .having(func.count(Analysis.sample_id) < 2) + ) -def get_snps_by_limit_and_skip(session: Session, skip: int, limit: int) -> list[SNP]: - return session.exec(select(SNP).offset(skip).limit(limit)).all() + @staticmethod + def _get_plate_samples(query: Query, plate_id: str) -> Query: + """Returning sample query statement for samples analysed on a specific plate.""" + return query.filter(Analysis.plate_id == plate_id) + + @staticmethod + def _get_commented_samples(query: Query) -> Query: + """Returning sample query statement for samples with no comment.""" + return query.filter(Sample.comment != None) + + @staticmethod + def _get_status_missing_samples(query: Query) -> Query: + """Returning sample query statement for samples with no comment.""" + return query.filter(Sample.status == None) + + @staticmethod + def _get_samples(query: Query, sample_id: str) -> Query: + """Returns a query for samples containing the given sample_id.""" + return query.filter(Sample.id.contains(sample_id)) + + def get_sample(self, sample_id: str) -> Sample: + """Get sample or raise 404.""" + return self.session.query(Sample).filter(Sample.id == sample_id).one() + + def get_user_by_id(self, user_id: int) -> User: + return self.session.query(User).filter(User.id == user_id).one() + + def get_user_by_email(self, email: str) -> User | None: + return self.session.query(User).filter(User.email == email).first() + + def get_users(self, skip: int = 0, limit: int = 100) -> list[User]: + return self.session.query(User).offset(skip).limit(limit).all() + + def get_users_with_skip_and_limit(self, skip: int, limit: int) -> list[User]: + return self.session.query(User).offset(skip).limit(limit).all() + + def check_analyses_objects(self, analyses: list[Analysis], analysis_type: Types) -> None: + """Raising 400 if any analysis in the list already exist in the database""" + for analysis_obj in analyses: + existing_analysis = self.get_analysis_by_type_sample( + sample_id=analysis_obj.sample_id, + analysis_type=analysis_type, + ) + if existing_analysis: + self.session.delete(existing_analysis) + + def get_snps(self) -> list[SNP]: + return self.session.query(SNP).all() + + def get_snps_by_limit_and_skip(self, skip: int, limit: int) -> list[SNP]: + return self.session.query(SNP).offset(skip).limit(limit).all() diff --git a/genotype_api/database/crud/update.py b/genotype_api/database/crud/update.py index e52c8d4..5a43d3a 100644 --- a/genotype_api/database/crud/update.py +++ b/genotype_api/database/crud/update.py @@ -1,92 +1,86 @@ -import types - from pydantic import EmailStr -from sqlmodel import Session + from genotype_api.constants import Types -from genotype_api.database.crud.read import get_sample +from genotype_api.database.base_handler import BaseHandler from genotype_api.database.filter_models.plate_models import PlateSignOff from genotype_api.database.filter_models.sample_models import SampleSexesUpdate from genotype_api.database.models import Sample, Plate, User -from sqlmodel.sql.expression import Select, SelectOfScalar - from genotype_api.exceptions import SampleNotFoundError from genotype_api.services.match_genotype_service.match_genotype import MatchGenotypeService -SelectOfScalar.inherit_cache = True -Select.inherit_cache = True - - -def refresh_sample_status(sample: Sample, session: Session) -> Sample: - if len(sample.analyses) != 2: - sample.status = None - else: - results = MatchGenotypeService.check_sample(sample=sample) - sample.status = "fail" if "fail" in results.dict().values() else "pass" - - session.add(sample) - session.commit() - session.refresh(sample) - return sample - - -def update_sample_comment(session: Session, sample_id: str, comment: str) -> Sample: - sample: Sample = get_sample(session=session, sample_id=sample_id) - if not sample: - raise SampleNotFoundError - sample.comment = comment - session.add(sample) - session.commit() - session.refresh(sample) - return sample - - -def update_sample_status(session: Session, sample_id: str, status: str | None) -> Sample: - sample: Sample = get_sample(session=session, sample_id=sample_id) - if not sample: - raise SampleNotFoundError - sample.status = status - session.add(sample) - session.commit() - session.refresh(sample) - return sample - - -def refresh_plate(session: Session, plate: Plate) -> None: - session.refresh(plate) - - -def update_plate_sign_off(session: Session, plate: Plate, plate_sign_off: PlateSignOff) -> Plate: - plate.signed_by = plate_sign_off.user_id - plate.signed_at = plate_sign_off.signed_at - plate.method_document = plate_sign_off.method_document - plate.method_version = plate_sign_off.method_version - session.commit() - session.refresh(plate) - return plate - - -def update_sample_sex(session: Session, sexes_update: SampleSexesUpdate) -> Sample: - sample: Sample = get_sample(session=session, sample_id=sexes_update.sample_id) - if not sample: - raise SampleNotFoundError - sample.sex = sexes_update.sex - for analysis in sample.analyses: - if sexes_update.genotype_sex and analysis.type == Types.GENOTYPE: - analysis.sex = sexes_update.genotype_sex - elif sexes_update.sequence_sex and analysis.type == Types.SEQUENCE: - analysis.sex = sexes_update.sequence_sex - session.add(analysis) - session.add(sample) - session.commit() - session.refresh(sample) - sample: Sample = refresh_sample_status(session=session, sample=sample) - return sample - -def update_user_email(session: Session, user: User, email: EmailStr) -> User: - user.email = email - session.add(user) - session.commit() - session.refresh(user) - return user +class UpdateHandler(BaseHandler): + + def refresh_sample_status( + self, + sample: Sample, + ) -> Sample: + if len(sample.analyses) != 2: + sample.status = None + else: + results = MatchGenotypeService.check_sample(sample=sample) + sample.status = "fail" if "fail" in results.dict().values() else "pass" + + self.session.add(sample) + self.session.commit() + self.session.refresh(sample) + return sample + + def update_sample_comment(self, sample_id: str, comment: str) -> Sample: + sample: Sample = self.get_sample(sample_id=sample_id) + if not sample: + raise SampleNotFoundError + sample.comment = comment + self.session.add(sample) + self.session.commit() + self.session.refresh(sample) + return sample + + def update_sample_status(self, sample_id: str, status: str | None) -> Sample: + sample: Sample = self.get_sample(sample_id=sample_id) + if not sample: + raise SampleNotFoundError + sample.status = status + self.session.add(sample) + self.session.commit() + self.session.refresh(sample) + return sample + + def refresh_plate(self, plate: Plate) -> None: + self.session.refresh(plate) + + def update_plate_sign_off(self, plate: Plate, plate_sign_off: PlateSignOff) -> Plate: + plate.signed_by = plate_sign_off.user_id + plate.signed_at = plate_sign_off.signed_at + plate.method_document = plate_sign_off.method_document + plate.method_version = plate_sign_off.method_version + self.session.commit() + self.session.refresh(plate) + return plate + + def update_sample_sex(self, sexes_update: SampleSexesUpdate) -> Sample: + sample = ( + self.session.query(Sample).filter(Sample.id == sexes_update.sample_id).one_or_none() + ) + if not sample: + raise SampleNotFoundError + sample.sex = sexes_update.sex + for analysis in sample.analyses: + if sexes_update.genotype_sex and analysis.type == Types.GENOTYPE: + analysis.sex = sexes_update.genotype_sex + elif sexes_update.sequence_sex and analysis.type == Types.SEQUENCE: + analysis.sex = sexes_update.sequence_sex + self.session.add(analysis) + self.session.add(sample) + self.session.commit() + self.session.refresh(sample) + sample = self.refresh_sample_status(sample) + return sample + + def update_user_email(self, user: User, email: EmailStr) -> User: + user.email = email + self.session.add(user) + self.session.commit() + self.session.refresh(user) + return user diff --git a/genotype_api/database/database.py b/genotype_api/database/database.py new file mode 100644 index 0000000..3bd3ebd --- /dev/null +++ b/genotype_api/database/database.py @@ -0,0 +1,64 @@ +"""Hold the database information""" + +from sqlalchemy import create_engine, inspect +from sqlalchemy.engine.base import Engine +from sqlalchemy.engine.reflection import Inspector +from sqlalchemy.orm import Session, scoped_session, sessionmaker + +from genotype_api.exceptions import GenotypeDBError +from genotype_api.database.models import Base + +SESSION: scoped_session | None = None +ENGINE: Engine | None = None + + +def initialise_database(db_uri: str) -> None: + """Initialize the SQLAlchemy engine and session for genotype api.""" + global SESSION, ENGINE + + ENGINE = create_engine(db_uri, pool_pre_ping=True) + session_factory = sessionmaker(ENGINE) + SESSION = scoped_session(session_factory) + + +def get_session() -> scoped_session: + """Get a SQLAlchemy session with a connection to genotype api.""" + if not SESSION: + raise GenotypeDBError + return SESSION + + +def get_scoped_session_registry() -> scoped_session | None: + """Get the scoped session registry for genotype api.""" + return SESSION + + +def get_engine() -> Engine: + """Get the SQLAlchemy engine with a connection to genotype api.""" + if not ENGINE: + raise GenotypeDBError + return ENGINE + + +def create_all_tables() -> None: + """Create all tables in genotype api.""" + session: Session = get_session() + Base.metadata.create_all(bind=session.get_bind()) + + +def drop_all_tables() -> None: + """Drop all tables in genotype api.""" + session: Session = get_session() + Base.metadata.drop_all(bind=session.get_bind()) + + +def get_tables() -> list[str]: + """Get a list of all tables in genotype api.""" + engine: Engine = get_engine() + inspector: Inspector = inspect(engine) + return inspector.get_table_names() + + +def close_session(): + """Close the global database session of the genotype api.""" + SESSION.close() diff --git a/genotype_api/database/models.py b/genotype_api/database/models.py index 426b7b2..e7dba9c 100644 --- a/genotype_api/database/models.py +++ b/genotype_api/database/models.py @@ -1,135 +1,115 @@ from collections import Counter from datetime import datetime +from sqlalchemy import Integer, DateTime +from sqlalchemy.orm import relationship +from sqlalchemy import ( + Column, + ForeignKey, + String, +) +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy_utils import EmailType -from pydantic import EmailStr, constr -from sqlalchemy import Index -from sqlmodel import Field, Relationship, SQLModel -from genotype_api.constants import Sexes, Status, Types +class Base(DeclarativeBase): + pass -class GenotypeBase(SQLModel): - rsnumber: constr(max_length=10) | None - analysis_id: int | None = Field(default=None, foreign_key="analysis.id") - allele_1: constr(max_length=1) | None - allele_2: constr(max_length=1) | None - - -class Genotype(GenotypeBase, table=True): +class Genotype(Base): __tablename__ = "genotype" - __table_args__ = (Index("_analysis_rsnumber", "analysis_id", "rsnumber", unique=True),) - id: int | None = Field(default=None, primary_key=True) - analysis: "Analysis" = Relationship(back_populates="genotypes") + id = Column(Integer, primary_key=True) + rsnumber = Column(String(length=10)) + analysis_id = Column(Integer, ForeignKey("analysis.id")) + allele_1 = Column(String(length=1)) + allele_2 = Column(String(length=1)) - @property - def alleles(self) -> list[str]: - """Return sorted because we are not dealing with phased data.""" + analysis = relationship("Analysis", back_populates="genotypes") + @property + def alleles(self): return sorted([self.allele_1, self.allele_2]) @property - def is_ok(self) -> bool: - """Check that the allele determination is ok.""" + def is_ok(self): return "0" not in self.alleles -class AnalysisBase(SQLModel): - type: Types - source: str | None - sex: Sexes | None - created_at: datetime | None = datetime.now() - sample_id: constr(max_length=32) | None = Field(default=None, foreign_key="sample.id") - plate_id: str | None = Field(default=None, foreign_key="plate.id") - - -class Analysis(AnalysisBase, table=True): +class Analysis(Base): __tablename__ = "analysis" - __table_args__ = (Index("_sample_type", "sample_id", "type", unique=True),) - id: int | None = Field(default=None, primary_key=True) - sample: "Sample" = Relationship(back_populates="analyses") - plate: list["Plate"] = Relationship(back_populates="analyses") - genotypes: list["Genotype"] = Relationship(back_populates="analysis") + id = Column(Integer, primary_key=True) + type = Column(String) + source = Column(String) + sex = Column(String) + created_at = Column(DateTime, default=datetime.now) + sample_id = Column(String(length=32), ForeignKey("sample.id")) + plate_id = Column(Integer, ForeignKey("plate.id")) - def check_no_calls(self) -> dict[str, int]: - """Check that genotypes look ok.""" + sample = relationship("Sample", back_populates="analyses") + plate = relationship("Plate", back_populates="analyses") + genotypes = relationship("Genotype", back_populates="analysis") + + def check_no_calls(self): calls = ["known" if genotype.is_ok else "unknown" for genotype in self.genotypes] return Counter(calls) -class SampleSlim(SQLModel): - status: Status | None - comment: str | None - - -class SampleBase(SampleSlim): - sex: Sexes | None - created_at: datetime | None = datetime.now() - - -class Sample(SampleBase, table=True): +class Sample(Base): __tablename__ = "sample" - id: constr(max_length=32) | None = Field(default=None, primary_key=True) - analyses: list["Analysis"] = Relationship(back_populates="sample") + id = Column(String(length=32), primary_key=True) + status = Column(String) + comment = Column(String) + sex = Column(String) + created_at = Column(DateTime, default=datetime.now) - @property - def genotype_analysis(self) -> Analysis | None: - """Return genotype analysis.""" + analyses = relationship("Analysis", back_populates="sample") + @property + def genotype_analysis(self): for analysis in self.analyses: if analysis.type == "genotype": return analysis - return None @property - def sequence_analysis(self) -> Analysis | None: - """Return sequence analysis.""" - + def sequence_analysis(self): for analysis in self.analyses: if analysis.type == "sequence": return analysis - return None -class SNPBase(SQLModel): - ref: constr(max_length=1) | None - chrom: constr(max_length=5) | None - pos: int | None - - -class SNP(SNPBase, table=True): +class SNP(Base): __tablename__ = "snp" - """Represent a SNP position under investigation.""" - - id: constr(max_length=32) | None = Field(default=None, primary_key=True) - -class UserBase(SQLModel): - email: EmailStr = Field(index=True, unique=True) - name: str | None = "" + id = Column(String(length=32), primary_key=True) + ref = Column(String(length=1)) + chrom = Column(String(length=5)) + pos = Column(Integer) -class User(UserBase, table=True): +class User(Base): __tablename__ = "user" - id: int | None = Field(default=None, primary_key=True) - plates: list["Plate"] = Relationship(back_populates="user") + id = Column(Integer, primary_key=True) + email = Column(EmailType, unique=True, index=True) + name = Column(String, default="") -class PlateBase(SQLModel): - created_at: datetime | None = datetime.now() - plate_id: constr(max_length=16) = Field(index=True, unique=True) - signed_by: int | None = Field(default=None, foreign_key="user.id") - signed_at: datetime | None - method_document: str | None - method_version: str | None + plates = relationship("Plate", back_populates="user") -class Plate(PlateBase, table=True): +class Plate(Base): __tablename__ = "plate" - id: int | None = Field(default=None, primary_key=True) - user: "User" = Relationship(back_populates="plates") - analyses: list["Analysis"] = Relationship(back_populates="plate") + + id = Column(Integer, primary_key=True) + created_at = Column(DateTime, default=datetime.now) + plate_id = Column(String(length=16), unique=True, index=True) + signed_by = Column(Integer, ForeignKey("user.id")) + signed_at = Column(DateTime) + method_document = Column(String) + method_version = Column(String) + + user = relationship("User", back_populates="plates") + analyses = relationship("Analysis", back_populates="plate") diff --git a/genotype_api/database/session_handler.py b/genotype_api/database/session_handler.py deleted file mode 100644 index c082e69..0000000 --- a/genotype_api/database/session_handler.py +++ /dev/null @@ -1,19 +0,0 @@ -"""Hold the database information""" - -from sqlmodel import Session, SQLModel, create_engine - -from genotype_api.config import settings - -sqlite_file_name = "database.db" -sqlite_url = f"sqlite:///{sqlite_file_name}" - -engine = create_engine(settings.db_uri, pool_pre_ping=True) - - -def get_session(): - with Session(engine) as session: - yield session - - -def create_db_and_tables(): - SQLModel.metadata.create_all(engine) diff --git a/genotype_api/database/store.py b/genotype_api/database/store.py new file mode 100644 index 0000000..4dba733 --- /dev/null +++ b/genotype_api/database/store.py @@ -0,0 +1,28 @@ +"""Module for the store handler.""" + +from sqlalchemy.orm import Session + +from genotype_api.config import DBSettings +from genotype_api.database.crud.create import CreateHandler +from genotype_api.database.crud.delete import DeleteHandler +from genotype_api.database.crud.read import ReadHandler +from genotype_api.database.crud.update import UpdateHandler +from genotype_api.database.database import get_session, initialise_database + + +class Store( + CreateHandler, + DeleteHandler, + ReadHandler, + UpdateHandler, +): + def __init__(self): + self.session: Session = get_session() + DeleteHandler(self.session) + ReadHandler(self.session) + UpdateHandler(self.session) + + +def get_store() -> Store: + """Return a store.""" + return Store() diff --git a/genotype_api/dto/analysis.py b/genotype_api/dto/analysis.py index 3ac0b4b..43797c1 100644 --- a/genotype_api/dto/analysis.py +++ b/genotype_api/dto/analysis.py @@ -9,11 +9,11 @@ class AnalysisResponse(BaseModel): - type: Types | None - source: str | None - sex: Sexes | None - created_at: datetime | None - sample_id: str | None - plate_id: str | None - id: int | None + type: Types | None = None + source: str | None = None + sex: Sexes | None = None + created_at: datetime | None = None + sample_id: str | None = None + plate_id: int | None = None + id: int | None = None genotypes: list[GenotypeResponse] | None = None diff --git a/genotype_api/dto/dto.py b/genotype_api/dto/dto.py deleted file mode 100644 index 86c8b49..0000000 --- a/genotype_api/dto/dto.py +++ /dev/null @@ -1,5 +0,0 @@ -from genotype_api.database import models - - -class PlateCreate(models.PlateBase): - analyses: list[models.Analysis] | None = [] diff --git a/genotype_api/dto/plate.py b/genotype_api/dto/plate.py index 485402f..b2a0d1e 100644 --- a/genotype_api/dto/plate.py +++ b/genotype_api/dto/plate.py @@ -6,6 +6,7 @@ from pydantic import BaseModel, validator, Field, EmailStr from genotype_api.constants import Types, Sexes, Status +from genotype_api.database.models import Analysis class PlateStatusCounts(BaseModel): @@ -17,7 +18,7 @@ class PlateStatusCounts(BaseModel): commented: int = Field(0, nullable=True) class Config: - allow_population_by_field_name = True + populate_by_name = True class UserOnPlate(BaseModel): @@ -32,13 +33,13 @@ class SampleStatus(BaseModel): class AnalysisOnPlate(BaseModel): - type: Types | None - source: str | None - sex: Sexes | None - created_at: datetime | None - sample_id: str | None - plate_id: str | None - id: int | None + type: Types | None = None + source: str | None = None + sex: Sexes | None = None + created_at: datetime | None = None + sample_id: str | None = None + plate_id: int | None = None + id: int | None = None sample: SampleStatus | None = None @@ -49,7 +50,7 @@ class PlateResponse(BaseModel): signed_at: datetime | None = None method_document: str | None = None method_version: str | None = None - id: str | None = None + id: int | None = None user: UserOnPlate | None = None analyses: list[AnalysisOnPlate] | None = None plate_status_counts: PlateStatusCounts | None = None @@ -65,4 +66,4 @@ def check_detail(cls, value, values): return PlateStatusCounts(**status_counts, total=len(analyses), commented=commented) class Config: - validate_all = True + validate_default = True diff --git a/genotype_api/dto/sample.py b/genotype_api/dto/sample.py index 98e4254..3cade4d 100644 --- a/genotype_api/dto/sample.py +++ b/genotype_api/dto/sample.py @@ -1,7 +1,7 @@ """Module for the sample DTOs.""" from datetime import datetime -from pydantic import BaseModel, validator +from pydantic import BaseModel, computed_field from genotype_api.constants import Sexes, Status, Types from genotype_api.dto.genotype import GenotypeResponse @@ -13,40 +13,46 @@ class AnalysisOnSample(BaseModel): type: Types | None = None sex: Sexes | None = None sample_id: str | None = None - plate_id: str | None = None + plate_id: int | None = None id: int | None = None genotypes: list[GenotypeResponse] class SampleResponse(BaseModel): id: str | None = None - status: Status | None - comment: str | None - sex: Sexes | None + status: Status | None = None + comment: str | None = None + sex: Sexes | None = None created_at: datetime | None = datetime.now() - analyses: list[AnalysisOnSample] | None - detail: SampleDetail | None + analyses: list[AnalysisOnSample] | None = None - @validator("detail") - def get_detail(cls, value, values) -> SampleDetail | None: - analyses = values.get("analyses") + @computed_field(alias="detail") + def get_detail(self) -> SampleDetail | None: + analyses = self.analyses if analyses: if len(analyses) != 2: return SampleDetail() - genotype_analysis: list[AnalysisOnSample] = [ + genotype_analysis: AnalysisOnSample = [ analysis for analysis in analyses if analysis.type == "genotype" ][0] - sequence_analysis: list[AnalysisOnSample] = [ + sequence_analysis: AnalysisOnSample = [ analysis for analysis in analyses if analysis.type == "sequence" ][0] status: dict = check_snps( genotype_analysis=genotype_analysis, sequence_analysis=sequence_analysis ) sex: str = check_sex( - sample_sex=values.get("sex"), + sample_sex=self.sex, genotype_analysis=genotype_analysis, sequence_analysis=sequence_analysis, ) - return SampleDetail(**status, sex=sex) return None + + +class SampleCreate(BaseModel): + id: str + status: str | None = None + comment: str | None = None + sex: Sexes + created_at: datetime = datetime.now() diff --git a/genotype_api/dto/snp.py b/genotype_api/dto/snp.py index 4c50c97..bec4972 100644 --- a/genotype_api/dto/snp.py +++ b/genotype_api/dto/snp.py @@ -4,7 +4,7 @@ class SNPResponse(BaseModel): - ref: str - chrom: str - pos: int | None - id: str + ref: str | None = None + chrom: str | None = None + pos: int | None = None + id: str | None = None diff --git a/genotype_api/dto/user.py b/genotype_api/dto/user.py index d9c4221..934c902 100644 --- a/genotype_api/dto/user.py +++ b/genotype_api/dto/user.py @@ -10,7 +10,7 @@ class PlateOnUser(BaseModel): plate_id: str | None = None signed_by: int | None = None signed_at: datetime | None = None - id: str | None = None + id: int | None = None class UserResponse(BaseModel): @@ -23,3 +23,9 @@ class UserResponse(BaseModel): class UserRequest(BaseModel): email: EmailStr name: str + + +class CurrentUser(BaseModel): + id: int + email: EmailStr + name: str diff --git a/genotype_api/exceptions.py b/genotype_api/exceptions.py index 711135e..859c06e 100644 --- a/genotype_api/exceptions.py +++ b/genotype_api/exceptions.py @@ -1,6 +1,10 @@ """Genotype specific exceptions""" +class GenotypeDBError(Exception): + pass + + class SexConflictError(Exception): pass diff --git a/genotype_api/models.py b/genotype_api/models.py index fbcfbda..b1a1f1d 100644 --- a/genotype_api/models.py +++ b/genotype_api/models.py @@ -40,7 +40,7 @@ def validate_status(cls, value, values) -> SampleDetailStatus: return SampleDetailStatus(sex=sex, snps=snps, nocalls=nocalls) class Config: - validate_all = True + validate_default = True class MatchCounts(BaseModel): diff --git a/genotype_api/security.py b/genotype_api/security.py index 86e0111..b30b479 100644 --- a/genotype_api/security.py +++ b/genotype_api/security.py @@ -2,29 +2,29 @@ from fastapi import Depends, HTTPException, Security from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from jose import jwt -from sqlmodel import Session from starlette import status from starlette.requests import Request from genotype_api.config import security_settings -from genotype_api.database.crud.read import get_user_by_email from genotype_api.database.models import User -from genotype_api.database.session_handler import get_session +from genotype_api.database.store import get_store, Store +from genotype_api.dto.user import CurrentUser def decode_id_token(token: str): - payload = jwt.decode( - token, - key=requests.get(security_settings.jwks_uri).json(), - algorithms=[security_settings.algorithm], - audience=security_settings.client_id, - options={ - "verify_at_hash": False, - }, - ) - if not payload: - return jwt.get_unverified_claims(token) - return payload + try: + payload = jwt.decode( + token, + key=requests.get(security_settings.jwks_uri).json(), + algorithms=[security_settings.algorithm], + audience=security_settings.client_id, + options={ + "verify_at_hash": False, + }, + ) + return payload + except jwt.JWTError: + return None class JWTBearer(HTTPBearer): @@ -43,14 +43,17 @@ async def __call__(self, request: Request): status_code=status.HTTP_403_FORBIDDEN, detail="Invalid authentication scheme.", ) - self.verify_jwt(credentials.credentials) - - return credentials.credentials + payload = self.verify_jwt(credentials.credentials) + return {"token": credentials.credentials, "payload": payload} def verify_jwt(self, jwtoken: str) -> dict | None: try: - return decode_id_token(jwtoken) - except Exception: + payload = decode_id_token(jwtoken) + if payload and "email" in payload: + return {"email": payload["email"]} + else: + return None + except jwt.JWTError: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Invalid token or expired token.", @@ -61,12 +64,16 @@ def verify_jwt(self, jwtoken: str) -> dict | None: async def get_active_user( - token: str = Security(jwt_scheme), - session: Session = Depends(get_session), -): + token_info: dict = Security(jwt_scheme), + store: Store = Depends(get_store), +) -> CurrentUser: """Dependency for secure endpoints""" - user = User.parse_obj(decode_id_token(token)) - db_user: User = get_user_by_email(session=session, email=user.email) + user_email = token_info["payload"]["email"] + db_user: User = store.get_user_by_email(email=user_email) if not db_user: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User not in DB") - return user + return CurrentUser( + id=db_user.id, + email=db_user.email, + name=db_user.name, + ) diff --git a/genotype_api/services/__init__.py b/genotype_api/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/genotype_api/services/endpoint_services/__init__.py b/genotype_api/services/endpoint_services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/genotype_api/services/analysis_service/analysis_service.py b/genotype_api/services/endpoint_services/analysis_service.py similarity index 63% rename from genotype_api/services/analysis_service/analysis_service.py rename to genotype_api/services/endpoint_services/analysis_service.py index 4e33bf8..4c8d2d5 100644 --- a/genotype_api/services/analysis_service/analysis_service.py +++ b/genotype_api/services/endpoint_services/analysis_service.py @@ -3,32 +3,22 @@ from pathlib import Path from fastapi import UploadFile -from sqlmodel import Session from genotype_api.constants import Types, FileExtension -from genotype_api.database.crud.create import create_analyses_samples, create_analysis -from genotype_api.database.crud.delete import delete_analysis -from genotype_api.database.crud.read import ( - get_analysis_by_id, - get_analyses_with_skip_and_limit, - check_analyses_objects, -) -from genotype_api.database.crud.update import refresh_sample_status from genotype_api.database.models import Analysis + from genotype_api.dto.analysis import AnalysisResponse from genotype_api.exceptions import AnalysisNotFoundError from genotype_api.file_parsing.files import check_file from genotype_api.file_parsing.vcf import SequenceAnalysis +from genotype_api.services.endpoint_services.base_service import BaseService -class AnalysisService: +class AnalysisService(BaseService): """This service acts as a translational layer between the CRUD and the API.""" - def __init__(self, session: Session): - self.session: Session = session - @staticmethod def _create_analysis_response(analysis: Analysis) -> AnalysisResponse: return AnalysisResponse( @@ -43,14 +33,14 @@ def _create_analysis_response(analysis: Analysis) -> AnalysisResponse: ) def get_analysis(self, analysis_id: int) -> AnalysisResponse: - analysis: Analysis = get_analysis_by_id(session=self.session, analysis_id=analysis_id) + analysis: Analysis = self.store.get_analysis_by_id(analysis_id=analysis_id) if not analysis: raise AnalysisNotFoundError return self._create_analysis_response(analysis) def get_analyses(self, skip: int, limit: int) -> list[AnalysisResponse]: - analyses: list[Analysis] = get_analyses_with_skip_and_limit( - session=self.session, skip=skip, limit=limit + analyses: list[Analysis] = self.store.get_analyses_with_skip_and_limit( + skip=skip, limit=limit ) if not analyses: raise AnalysisNotFoundError @@ -64,18 +54,16 @@ def get_upload_sequence_analyses(self, file: UploadFile) -> list[AnalysisRespons content = file.file.read().decode("utf-8") sequence_analysis = SequenceAnalysis(vcf_file=content, source=str(file_name)) analyses: list[Analysis] = list(sequence_analysis.generate_analyses()) - check_analyses_objects( - session=self.session, analyses=analyses, analysis_type=Types.SEQUENCE - ) - create_analyses_samples(session=self.session, analyses=analyses) + self.store.check_analyses_objects(analyses=analyses, analysis_type=Types.SEQUENCE) + self.store.create_analyses_samples(analyses=analyses) for analysis in analyses: - analysis: Analysis = create_analysis(session=self.session, analysis=analysis) - refresh_sample_status(session=self.session, sample=analysis.sample) + analysis: Analysis = self.store.create_analysis(analysis=analysis) + self.store.refresh_sample_status(sample=analysis.sample) return [self._create_analysis_response(analysis) for analysis in analyses] def delete_analysis(self, analysis_id: int) -> None: - analysis: Analysis = get_analysis_by_id(session=self.session, analysis_id=analysis_id) + analysis: Analysis = self.store.get_analysis_by_id(analysis_id=analysis_id) if not analysis: raise AnalysisNotFoundError - delete_analysis(session=self.session, analysis=analysis) + self.store.delete_analysis(analysis=analysis) diff --git a/genotype_api/services/endpoint_services/base_service.py b/genotype_api/services/endpoint_services/base_service.py new file mode 100644 index 0000000..1a563bf --- /dev/null +++ b/genotype_api/services/endpoint_services/base_service.py @@ -0,0 +1,8 @@ +"""Module for the endpoint service.""" + +from genotype_api.database.store import Store + + +class BaseService: + def __init__(self, store: Store): + self.store: Store = store diff --git a/genotype_api/services/plate_service/plate_service.py b/genotype_api/services/endpoint_services/plate_service.py similarity index 69% rename from genotype_api/services/plate_service/plate_service.py rename to genotype_api/services/endpoint_services/plate_service.py index 53c04dc..c5691a1 100644 --- a/genotype_api/services/plate_service/plate_service.py +++ b/genotype_api/services/endpoint_services/plate_service.py @@ -8,39 +8,18 @@ from fastapi import UploadFile from pydantic import EmailStr -from sqlmodel import Session from starlette import status - - from genotype_api.constants import Types -from genotype_api.database.crud.create import create_analyses_samples, create_plate -from genotype_api.database.crud.delete import delete_analysis, delete_plate -from genotype_api.database.crud.read import ( - check_analyses_objects, - get_plate_by_id, - get_ordered_plates, - get_analyses_from_plate, - get_user_by_id, - get_user_by_email, -) -from genotype_api.database.crud.update import ( - refresh_sample_status, - refresh_plate, - update_plate_sign_off, -) from genotype_api.database.filter_models.plate_models import PlateSignOff, PlateOrderParams from genotype_api.database.models import Plate, Analysis, User -from genotype_api.dto.dto import PlateCreate from genotype_api.dto.plate import PlateResponse, UserOnPlate, AnalysisOnPlate, SampleStatus from genotype_api.exceptions import PlateNotFoundError, UserNotFoundError from genotype_api.file_parsing.excel import GenotypeAnalysis from genotype_api.file_parsing.files import check_file +from genotype_api.services.endpoint_services.base_service import BaseService -class PlateService: - - def __init__(self, session: Session): - self.session: Session = session +class PlateService(BaseService): @staticmethod def _get_analyses_on_plate(plate: Plate) -> list[AnalysisOnPlate] | None: @@ -65,7 +44,7 @@ def _get_analyses_on_plate(plate: Plate) -> list[AnalysisOnPlate] | None: def _get_plate_user(self, plate: Plate) -> UserOnPlate | None: if plate.signed_by: - user: User = get_user_by_id(session=self.session, user_id=plate.signed_by) + user: User = self.store.get_user_by_id(user_id=plate.signed_by) return UserOnPlate(email=user.email, name=user.name, id=user.id) return None @@ -92,7 +71,7 @@ def _get_plate_id_from_file(file_name: Path) -> str: def upload_plate(self, file: UploadFile) -> PlateResponse: file_name: Path = check_file(file_path=file.filename, extension=".xlsx") plate_id: str = self._get_plate_id_from_file(file_name) - db_plate = self.session.get(Plate, plate_id) + db_plate = self.store.get_plate_by_plate_id(plate_id) if db_plate: raise HTTPException( status_code=status.HTTP_409_CONFLICT, @@ -105,26 +84,24 @@ def upload_plate(self, file: UploadFile) -> PlateResponse: include_key="-CG-", ) analyses: list[Analysis] = list(excel_parser.generate_analyses()) - check_analyses_objects( - session=self.session, analyses=analyses, analysis_type=Types.GENOTYPE - ) - create_analyses_samples(session=self.session, analyses=analyses) - plate_obj = PlateCreate(plate_id=plate_id) + self.store.check_analyses_objects(analyses=analyses, analysis_type=Types.GENOTYPE) + self.store.create_analyses_samples(analyses=analyses) + plate_obj = Plate(plate_id=plate_id) plate_obj.analyses = analyses - plate: Plate = create_plate(session=self.session, plate=plate_obj) + plate: Plate = self.store.create_plate(plate=plate_obj) for analysis in plate.analyses: - refresh_sample_status(sample=analysis.sample, session=self.session) - refresh_plate(session=self.session, plate=plate) + self.store.refresh_sample_status(sample=analysis.sample) + self.store.refresh_plate(plate=plate) return self._create_plate_response(plate) def update_plate_sign_off( self, plate_id: int, user_email: EmailStr, method_document: str, method_version: str ) -> PlateResponse: - plate: Plate = get_plate_by_id(session=self.session, plate_id=plate_id) + plate: Plate = self.store.get_plate_by_id(plate_id=plate_id) if not plate: raise PlateNotFoundError - user: User = get_user_by_email(session=self.session, email=user_email) + user: User = self.store.get_user_by_email(email=user_email) if not user: raise UserNotFoundError plate_sign_off = PlateSignOff( @@ -133,29 +110,29 @@ def update_plate_sign_off( method_document=method_document, method_version=method_version, ) - update_plate_sign_off(session=self.session, plate=plate, plate_sign_off=plate_sign_off) + self.store.update_plate_sign_off(plate=plate, plate_sign_off=plate_sign_off) return self._create_plate_response(plate) def get_plate(self, plate_id: int) -> PlateResponse: - plate: Plate = get_plate_by_id(session=self.session, plate_id=plate_id) + plate: Plate = self.store.get_plate_by_id(plate_id=plate_id) if not plate: raise PlateNotFoundError return self._create_plate_response(plate) def get_plates(self, order_params: PlateOrderParams) -> list[PlateResponse]: - plates: list[Plate] = get_ordered_plates(session=self.session, order_params=order_params) + plates: list[Plate] = self.store.get_ordered_plates(order_params=order_params) if not plates: raise PlateNotFoundError return [self._create_plate_response(plate) for plate in plates] def delete_plate(self, plate_id) -> list[int]: """Delete a plate with the given plate id and return associated analysis ids.""" - plate = get_plate_by_id(session=self.session, plate_id=plate_id) + plate = self.store.get_plate_by_id(plate_id=plate_id) if not plate: raise PlateNotFoundError - analyses: list[Analysis] = get_analyses_from_plate(session=self.session, plate_id=plate_id) + analyses: list[Analysis] = self.store.get_analyses_from_plate(plate_id=plate_id) analysis_ids: list[int] = [analyse.id for analyse in analyses] for analysis in analyses: - delete_analysis(session=self.session, analysis=analysis) - delete_plate(session=self.session, plate=plate) + self.store.delete_analysis(analysis=analysis) + self.store.delete_plate(plate=plate) return analysis_ids diff --git a/genotype_api/services/sample_service/sample_service.py b/genotype_api/services/endpoint_services/sample_service.py similarity index 67% rename from genotype_api/services/sample_service/sample_service.py rename to genotype_api/services/endpoint_services/sample_service.py index 02adf42..dfbf26e 100644 --- a/genotype_api/services/sample_service/sample_service.py +++ b/genotype_api/services/endpoint_services/sample_service.py @@ -2,37 +2,19 @@ from datetime import date from typing import Literal - -from sqlmodel import Session - from genotype_api.constants import Types, Sexes -from genotype_api.database.crud.create import create_sample -from genotype_api.database.crud.delete import delete_analysis, delete_sample -from genotype_api.database.crud.read import ( - get_sample, - get_filtered_samples, - get_analysis_by_type_and_sample_id, - get_analyses_by_type_between_dates, -) -from genotype_api.database.crud.update import ( - refresh_sample_status, - update_sample_status, - update_sample_comment, - update_sample_sex, -) from genotype_api.database.filter_models.sample_models import SampleFilterParams, SampleSexesUpdate from genotype_api.database.models import Sample, Analysis + from genotype_api.dto.genotype import GenotypeResponse -from genotype_api.dto.sample import AnalysisOnSample, SampleResponse +from genotype_api.dto.sample import AnalysisOnSample, SampleResponse, SampleCreate from genotype_api.exceptions import SampleNotFoundError from genotype_api.models import SampleDetail, MatchResult +from genotype_api.services.endpoint_services.base_service import BaseService from genotype_api.services.match_genotype_service.match_genotype import MatchGenotypeService -class SampleService: - - def __init__(self, session: Session): - self.session = session +class SampleService(BaseService): @staticmethod def _get_genotype_on_analysis(analysis: Analysis) -> list[GenotypeResponse] | None: @@ -78,30 +60,35 @@ def _get_sample_response(self, sample: Sample) -> SampleResponse: ) def get_sample(self, sample_id: str) -> SampleResponse: - sample: Sample = get_sample(session=self.session, sample_id=sample_id) + sample: Sample = self.store.get_sample(sample_id=sample_id) if not sample: raise SampleNotFoundError if len(sample.analyses) == 2 and not sample.status: - sample: Sample = refresh_sample_status(session=self.session, sample=sample) + sample: Sample = self.store.refresh_sample_status(sample=sample) return self._get_sample_response(sample) def get_samples(self, filter_params: SampleFilterParams) -> list[SampleResponse]: - samples: list[Sample] = get_filtered_samples( - session=self.session, filter_params=filter_params - ) + samples: list[Sample] = self.store.get_filtered_samples(filter_params=filter_params) return [self._get_sample_response(sample) for sample in samples] - def create_sample(self, sample: Sample) -> None: - create_sample(session=self.session, sample=sample) + def create_sample(self, sample_create: SampleCreate) -> None: + sample = Sample( + id=sample_create.id, + status=sample_create.status, + comment=sample_create.comment, + sex=sample_create.sex, + created_at=sample_create.created_at, + ) + self.store.create_sample(sample=sample) def delete_sample(self, sample_id: str) -> None: - sample: Sample = get_sample(session=self.session, sample_id=sample_id) + sample: Sample = self.store.get_sample(sample_id=sample_id) for analysis in sample.analyses: - delete_analysis(session=self.session, analysis=analysis) - delete_sample(session=self.session, sample=sample) + self.store.delete_analysis(analysis=analysis) + self.store.delete_sample(sample=sample) def get_status_detail(self, sample_id: str) -> SampleDetail: - sample: Sample = get_sample(session=self.session, sample_id=sample_id) + sample: Sample = self.store.get_sample(sample_id=sample_id) if len(sample.analyses) != 2: return SampleDetail() return MatchGenotypeService.check_sample(sample=sample) @@ -115,11 +102,11 @@ def get_match_results( date_max: date, ) -> list[MatchResult]: """Get the match results for an analysis type and the comparison type in a given time frame.""" - analyses: list[Analysis] = get_analyses_by_type_between_dates( - session=self.session, analysis_type=comparison_set, date_max=date_max, date_min=date_min + analyses: list[Analysis] = self.store.get_analyses_by_type_between_dates( + analysis_type=comparison_set, date_max=date_max, date_min=date_min ) - sample_analysis: Analysis = get_analysis_by_type_and_sample_id( - session=self.session, analysis_type=analysis_type, sample_id=sample_id + sample_analysis: Analysis = self.store.get_analysis_by_type_and_sample_id( + analysis_type=analysis_type, sample_id=sample_id ) matches: list[MatchResult] = MatchGenotypeService.get_matches( analyses=analyses, sample_analysis=sample_analysis @@ -129,19 +116,15 @@ def get_match_results( def set_sample_status( self, sample_id: str, status: Literal["pass", "fail", "cancel"] | None ) -> SampleResponse: - sample: Sample = update_sample_status( - session=self.session, sample_id=sample_id, status=status - ) + sample: Sample = self.store.update_sample_status(sample_id=sample_id, status=status) return self._get_sample_response(sample) def set_sample_comment(self, sample_id: str, comment: str) -> SampleResponse: - sample: Sample = update_sample_comment( - session=self.session, sample_id=sample_id, comment=comment - ) + sample: Sample = self.store.update_sample_comment(sample_id=sample_id, comment=comment) return self._get_sample_response(sample) def set_sex(self, sample_id: str, sex: Sexes, genotype_sex: Sexes, sequence_sex: Sexes) -> None: sexes_update = SampleSexesUpdate( sample_id=sample_id, sex=sex, genotype_sex=genotype_sex, sequence_sex=sequence_sex ) - update_sample_sex(session=self.session, sexes_update=sexes_update) + self.store.update_sample_sex(sexes_update=sexes_update) diff --git a/genotype_api/services/snp_service/snp_service.py b/genotype_api/services/endpoint_services/snp_service.py similarity index 64% rename from genotype_api/services/snp_service/snp_service.py rename to genotype_api/services/endpoint_services/snp_service.py index 5573796..e19328a 100644 --- a/genotype_api/services/snp_service/snp_service.py +++ b/genotype_api/services/endpoint_services/snp_service.py @@ -2,37 +2,33 @@ from fastapi import UploadFile -from genotype_api.database.crud.create import create_snps -from genotype_api.database.crud.delete import delete_snps -from genotype_api.database.crud.read import get_snps_by_limit_and_skip, get_snps from genotype_api.database.models import SNP + from genotype_api.dto.snp import SNPResponse from genotype_api.exceptions import SNPExistsError +from genotype_api.services.endpoint_services.base_service import BaseService from genotype_api.services.snp_reader_service.snp_reader import SNPReaderService -class SNPService: - - def __init__(self, session): - self.session = session +class SNPService(BaseService): @staticmethod def _get_snp_response(snp: SNP) -> SNPResponse: return SNPResponse(ref=snp.ref, chrom=snp.chrom, pos=snp.pos, id=snp.id) def get_snps(self, skip: int, limit: int) -> list[SNPResponse]: - snps: list[SNP] = get_snps_by_limit_and_skip(session=self.session, skip=skip, limit=limit) + snps: list[SNP] = self.store.get_snps_by_limit_and_skip(skip=skip, limit=limit) return [self._get_snp_response(snp) for snp in snps] def upload_snps(self, snps_file: UploadFile) -> list[SNPResponse]: """Upload snps to the database, raises an error when SNPs already exist.""" - existing_snps: list[SNP] = get_snps(self.session) + existing_snps: list[SNP] = self.store.get_snps(self.session) if existing_snps: raise SNPExistsError snps: list[SNP] = SNPReaderService.read_snps_from_file(snps_file) - new_snps: list[SNP] = create_snps(session=self.session, snps=snps) + new_snps: list[SNP] = self.store.create_snps(snps=snps) return [self._get_snp_response(new_snp) for new_snp in new_snps] def delete_all_snps(self) -> int: - result = delete_snps(self.session) + result = self.store.delete_snps(self.session) return result.rowcount diff --git a/genotype_api/services/user_service/user_service.py b/genotype_api/services/endpoint_services/user_service.py similarity index 63% rename from genotype_api/services/user_service/user_service.py rename to genotype_api/services/endpoint_services/user_service.py index 7c46afe..0ace577 100644 --- a/genotype_api/services/user_service/user_service.py +++ b/genotype_api/services/endpoint_services/user_service.py @@ -1,24 +1,14 @@ """Module to holds the user service.""" from pydantic import EmailStr -from sqlmodel import Session -from genotype_api.database.crud.create import create_user -from genotype_api.database.crud.delete import delete_user -from genotype_api.database.crud.read import ( - get_user_by_id, - get_users_with_skip_and_limit, - get_user_by_email, -) -from genotype_api.database.crud.update import update_user_email from genotype_api.database.models import User +from genotype_api.database.store import Store from genotype_api.dto.user import UserResponse, UserRequest, PlateOnUser from genotype_api.exceptions import UserNotFoundError, UserArchiveError, UserExistsError +from genotype_api.services.endpoint_services.base_service import BaseService -class UserService: - - def __init__(self, session: Session): - self.session: Session = session +class UserService(BaseService): @staticmethod def _get_plates_on_user(user: User) -> list[PlateOnUser] | None: @@ -41,35 +31,33 @@ def _create_user_response(self, user: User) -> UserResponse: return UserResponse(email=user.email, name=user.name, id=user.id, plates=plates) def create_user(self, user: UserRequest): - existing_user: User = get_user_by_email(session=self.session, email=user.email) + existing_user: User = self.store.get_user_by_email(email=user.email) if existing_user: raise UserExistsError - new_user: User = create_user(session=self.session, user=user) + new_user: User = self.store.create_user(user=user) return self._create_user_response(new_user) def get_users(self, skip: int, limit: int) -> list[UserResponse]: - users: list[User] = get_users_with_skip_and_limit( - session=self.session, skip=skip, limit=limit - ) + users: list[User] = self.store.get_users_with_skip_and_limit(skip=skip, limit=limit) return [self._create_user_response(user) for user in users] def get_user(self, user_id: int) -> UserResponse: - user: User = get_user_by_id(session=self.session, user_id=user_id) + user: User = self.store.get_user_by_id(user_id=user_id) if not user: raise UserNotFoundError return self._create_user_response(user) def delete_user(self, user_id: int): - user: User = get_user_by_id(session=self.session, user_id=user_id) + user: User = self.store.get_user_by_id(user_id=user_id) if not user: raise UserNotFoundError if user.plates: raise UserArchiveError - delete_user(session=self.session, user=user) + self.store.delete_user(user=user) def update_user_email(self, user_id: int, email: EmailStr): - user: User = get_user_by_id(session=self.session, user_id=user_id) + user: User = self.store.get_user_by_id(user_id=user_id) if not user: raise UserNotFoundError - user: User = update_user_email(session=self.session, user=user, email=email) + user: User = self.store.update_user_email(user=user, email=email) return self._create_user_response(user) diff --git a/requirements.txt b/requirements.txt index c32a2ea..e799d1f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,24 +1,26 @@ -SQLAlchemy==1.4.30 +SQLAlchemy aiofiles bcrypt bump2version click +starlette coloredlogs email-validator -fastapi==0.75.0 +fastapi>=0.109.1 google-auth gunicorn httptools numpy openpyxl passlib -pydantic==1.10.14 +pydantic pymysql python-dotenv python-jose[cryptography] python-multipart pyyaml requests -sqlmodel uvicorn uvloop +sqlalchemy_utils +pydantic_settings \ No newline at end of file