diff --git a/genotype_api/database/crud/read.py b/genotype_api/database/crud/read.py index f2fc595..fd6d32a 100644 --- a/genotype_api/database/crud/read.py +++ b/genotype_api/database/crud/read.py @@ -12,18 +12,31 @@ from genotype_api.database.filters.analysis_filter import ( AnalysisFilter, apply_analysis_filter, + filter_analyses_by_id, + filter_analyses_by_plate_id, ) from genotype_api.database.filters.genotype_filters import ( GenotypeFilter, apply_genotype_filter, ) -from genotype_api.database.filters.plate_filters import PlateFilter, apply_plate_filter +from genotype_api.database.filters.plate_filters import ( + PlateFilter, + apply_plate_filter, + filter_plates_by_id, + filter_plates_by_plate_id, +) from genotype_api.database.filters.sample_filters import ( SampleFilter, apply_sample_filter, + filter_samples_by_id, ) from genotype_api.database.filters.snp_filters import SNPFilter, apply_snp_filter -from genotype_api.database.filters.user_filters import UserFilter, apply_user_filter +from genotype_api.database.filters.user_filters import ( + UserFilter, + apply_user_filter, + filter_users_by_email, + filter_users_by_id, +) from genotype_api.database.models import SNP, Analysis, Genotype, Plate, Sample, User LOG = logging.getLogger(__name__) @@ -32,40 +45,31 @@ class ReadHandler(BaseHandler): async def get_analyses_by_plate_id(self, plate_id: int) -> list[Analysis]: - analyses: Query = self._get_query(Analysis).options( - selectinload(Analysis.genotypes), selectinload(Analysis.sample) - ) - filter_functions = [AnalysisFilter.BY_PLATE_ID] - filtered_query = apply_analysis_filter( - analyses=analyses, filter_functions=filter_functions, plate_id=plate_id - ) - return (await self.session.execute(filtered_query)).scalars().all() + analyses: Query = self._get_analysis_with_genotypes() + filtered_query = filter_analyses_by_plate_id(plate_id=plate_id, analyses=analyses) + return await self.fetch_all_rows(filtered_query) async def get_analysis_by_id(self, analysis_id: int) -> Analysis: - analyses: Query = self._get_query(Analysis) - filter_functions = [AnalysisFilter.BY_ID] - filtered_query = apply_analysis_filter( - analyses=analyses, filter_functions=filter_functions, analysis_id=analysis_id - ) - return (await self.session.execute(filtered_query)).scalars().first() + analyses: Query = select(Analysis) + filtered_query = filter_analyses_by_id(analysis_id=analysis_id, analyses=analyses) + return await self.fetch_first_row(filtered_query) async def get_analyses(self) -> list[Analysis]: - filtered_query = self._get_query(Analysis) - return (await self.session.execute(filtered_query)).scalars().all() + filtered_query = select(Analysis) + return await self.fetch_all_rows(filtered_query) async def get_analyses_with_skip_and_limit(self, skip: int, limit: int) -> list[Analysis]: - analyses: Query = self._get_query(Analysis) + analyses: Query = select(Analysis) filter_functions = [AnalysisFilter.SKIP_AND_LIMIT] filtered_query = apply_analysis_filter( analyses=analyses, filter_functions=filter_functions, skip=skip, limit=limit ) - return (await self.session.execute(filtered_query)).scalars().all() + return await self.fetch_all_rows(filtered_query) async def get_analyses_by_type_between_dates( self, analysis_type: Types, date_min: date, date_max: date ) -> list[Analysis]: - analyses: Query = self._get_query(Analysis) - + analyses: Query = select(Analysis) filter_functions = [AnalysisFilter.BY_TYPE, AnalysisFilter.BETWEEN_DATES] filtered_query = apply_analysis_filter( analyses=analyses, @@ -76,12 +80,12 @@ async def get_analyses_by_type_between_dates( ) filtered_query = filtered_query.options(selectinload(Analysis.genotypes)) - return (await self.session.execute(filtered_query)).scalars().all() + return await self.fetch_all_rows(filtered_query) async def get_analysis_by_type_and_sample_id( self, sample_id: str, analysis_type: Types ) -> Analysis: - analyses: Query = self._get_query(Analysis) + analyses: Query = select(Analysis) filter_functions = [AnalysisFilter.BY_TYPE, AnalysisFilter.BY_SAMPLE_ID] filtered_query = apply_analysis_filter( analyses=analyses, @@ -89,33 +93,22 @@ async def get_analysis_by_type_and_sample_id( sample_id=sample_id, type=analysis_type, ) - filtered_query = filtered_query.options(selectinload(Analysis.genotypes)) - return (await self.session.execute(filtered_query)).scalars().first() + return await self.fetch_first_row(filtered_query) async def get_plate_by_id(self, plate_id: int) -> Plate: - plates: Query = self._get_query(Plate).options( - selectinload(Plate.analyses).selectinload(Analysis.sample) - ) - filter_functions = [PlateFilter.BY_ID] - filtered_query = apply_plate_filter( - plates=plates, filter_functions=filter_functions, entry_id=plate_id - ) - return (await self.session.execute(filtered_query)).scalars().first() + plates: Query = self._get_plate_with_analyses_and_samples() + filtered_query = filter_plates_by_id(entry_id=plate_id, plates=plates) + return await self.fetch_first_row(filtered_query) async def get_plate_by_plate_id(self, plate_id: str) -> Plate: - plates: Query = self._get_query(Plate).options(selectinload(Plate.analyses)) - filter_functions = [PlateFilter.BY_PLATE_ID] - filtered_query = apply_plate_filter( - plates=plates, filter_functions=filter_functions, plate_id=plate_id - ) - return (await self.session.execute(filtered_query)).scalars().first() + plates: Query = self._get_plate_with_analyses + filtered_query = filter_plates_by_plate_id(plate_id=plate_id, plates=plates) + return await self.fetch_first_row(filtered_query) async def get_ordered_plates(self, order_params: PlateOrderParams) -> list[Plate]: sort_func = desc if order_params.sort_order == "descend" else asc - plates: Query = self._get_query(Plate).options( - selectinload(Plate.analyses).selectinload(Analysis.sample) - ) + plates: Query = self._get_plate_with_analyses_and_samples() filter_functions = [PlateFilter.ORDER, PlateFilter.SKIP_AND_LIMIT] filtered_query = apply_plate_filter( plates=plates, @@ -125,23 +118,18 @@ async def get_ordered_plates(self, order_params: PlateOrderParams) -> list[Plate limit=order_params.limit, sort_func=sort_func, ) - return (await self.session.execute(filtered_query)).scalars().all() + return await self.fetch_all_rows(filtered_query) async def get_genotype_by_id(self, entry_id: int) -> Genotype: - genotypes: Query = self._get_query(Genotype).options(selectinload(Genotype.analysis)) + genotypes: Query = self._get_genotype_with_analysis() filter_functions = [GenotypeFilter.BY_ID] filtered_query = apply_genotype_filter( genotypes=genotypes, filter_functions=filter_functions, entry_id=entry_id ) - return (await self.session.execute(filtered_query)).scalars().first() + return await self.fetch_first_row(filtered_query) async def get_filtered_samples(self, filter_params: SampleFilterParams) -> list[Sample]: - query = ( - select(Sample) - .distinct() - .options(selectinload(Sample.analyses).selectinload(Analysis.genotypes)) - .join(Analysis, Analysis.sample_id == Sample.id) - ) + query = self._get_samples_with_analyses_and_genotypes() if filter_params.sample_id: query = self._get_samples(query, filter_params.sample_id) if filter_params.plate_id: @@ -157,8 +145,9 @@ async def get_filtered_samples(self, filter_params: SampleFilterParams) -> list[ .offset(filter_params.skip) .limit(filter_params.limit) ) - return (await self.session.execute(filtered_query)).scalars().all() + return await self.fetch_all_rows(filtered_query) + # pylint: disable=E1102 @staticmethod def _get_incomplete_samples(query: Query) -> Query: """Returning sample query statement for samples with less than two analyses.""" @@ -176,12 +165,12 @@ def _get_plate_samples(query: Query, plate_id: str) -> Query: @staticmethod def _get_commented_samples(query: Query) -> Query: """Returning sample query statement for samples with no comment.""" - return query.filter(Sample.comment != None) + return query.filter(Sample.comment.is_not(None)) @staticmethod def _get_status_missing_samples(query: Query) -> Query: """Returning sample query statement for samples with no comment.""" - return query.filter(Sample.status == None) + return query.filter(Sample.status.is_(None)) @staticmethod def _get_samples(query: Query, sample_id: str) -> Query: @@ -189,39 +178,27 @@ def _get_samples(query: Query, sample_id: str) -> Query: return query.filter(Sample.id.contains(sample_id)) async def get_sample_by_id(self, sample_id: str) -> Sample: - samples: Query = self._get_query(Sample) - filter_functions = [SampleFilter.BY_ID] - filtered_query = apply_sample_filter( - samples=samples, filter_functions=filter_functions, sample_id=sample_id - ) - filtered_query = filtered_query.options( - selectinload(Sample.analyses).selectinload(Analysis.genotypes) - ) - return (await self.session.execute(filtered_query)).scalars().first() + samples: Query = self._get_samples_with_analyses_and_genotypes() + filtered_query = filter_samples_by_id(sample_id=sample_id, samples=samples) + return await self.fetch_first_row(filtered_query) async def get_user_by_id(self, user_id: int) -> User: - users: Query = self._get_query(User).options(selectinload(User.plates)) - filter_functions = [UserFilter.BY_ID] - filtered_query = apply_user_filter( - users=users, filter_functions=filter_functions, user_id=user_id - ) - return (await self.session.execute(filtered_query)).scalars().first() + users: Query = self._get_user_with_plates() + filtered_query = filter_users_by_id(user_id=user_id, users=users) + return await self.fetch_first_row(filtered_query) async def get_user_by_email(self, email: str) -> User | None: - users: Query = self._get_query(User) - filter_functions = [UserFilter.BY_EMAIL] - filtered_query = apply_user_filter( - users=users, filter_functions=filter_functions, email=email - ) - return (await self.session.execute(filtered_query)).scalars().first() + users: Query = select(User) + filtered_query = filter_users_by_email(email=email, users=users) + return await self.fetch_first_row(filtered_query) async def get_users_with_skip_and_limit(self, skip: int, limit: int) -> list[User]: - users: Query = self._get_query(User).options(selectinload(User.plates)) + users: Query = self._get_user_with_plates() filter_functions = [UserFilter.SKIP_AND_LIMIT] filtered_query = apply_user_filter( users=users, filter_functions=filter_functions, skip=skip, limit=limit ) - return (await self.session.execute(filtered_query)).scalars().all() + return await self.fetch_all_rows(filtered_query) async def check_analyses_objects(self, analyses: list[Analysis], analysis_type: Types) -> None: """Raising 400 if any analysis in the list already exist in the database""" @@ -235,13 +212,49 @@ async def check_analyses_objects(self, analyses: list[Analysis], analysis_type: await self.session.commit() async def get_snps(self) -> list[SNP]: - filtered_query = self._get_query(SNP) - return (await self.session.execute(filtered_query)).scalars().all() + filtered_query = select(SNP) + return await self.fetch_all_rows(filtered_query) async def get_snps_by_limit_and_skip(self, skip: int, limit: int) -> list[SNP]: - snps: Query = self._get_query(SNP) + snps: Query = select(SNP) filter_functions = [SNPFilter.SKIP_AND_LIMIT] filtered_query = apply_snp_filter( snps=snps, filter_functions=filter_functions, skip=skip, limit=limit ) - return (await self.session.execute(filtered_query)).scalars().all() + return await self.fetch_all_rows(filtered_query) + + @staticmethod + def _get_analysis_with_genotypes_and_sample() -> Query: + return select(Analysis).options( + selectinload(Analysis.genotypes), + selectinload(Analysis.sample), + ) + + @staticmethod + def _get_analysis_with_genotypes() -> Query: + return select(Analysis).options(selectinload(Analysis.genotypes)) + + @staticmethod + def _get_plate_with_analyses() -> Query: + return select(Plate).options(selectinload(Plate.analyses)) + + @staticmethod + def _get_plate_with_analyses_and_samples() -> Query: + return select(Plate).options(selectinload(Plate.analyses).selectinload(Analysis.sample)) + + @staticmethod + def _get_genotype_with_analysis() -> Query: + return select(Genotype).options(selectinload(Genotype.analysis)) + + @staticmethod + def _get_samples_with_analyses_and_genotypes() -> Query: + return ( + select(Sample) + .distinct() + .options(selectinload(Sample.analyses).selectinload(Analysis.genotypes)) + .join(Analysis, Analysis.sample_id == Sample.id) + ) + + @staticmethod + def _get_user_with_plates() -> Query: + return select(User).options(selectinload(User.plates))