diff --git a/genotype_api/database/base_handler.py b/genotype_api/database/base_handler.py index 1d58c4b..6fd06e5 100644 --- a/genotype_api/database/base_handler.py +++ b/genotype_api/database/base_handler.py @@ -1,8 +1,6 @@ from dataclasses import dataclass from typing import Type - from sqlalchemy.orm import Session, Query, DeclarativeBase - from genotype_api.database.models import Analysis, Sample diff --git a/genotype_api/database/crud/create.py b/genotype_api/database/crud/create.py index 2579d59..9cbd56a 100644 --- a/genotype_api/database/crud/create.py +++ b/genotype_api/database/crud/create.py @@ -2,7 +2,7 @@ from genotype_api.database.base_handler import BaseHandler -from genotype_api.database.models import Analysis, Plate, Sample, User, SNP +from genotype_api.database.models import Analysis, Plate, Sample, User, SNP, Genotype from genotype_api.dto.user import UserRequest from genotype_api.exceptions import SampleExistsError @@ -42,14 +42,18 @@ def create_analyses_samples(self, analyses: list[Analysis]) -> list[Sample]: if not self.session.query(Sample).filter(Sample.id == analysis.sample_id).one_or_none() ] - def create_user(self, user: UserRequest) -> User: - db_user = User(email=user.email, name=user.name) - self.session.add(db_user) + def create_user(self, user: User) -> User: + self.session.add(user) self.session.commit() - self.session.refresh(db_user) - return db_user + self.session.refresh(user) + return user def create_snps(self, snps: list[SNP]) -> list[SNP]: self.session.add_all(snps) self.session.commit() return snps + + def create_genotype(self, genotype: Genotype) -> Genotype: + self.session.add(genotype) + self.session.commit() + return genotype diff --git a/genotype_api/database/crud/delete.py b/genotype_api/database/crud/delete.py index 0fb6218..8cc642f 100644 --- a/genotype_api/database/crud/delete.py +++ b/genotype_api/database/crud/delete.py @@ -26,7 +26,10 @@ def delete_user(self, user: User) -> None: self.session.delete(user) self.session.commit() - def delete_snps(self) -> any: - result = self.session.execute(delete(SNP)) + def delete_snps(self) -> int: + snps: list[SNP] = self._get_query(SNP).all() + count: int = len(snps) + for snp in snps: + self.session.delete(snp) self.session.commit() - return result + return count diff --git a/genotype_api/database/crud/read.py b/genotype_api/database/crud/read.py index b92a508..c0e7fb4 100644 --- a/genotype_api/database/crud/read.py +++ b/genotype_api/database/crud/read.py @@ -23,17 +23,6 @@ class ReadHandler(BaseHandler): def get_analyses_from_plate(self, plate_id: int) -> list[Analysis]: return self.session.query(Analysis).filter(Analysis.plate_id == plate_id).all() - def get_analysis_by_type_sample( - self, - sample_id: str, - analysis_type: str, - ) -> Analysis | None: - return ( - self.session.query(Analysis) - .filter(Analysis.sample_id == sample_id, Analysis.type == analysis_type) - .first() - ) - def get_analysis_by_id(self, analysis_id: int) -> Analysis: return self.session.query(Analysis).filter(Analysis.id == analysis_id).first() @@ -130,7 +119,7 @@ def _get_samples(query: Query, sample_id: str) -> Query: """Returns a query for samples containing the given sample_id.""" return query.filter(Sample.id.contains(sample_id)) - def get_sample(self, sample_id: str) -> Sample: + def get_sample_by_id(self, sample_id: str) -> Sample: return self.session.query(Sample).filter(Sample.id == sample_id).first() def get_user_by_id(self, user_id: int) -> User: @@ -145,7 +134,7 @@ def get_users_with_skip_and_limit(self, skip: int, limit: int) -> list[User]: def check_analyses_objects(self, analyses: list[Analysis], analysis_type: Types) -> None: """Raising 400 if any analysis in the list already exist in the database""" for analysis_obj in analyses: - existing_analysis = self.get_analysis_by_type_sample( + existing_analysis = self.get_analysis_by_type_and_sample_id( sample_id=analysis_obj.sample_id, analysis_type=analysis_type, ) diff --git a/genotype_api/database/crud/update.py b/genotype_api/database/crud/update.py index 5a43d3a..df36d30 100644 --- a/genotype_api/database/crud/update.py +++ b/genotype_api/database/crud/update.py @@ -28,7 +28,7 @@ def refresh_sample_status( return sample def update_sample_comment(self, sample_id: str, comment: str) -> Sample: - sample: Sample = self.get_sample(sample_id=sample_id) + sample: Sample = self.get_sample_by_id(sample_id=sample_id) if not sample: raise SampleNotFoundError sample.comment = comment @@ -38,7 +38,7 @@ def update_sample_comment(self, sample_id: str, comment: str) -> Sample: return sample def update_sample_status(self, sample_id: str, status: str | None) -> Sample: - sample: Sample = self.get_sample(sample_id=sample_id) + sample: Sample = self.get_sample_by_id(sample_id=sample_id) if not sample: raise SampleNotFoundError sample.status = status diff --git a/genotype_api/services/endpoint_services/sample_service.py b/genotype_api/services/endpoint_services/sample_service.py index dfbf26e..c81a286 100644 --- a/genotype_api/services/endpoint_services/sample_service.py +++ b/genotype_api/services/endpoint_services/sample_service.py @@ -60,7 +60,7 @@ def _get_sample_response(self, sample: Sample) -> SampleResponse: ) def get_sample(self, sample_id: str) -> SampleResponse: - sample: Sample = self.store.get_sample(sample_id=sample_id) + sample: Sample = self.store.get_sample_by_id(sample_id=sample_id) if not sample: raise SampleNotFoundError if len(sample.analyses) == 2 and not sample.status: @@ -82,13 +82,13 @@ def create_sample(self, sample_create: SampleCreate) -> None: self.store.create_sample(sample=sample) def delete_sample(self, sample_id: str) -> None: - sample: Sample = self.store.get_sample(sample_id=sample_id) + sample: Sample = self.store.get_sample_by_id(sample_id=sample_id) for analysis in sample.analyses: self.store.delete_analysis(analysis=analysis) self.store.delete_sample(sample=sample) def get_status_detail(self, sample_id: str) -> SampleDetail: - sample: Sample = self.store.get_sample(sample_id=sample_id) + sample: Sample = self.store.get_sample_by_id(sample_id=sample_id) if len(sample.analyses) != 2: return SampleDetail() return MatchGenotypeService.check_sample(sample=sample) diff --git a/genotype_api/services/endpoint_services/snp_service.py b/genotype_api/services/endpoint_services/snp_service.py index e19328a..dacf4b9 100644 --- a/genotype_api/services/endpoint_services/snp_service.py +++ b/genotype_api/services/endpoint_services/snp_service.py @@ -22,7 +22,7 @@ def get_snps(self, skip: int, limit: int) -> list[SNPResponse]: def upload_snps(self, snps_file: UploadFile) -> list[SNPResponse]: """Upload snps to the database, raises an error when SNPs already exist.""" - existing_snps: list[SNP] = self.store.get_snps(self.session) + existing_snps: list[SNP] = self.store.get_snps() if existing_snps: raise SNPExistsError snps: list[SNP] = SNPReaderService.read_snps_from_file(snps_file) @@ -30,5 +30,5 @@ def upload_snps(self, snps_file: UploadFile) -> list[SNPResponse]: return [self._get_snp_response(new_snp) for new_snp in new_snps] def delete_all_snps(self) -> int: - result = self.store.delete_snps(self.session) + result = self.store.delete_snps() return result.rowcount diff --git a/genotype_api/services/endpoint_services/user_service.py b/genotype_api/services/endpoint_services/user_service.py index 0ace577..e335ab4 100644 --- a/genotype_api/services/endpoint_services/user_service.py +++ b/genotype_api/services/endpoint_services/user_service.py @@ -34,7 +34,8 @@ def create_user(self, user: UserRequest): existing_user: User = self.store.get_user_by_email(email=user.email) if existing_user: raise UserExistsError - new_user: User = self.store.create_user(user=user) + db_user = User(email=user.email, name=user.name) + new_user: User = self.store.create_user(user=db_user) return self._create_user_response(new_user) def get_users(self, skip: int, limit: int) -> list[UserResponse]: diff --git a/tests/conftest.py b/tests/conftest.py index fc1f524..491d756 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,8 @@ import pytest from genotype_api.database.database import initialise_database, create_all_tables, drop_all_tables +from genotype_api.database.filter_models.plate_models import PlateSignOff +from genotype_api.database.filter_models.sample_models import SampleSexesUpdate from genotype_api.database.models import User, Plate, SNP, Sample, Genotype, Analysis from genotype_api.database.store import Store from tests.store_helpers import StoreHelpers @@ -17,6 +19,21 @@ def timestamp_now() -> datetime: return datetime.datetime.now() +@pytest.fixture +def date_two_weeks_future() -> datetime.date: + return datetime.date.today() + datetime.timedelta(days=14) + + +@pytest.fixture +def date_yesterday() -> datetime: + return datetime.date.today() - datetime.timedelta(days=1) + + +@pytest.fixture +def date_tomorrow() -> datetime: + return datetime.date.today() + datetime.timedelta(days=1) + + @pytest.fixture def store() -> Generator[Store, None, None]: """Return a CG store.""" @@ -232,6 +249,39 @@ def base_store( return store +@pytest.fixture +def unsigned_plate() -> Plate: + return Plate( + id=1, + plate_id="ID_1", + signed_by=None, + method_document=None, + method_version=None, + created_at=datetime.datetime.now(), + signed_at=None, + ) + + +@pytest.fixture +def plate_sign_off() -> PlateSignOff: + return PlateSignOff( + user_id=1, + signed_at=datetime.datetime.now(), + method_document="mdoc", + method_version="mdoc_ver", + ) + + +@pytest.fixture +def sample_sex_update(test_sample_id) -> SampleSexesUpdate: + return SampleSexesUpdate( + sample_id=test_sample_id, + sex="female", + genotype_sex="female", + sequence_sex="female", + ) + + @pytest.fixture(name="fixtures_dir") def fixture_fixtures_dir() -> Path: """Return the path to fixtures dir.""" diff --git a/tests/database/crud/test_create.py b/tests/database/crud/test_create.py new file mode 100644 index 0000000..a4336fd --- /dev/null +++ b/tests/database/crud/test_create.py @@ -0,0 +1,84 @@ +"""Module to test the create functionality of the genotype API CRUD.""" + +from genotype_api.database.models import Analysis, SNP, User, Genotype, Sample, Plate +from genotype_api.database.store import Store + + +def test_create_analysis(store: Store, test_analysis: Analysis): + # GIVEN an analysis and an empty store + assert not store._get_query(Analysis).all() + # WHEN creating the analysis + store.create_analysis(analysis=test_analysis) + + # THEN the analysis is created + assert store._get_query(Analysis).all()[0].id == test_analysis.id + + +def test_create_genotype(store: Store, test_genotype: Genotype): + # GIVEN a genotype and an empty store + assert not store._get_query(Genotype).all() + + # WHEN creating the genotype + store.create_genotype(genotype=test_genotype) + + # THEN the genotype is created + assert store._get_query(Genotype).all()[0].id == test_genotype.id + + +def test_create_snp(store: Store, test_snp: SNP): + # GIVEN a SNP and an empty store + assert not store._get_query(SNP).all() + + # WHEN creating the SNP + store.create_snps(snps=[test_snp]) + + # THEN the SNP is created + assert store._get_query(SNP).all()[0].id == test_snp.id + + +def test_create_user(store: Store, test_user: User): + # GIVEN a user and an empty store + assert not store._get_query(User).all() + + # WHEN creating the user + store.create_user(user=test_user) + + # THEN the user is created + assert store._get_query(User).all()[0].id == test_user.id + + +def test_create_sample(store: Store, test_sample: Sample): + # GIVEN a sample and an empty store + assert not store._get_query(Sample).all() + + # WHEN creating the sample + store.create_sample(sample=test_sample) + + # THEN the sample is created + assert store._get_query(Sample).all()[0].id == test_sample.id + + +def test_create_plate(store: Store, test_plate: Plate): + # GIVEN a plate and an empty store + assert not store._get_query(Plate).all() + + # WHEN creating the plate + store.create_plate(plate=test_plate) + + # THEN the plate is created + assert store._get_query(Plate).all()[0].id == test_plate.id + + +def test_create_analyses_samples(store: Store, test_analysis: Analysis): + # GIVEN an analysis in a store + assert not store._get_query(Sample).all() + assert not store._get_query(Analysis).all() + store.create_analysis(test_analysis) + + # WHEN creating the analyses + store.create_analyses_samples(analyses=[test_analysis]) + + # THEN the samples are created + sample: Sample = store._get_query(Sample).all()[0] + assert sample + assert sample.id == test_analysis.sample_id diff --git a/tests/database/crud/test_delete.py b/tests/database/crud/test_delete.py new file mode 100644 index 0000000..ebfc32f --- /dev/null +++ b/tests/database/crud/test_delete.py @@ -0,0 +1,59 @@ +"""Module to test the delete functionality of the genotype API CRUD.""" + +from genotype_api.database.models import Analysis, Sample, User, Plate, SNP +from genotype_api.database.store import Store + + +def test_delete_analysis(base_store: Store, test_analysis: Analysis): + # GIVEN an analysis and a store with the analysis + assert test_analysis in base_store._get_query(Analysis).all() + + # WHEN deleting the analysis + base_store.delete_analysis(analysis=test_analysis) + + # THEN the analysis is deleted + assert test_analysis not in base_store._get_query(Analysis).all() + + +def test_delete_sample(base_store: Store, test_sample: Sample): + # GIVEN a sample and a store with the sample + assert test_sample in base_store._get_query(Sample).all() + + # WHEN deleting the sample + base_store.delete_sample(sample=test_sample) + + # THEN the sample is deleted + assert test_sample not in base_store._get_query(Sample).all() + + +def test_delete_plate(base_store: Store, test_plate: Plate): + # GIVEN a plate and a store with the plate + assert test_plate in base_store._get_query(Plate).all() + + # WHEN deleting the plate + base_store.delete_plate(plate=test_plate) + + # THEN the plate is deleted + assert test_plate not in base_store._get_query(Plate).all() + + +def test_delete_user(base_store: Store, test_user: User): + # GIVEN a user and a store with the user + assert test_user in base_store._get_query(User).all() + + # WHEN deleting the user + base_store.delete_user(user=test_user) + + # THEN the user is deleted + assert test_user not in base_store._get_query(User).all() + + +def test_delete_snps(base_store: Store, test_snp: SNP): + # GIVEN an SNP and a store with the SNP + assert base_store._get_query(SNP).all() + + # WHEN deleting the SNP + base_store.delete_snps() + + # THEN all SNPs are deleted + assert not base_store._get_query(SNP).all() diff --git a/tests/database/crud/test_read.py b/tests/database/crud/test_read.py new file mode 100644 index 0000000..5905919 --- /dev/null +++ b/tests/database/crud/test_read.py @@ -0,0 +1,188 @@ +"""Module to test the read functionality of the genotype API CRUD.""" + +from datetime import date + +from astroid import helpers + +from genotype_api.database.filter_models.plate_models import PlateOrderParams +from genotype_api.database.models import Analysis, Plate, SNP, User, Genotype +from genotype_api.database.store import Store +from tests.store_helpers import StoreHelpers + + +def test_get_analysis_by_plate_id(base_store: Store, test_analysis: Analysis): + # GIVEN an analysis and a store with the analysis + + # WHEN getting the analysis by plate id + analyses: list[Analysis] = base_store.get_analyses_from_plate(plate_id=test_analysis.plate_id) + + # THEN the analysis is returned + for analysis in analyses: + assert analysis.plate_id == test_analysis.plate_id + + +def test_get_analysis_by_type_and_sample_id(base_store: Store, test_analysis: Analysis): + # GIVEN an analysis and a store with the analysis + + # WHEN getting the analysis by type and sample id + analysis: Analysis = base_store.get_analysis_by_type_and_sample_id( + analysis_type=test_analysis.type, sample_id=test_analysis.sample_id + ) + + # THEN the analysis is returned + assert analysis.sample_id == test_analysis.sample_id + assert analysis.type == test_analysis.type + + +def test_get_analysis_by_id(base_store: Store, test_analysis: Analysis): + # GIVEN an analysis and a store with the analysis + + # WHEN getting the analysis by id + analysis: Analysis = base_store.get_analysis_by_id(analysis_id=test_analysis.id) + + # THEN the analysis is returned + assert analysis.id == test_analysis.id + + +def test_get_analyses(base_store: Store, test_analyses: list[Analysis]): + # GIVEN an analysis and a store with the analysis + + # WHEN getting the analyses + analyses: list[Analysis] = base_store.get_analyses() + + # THEN the analyses are returned + assert analyses == test_analyses + + +def test_get_analyses_with_skip_and_limit(base_store: Store, test_analyses: list[Analysis]): + # GIVEN an analysis and a store with the analysis + + # WHEN getting the analyses with skip and limit + analyses: list[Analysis] = base_store.get_analyses_with_skip_and_limit(skip=0, limit=2) + + # THEN the analyses are returned + assert analyses == test_analyses[:2] + + +def test_get_analyses_by_type_between_dates( + base_store: Store, + test_analysis: Analysis, + date_tomorrow: date, + date_yesterday: date, + date_two_weeks_future: date, + helpers: StoreHelpers, +): + # GIVEN a store with two analyses of the same type but different dates + future_analysis: Analysis = test_analysis + future_analysis.created_at = date_two_weeks_future + helpers.ensure_analysis(store=base_store, analysis=future_analysis) + + # WHEN getting the analyses by type between dates excluding one of the analyses + analyses: list[Analysis] = base_store.get_analyses_by_type_between_dates( + analysis_type=test_analysis.type, date_min=date_yesterday, date_max=date_tomorrow + ) + + # THEN the analyses are returned + for analysis in analyses: + assert analysis.type == test_analysis.type + assert analysis.created_at != date_two_weeks_future + + +def test_get_plate_by_id(base_store: Store, test_plate: Plate): + # GIVEN a store with a plate + + # WHEN getting the plate by id + plate: Plate = base_store.get_plate_by_id(plate_id=test_plate.id) + + # THEN the plate is returned + assert plate.id == test_plate.id + + +def test_get_plate_by_plate_id(base_store: Store, test_plate: Plate): + # GIVEN a store with a plate + + # WHEN getting the plate by plate id + plate: Plate = base_store.get_plate_by_plate_id(plate_id=test_plate.plate_id) + + # THEN the plate is returned + assert plate.plate_id == test_plate.plate_id + + +def get_user_by_id(base_store: Store, test_user: User): + # GIVEN a store with a user + + # WHEN getting the user by id + user: User = base_store.get_user_by_id(user_id=test_user.id) + + # THEN the user is returned + assert user.id == test_user.id + + +def get_user_by_email(base_store: Store, test_user: User): + # GIVEN a store with a user + + # WHEN getting the user by email + user: User = base_store.get_user_by_email(email=test_user.email) + + # THEN the user is returned + assert user.email == test_user.email + + +def get_user_with_skip_and_limit(base_store: Store, test_users: list[User], helpers: StoreHelpers): + # GIVEN store with a user + out_of_limit_user: User = test_users[0] + out_of_limit_user.id = 3 + helpers.ensure_user(store=base_store, user=out_of_limit_user) + + # WHEN getting the user with skip and limit + users: list[User] = base_store.get_users_with_skip_and_limit(skip=0, limit=2) + + # THEN the user is returned + assert users == test_users + + +def test_get_genotype_by_id(base_store: Store, test_genotype: Genotype): + # GIVEN store with a genotype + + # WHEN getting the genotype by id + genotype: Genotype = base_store.get_genotype_by_id(entry_id=test_genotype.id) + + # THEN the genotype is returned + assert genotype.id == test_genotype.id + + +def test_get_snps(base_store: Store, test_snps: list[SNP]): + # GIVEN a store with a SNP + + # WHEN getting the SNPs + snps: list[SNP] = base_store.get_snps() + + # THEN the SNPs are returned + assert len(snps) == len(test_snps) + + +def test_get_snps_by_limit_and_skip(base_store: Store, test_snps: list[SNP]): + # GIVEN store with SNPs + out_of_limit_snp: SNP = test_snps[0] + out_of_limit_snp.id = 3 + base_store.create_snps(snps=[out_of_limit_snp]) + # WHEN getting the SNPs + snps: list[SNP] = base_store.get_snps_by_limit_and_skip(skip=0, limit=2) + + # THEN the SNPs are returned + assert len(snps) == len(test_snps) + + +def test_get_ordered_plates(base_store: Store, test_plates: list[Plate], helpers: StoreHelpers): + # GIVEN a store with the plates and plate not fulfilling the limit + plate_order_params = PlateOrderParams(sort_order="acs", order_by="plate_id", skip=0, limit=2) + out_of_limit_plate: Plate = test_plates[0] + out_of_limit_plate.plate_id = "ID3" + out_of_limit_plate.id = 3 + helpers.ensure_plate(store=base_store, plate=out_of_limit_plate) + + # WHEN getting the ordered plates + plates: list[Plate] = base_store.get_ordered_plates(order_params=plate_order_params) + + # THEN the plates are returned + assert len(plates) == len(test_plates) diff --git a/tests/database/crud/test_update.py b/tests/database/crud/test_update.py new file mode 100644 index 0000000..84fc000 --- /dev/null +++ b/tests/database/crud/test_update.py @@ -0,0 +1,96 @@ +"""Module to test the update functionality of the genotype API CRUD.""" + +from genotype_api.database.filter_models.plate_models import PlateSignOff +from genotype_api.database.filter_models.sample_models import SampleSexesUpdate +from genotype_api.database.models import Sample, Plate +from genotype_api.database.store import Store +from tests.store_helpers import StoreHelpers + + +def test_refresh_sample_status(store: Store, test_sample: Sample, helpers: StoreHelpers): + # GIVEN a store with a sample with an initial status + initial_status: str = "initial_status" + test_sample.status = initial_status + helpers.ensure_sample(store=store, sample=test_sample) + + # WHEN updating the sample status + store.refresh_sample_status(sample=test_sample) + + # THEN the sample status is updated + updated_sample = store.get_sample_by_id(sample_id=test_sample.id) + assert updated_sample.status != initial_status + + +def test_update_sample_comment(store: Store, test_sample: Sample, helpers: StoreHelpers): + # GIVEN a sample and a store with the sample + initial_comment: str = "initial_comment" + test_sample.comment = initial_comment + helpers.ensure_sample(store=store, sample=test_sample) + + # WHEN updating the sample comment + new_comment: str = "new_comment" + store.update_sample_comment(sample_id=test_sample.id, comment=new_comment) + + # THEN the sample comment is updated + updated_sample = store.get_sample_by_id(sample_id=test_sample.id) + assert updated_sample.comment == new_comment + + +def test_update_sample_status(store: Store, test_sample: Sample, helpers: StoreHelpers): + # GIVEN a sample and a store with the sample + initial_status: str = "initial_status" + test_sample.status = initial_status + helpers.ensure_sample(store=store, sample=test_sample) + + # WHEN updating the sample status + new_status: str = "new_status" + store.update_sample_status(sample_id=test_sample.id, status=new_status) + + # THEN the sample status is updated + updated_sample = store.get_sample_by_id(sample_id=test_sample.id) + assert updated_sample.status == new_status + + +def test_update_user_email(store: Store, test_user, helpers: StoreHelpers): + # GIVEN a user and a store with the user + initial_email: str = "initial_email" + test_user.email = initial_email + helpers.ensure_user(store=store, user=test_user) + + # WHEN updating the user email + new_email: str = "new_email" + store.update_user_email(user=test_user, email=new_email) + + # THEN the user email is updated + updated_user = store.get_user_by_id(user_id=test_user.id) + assert updated_user.email == new_email + + +def test_update_plate_sign_off( + store: Store, unsigned_plate: Plate, plate_sign_off: PlateSignOff, helpers: StoreHelpers +): + # GIVEN a plate and a store with the plate + helpers.ensure_plate(store=store, plate=unsigned_plate) + + # WHEN updating the plate sign off + store.update_plate_sign_off(plate=unsigned_plate, plate_sign_off=plate_sign_off) + + # THEN the plate sign off is updated + updated_plate = store.get_plate_by_id(plate_id=unsigned_plate.id) + assert updated_plate.signed_by == plate_sign_off.user_id + assert updated_plate.signed_at == plate_sign_off.signed_at + assert updated_plate.method_document == plate_sign_off.method_document + assert updated_plate.method_version == plate_sign_off.method_version + + +def test_update_sample_sex(base_store: Store, sample_sex_update: SampleSexesUpdate): + # GIVEN a store with a sample, analysis + + # WHEN updating the sex of the sample + base_store.update_sample_sex(sample_sex_update) + + # THEN the sex of the sample and analysis + updated_sample = base_store.get_sample_by_id(sample_id=sample_sex_update.sample_id) + assert updated_sample.sex == sample_sex_update.sex + for analysis in updated_sample.analyses: + assert analysis.sex == sample_sex_update.genotype_sex diff --git a/tests/store_helpers.py b/tests/store_helpers.py index 3398156..bc19e2d 100644 --- a/tests/store_helpers.py +++ b/tests/store_helpers.py @@ -55,7 +55,7 @@ def ensure_analysis( genotypes: list[Genotype] = None, ): """Add an analysis to the store and ensure the associated sample, plate and genotypes are present.""" - if sample and not store.get_sample(sample.id): + if sample and not store.get_sample_by_id(sample.id): self.add_entity(store=store, entity=sample) if plate and not store.get_plate_by_id(plate.id): self.add_entity(store=store, entity=plate) diff --git a/tests/test_store_helpers.py b/tests/test_store_helpers.py index 0176e18..2cb654f 100644 --- a/tests/test_store_helpers.py +++ b/tests/test_store_helpers.py @@ -95,7 +95,7 @@ def test_ensure_analysis( added_analysis: Analysis = store.get_analysis_by_id(test_analysis.id) assert added_analysis - added_sample: Sample = store.get_sample(test_sample.id) + added_sample: Sample = store.get_sample_by_id(test_sample.id) assert added_sample added_plate: Plate = store.get_plate_by_id(test_plate.id) @@ -112,5 +112,5 @@ def test_ensure_sample(store: Store, test_sample: Sample, helpers: StoreHelpers) helpers.ensure_sample(store=store, sample=test_sample) # THEN a sample is added - added_sample: Sample = store.get_sample(test_sample.id) + added_sample: Sample = store.get_sample_by_id(test_sample.id) assert added_sample