From f9a4f57665464023f59a0a34b54f6cf3cd8c6890 Mon Sep 17 00:00:00 2001 From: Henrik Stranneheim Date: Wed, 16 Oct 2024 08:51:49 +0200 Subject: [PATCH] Update pydantic mip (#3848) ### Changed - Update Pydantic models for MIP --- cg/meta/workflow/mip.py | 2 +- cg/models/deliverables/metric_deliverables.py | 73 +++++++++--------- cg/models/mip/mip_config.py | 44 ++++++----- cg/models/mip/mip_metrics_deliverables.py | 76 +++++++++++-------- tests/models/mip/test_mip_analysis.py | 10 +-- tests/models/mip/test_mip_config.py | 12 +-- .../mip/test_mip_metrics_deliverables.py | 60 +++++---------- 7 files changed, 136 insertions(+), 141 deletions(-) diff --git a/cg/meta/workflow/mip.py b/cg/meta/workflow/mip.py index 568081de1b..4b21ea1abf 100644 --- a/cg/meta/workflow/mip.py +++ b/cg/meta/workflow/mip.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import Any -from pydantic.v1 import ValidationError +from pydantic import ValidationError from cg.apps.mip.confighandler import ConfigHandler from cg.constants import FileExtensions, Workflow diff --git a/cg/models/deliverables/metric_deliverables.py b/cg/models/deliverables/metric_deliverables.py index 8f26dd4fa5..43e60b561a 100644 --- a/cg/models/deliverables/metric_deliverables.py +++ b/cg/models/deliverables/metric_deliverables.py @@ -1,7 +1,8 @@ import operator -from typing import Any, Callable +from typing import Annotated, Any, Callable -from pydantic.v1 import BaseModel, Field, validator +from pydantic import BaseModel, Field, field_validator +from pydantic_core.core_schema import ValidationInfo from cg.constants import PRECISION from cg.exc import CgError, MetricsQCError @@ -15,28 +16,26 @@ def _get_metric_per_sample_id(sample_id: str, metric_objs: list) -> Any: return metric -def add_metric(name: str, values: dict) -> list[Any]: - """Add metric to list of objects""" - found_metrics: list = [] - raw_metrics: list = values.get("metrics_") - metrics_validator: dict[str, Any] = values.get("metric_to_get_") - for metric in raw_metrics: - if name == metric.name and metric.name in metrics_validator: - found_metrics.append( - metrics_validator[metric.name]( - sample_id=metric.id, step=metric.step, value=metric.value - ) - ) +def add_metric(name: str, info: ValidationInfo) -> list[Any]: + """Add metric to a list of objects.""" + raw_metrics: list = info.data.get("metrics_") + metrics_validator: dict[str, Any] = info.data.get("metric_to_get_") + found_metrics: list = [ + metrics_validator[metric.name](sample_id=metric.id, step=metric.step, value=metric.value) + for metric in raw_metrics + if name == metric.name and metric.name in metrics_validator + ] return found_metrics -def add_sample_id_metrics(parsed_metric: Any, values: dict) -> list[Any]: +def add_sample_id_metrics(parsed_metric: Any, info: ValidationInfo) -> list[Any]: """Add parsed sample_id metrics gathered from all metrics to list""" - sample_ids: set = values.get("sample_ids") + sample_ids: set = info.data.get("sample_ids") sample_id_metrics: list = [] - metric_per_sample_id_map: dict = {} - for metric_name in values.get("sample_metric_to_parse"): - metric_per_sample_id_map.update({metric_name: values.get(metric_name)}) + metric_per_sample_id_map: dict = { + metric_name: info.data.get(metric_name) + for metric_name in info.data.get("sample_metric_to_parse") + } for sample_id in sample_ids: metric_per_sample_id: dict = {"sample_id": sample_id} for metric_name, metric_objs in metric_per_sample_id_map.items(): @@ -45,7 +44,7 @@ def add_sample_id_metrics(parsed_metric: Any, values: dict) -> list[Any]: ) if sample_metric.value: metric_per_sample_id[metric_name]: Any = sample_metric.value - metric_per_sample_id[metric_name + "_step"]: str = sample_metric.step + metric_per_sample_id[f"{metric_name}_step"]: str = sample_metric.step sample_id_metrics.append(parsed_metric(**metric_per_sample_id)) return sample_id_metrics @@ -61,7 +60,8 @@ class MetricCondition(BaseModel): norm: str threshold: float | str - @validator("norm") + @field_validator("norm") + @classmethod def validate_operator(cls, norm: str) -> str: """Validate that an operator is accepted.""" try: @@ -76,13 +76,13 @@ def validate_operator(cls, norm: str) -> str: class MetricsBase(BaseModel): """Definition for elements in deliverables metrics file.""" - header: str | None + header: str | None = None id: str input: str name: str step: str value: Any - condition: MetricCondition | None + condition: MetricCondition | None = None class SampleMetric(BaseModel): @@ -97,7 +97,8 @@ class MeanInsertSize(SampleMetric): value: float - @validator("value", always=True) + @field_validator("value") + @classmethod def convert_mean_insert_size(cls, value) -> int: """Convert raw value from float to int""" return int(value) @@ -122,16 +123,15 @@ class ParsedMetrics(QCMetrics): class MetricsDeliverables(BaseModel): """Specification for a metric general deliverables file""" - metrics_: list[MetricsBase] = Field(..., alias="metrics") - sample_ids: set | None + metrics_: list[MetricsBase] = Field(list[MetricsBase], alias="metrics") + sample_ids: Annotated[set | None, Field(validate_default=True)] = None - @validator("sample_ids", always=True) - def set_sample_ids(cls, _, values: dict) -> set: + @field_validator("sample_ids") + @classmethod + def set_sample_ids(cls, _, info: ValidationInfo) -> set: """Set sample_ids gathered from all metrics""" - sample_ids: list = [] - raw_metrics: list = values.get("metrics_") - for metric in raw_metrics: - sample_ids.append(metric.id) + raw_metrics: list = info.data.get("metrics_") + sample_ids: list = [metric.id for metric in raw_metrics] return set(sample_ids) @@ -140,7 +140,8 @@ class MetricsDeliverablesCondition(BaseModel): metrics: list[MetricsBase] - @validator("metrics") + @field_validator("metrics") + @classmethod def validate_metrics(cls, metrics: list[MetricsBase]) -> list[MetricsBase]: """Verify that metrics met QC conditions.""" failed_metrics = [] @@ -164,6 +165,6 @@ def validate_metrics(cls, metrics: list[MetricsBase]) -> list[MetricsBase]: class MultiqcDataJson(BaseModel): """Multiqc data json model.""" - report_general_stats_data: list[dict] | None - report_data_sources: dict | None - report_saved_raw_data: dict[str, dict] | None + report_general_stats_data: list[dict] | None = None + report_data_sources: dict | None = None + report_saved_raw_data: dict[str, dict] | None = None diff --git a/cg/models/mip/mip_config.py b/cg/models/mip/mip_config.py index 11f63e23e3..d57dff2994 100644 --- a/cg/models/mip/mip_config.py +++ b/cg/models/mip/mip_config.py @@ -1,6 +1,8 @@ """Model MIP config""" -from pydantic.v1 import BaseModel, EmailStr, Field, validator +from typing import Annotated + +from pydantic import BaseModel, EmailStr, Field, ValidationInfo, field_validator from cg.constants.priority import SlurmQos @@ -13,29 +15,33 @@ class AnalysisType(BaseModel): class MipBaseConfig(BaseModel): """This model is used when validating the mip analysis config""" - family_id_: str = Field(None, alias="family_id") - case_id: str = None - analysis_type_: dict = Field(..., alias="analysis_type") - samples: list[AnalysisType] = None - config_path: str = Field(..., alias="config_file_analysis") - deliverables_file_path: str = Field(..., alias="store_file") + family_id_: str | None = Field(None, alias="family_id") + case_id: Annotated[str | None, Field(validate_default=True)] = None + analysis_type_: dict = Field(dict, alias="analysis_type") + samples: Annotated[list[AnalysisType] | None, Field(validate_default=True)] = None + config_path: str = Field(str, alias="config_file_analysis") + deliverables_file_path: str = Field(str, alias="store_file") email: EmailStr is_dry_run: bool = Field(False, alias="dry_run_all") - log_path: str = Field(..., alias="log_file") - out_dir: str = Field(..., alias="outdata_dir") - priority: SlurmQos = Field(..., alias="slurm_quality_of_service") - sample_info_path: str = Field(..., alias="sample_info_file") + log_path: str = Field(str, alias="log_file") + out_dir: str = Field(str, alias="outdata_dir") + priority: SlurmQos = Field(SlurmQos, alias="slurm_quality_of_service") + sample_info_path: str = Field(str, alias="sample_info_file") sample_ids: list[str] - @validator("case_id", always=True, pre=True) - def set_case_id(cls, value, values: dict) -> str: - """Set case_id. Family_id is used for older versions of MIP analysis""" - return value or values.get("family_id_") - - @validator("samples", always=True, pre=True) - def set_samples(cls, _, values: dict) -> list[AnalysisType]: + @field_validator("case_id") + @classmethod + def set_case_id(cls, value: str, info: ValidationInfo) -> str: + """Set case id. Family id is used for older versions of MIP analysis""" + return value or info.data.get("family_id_") + + @field_validator( + "samples", + ) + @classmethod + def set_samples(cls, _, info: ValidationInfo) -> list[AnalysisType]: """Set samples analysis type""" - raw_samples: dict = values.get("analysis_type_") + raw_samples: dict = info.data.get("analysis_type_") return [ AnalysisType(sample_id=sample_id, analysis_type=analysis_type) for sample_id, analysis_type in raw_samples.items() diff --git a/cg/models/mip/mip_metrics_deliverables.py b/cg/models/mip/mip_metrics_deliverables.py index a0713917f9..3454fe0aae 100644 --- a/cg/models/mip/mip_metrics_deliverables.py +++ b/cg/models/mip/mip_metrics_deliverables.py @@ -1,6 +1,6 @@ -from typing import Any +from typing import Annotated, Any -from pydantic.v1 import validator +from pydantic import Field, ValidationInfo, field_validator from cg.constants.subject import Sex from cg.models.deliverables.metric_deliverables import ( @@ -27,14 +27,15 @@ class DuplicateReads(SampleMetric): value: float - @validator("value", always=True) + @field_validator("value") + @classmethod def convert_duplicate_read(cls, value) -> float: - """Convert raw value from fraction to percent""" + """Convert raw value from fraction to percent.""" return value * 100 class SexCheck(SampleMetric): - """Definition of sex check metric""" + """Definition of sex check metric.""" value: str @@ -44,7 +45,8 @@ class MIPMappedReads(SampleMetric): value: float - @validator("value", always=True) + @field_validator("value") + @classmethod def convert_mapped_read(cls, value) -> float: """Convert raw value from fraction to percent""" return value * 100 @@ -70,34 +72,38 @@ class MIPMetricsDeliverables(MetricsDeliverables): "MEDIAN_TARGET_COVERAGE": MedianTargetCoverage, "gender": SexCheck, } - duplicate_reads: list[DuplicateReads] | None - mapped_reads: list[MIPMappedReads] | None - mean_insert_size: list[MeanInsertSize] | None - median_target_coverage: list[MedianTargetCoverage] | None - predicted_sex: list[SexCheck] | None + duplicate_reads: Annotated[list[DuplicateReads] | None, Field(validate_default=True)] = None + mapped_reads: Annotated[list[MIPMappedReads] | None, Field(validate_default=True)] = None + mean_insert_size: Annotated[list[MeanInsertSize] | None, Field(validate_default=True)] = None + median_target_coverage: Annotated[ + list[MedianTargetCoverage] | None, Field(validate_default=True) + ] = None + predicted_sex: Annotated[list[SexCheck] | None, Field(validate_default=True)] = None sample_metric_to_parse: list[str] = SAMPLE_METRICS_TO_PARSE - sample_id_metrics: list[MIPParsedMetrics] | None + sample_id_metrics: Annotated[list[MIPParsedMetrics] | None, Field(validate_default=True)] = None - @validator("duplicate_reads", always=True) - def set_duplicate_reads(cls, _, values: dict) -> list[DuplicateReads]: + @field_validator("duplicate_reads") + @classmethod + def set_duplicate_reads(cls, _, info: ValidationInfo) -> list[DuplicateReads]: """Set duplicate_reads""" - return add_metric(name="fraction_duplicates", values=values) + return add_metric(name="fraction_duplicates", info=info) - @validator("mapped_reads", always=True) - def set_mapped_reads(cls, _, values: dict) -> list[MIPMappedReads]: + @field_validator("mapped_reads") + @classmethod + def set_mapped_reads(cls, _, info: ValidationInfo) -> list[MIPMappedReads]: """Set mapped reads""" - sample_ids: set = values.get("sample_ids") + sample_ids: set = info.data.get("sample_ids") mapped_reads: list = [] total_sequences: dict = {} reads_mapped: dict = {} - raw_metrics: list = values.get("metrics_") + raw_metrics: list = info.data.get("metrics_") metric_step: str = "" for metric in raw_metrics: if metric.name == "raw_total_sequences": raw_total_sequences = total_sequences.get(metric.id, 0) total_sequences[metric.id] = int(metric.value) + raw_total_sequences metric_step: str = metric.step - if metric.name == "reads_mapped": + elif metric.name == "reads_mapped": raw_reads_mapped = reads_mapped.get(metric.id, 0) reads_mapped[metric.id] = int(metric.value) + raw_reads_mapped for sample_id in sample_ids: @@ -107,31 +113,35 @@ def set_mapped_reads(cls, _, values: dict) -> list[MIPMappedReads]: ) return mapped_reads - @validator("mean_insert_size", always=True) - def set_mean_insert_size(cls, _, values: dict) -> list[MeanInsertSize]: + @field_validator("mean_insert_size") + @classmethod + def set_mean_insert_size(cls, _, info: ValidationInfo) -> list[MeanInsertSize]: """Set mean insert size""" - return add_metric(name="MEAN_INSERT_SIZE", values=values) + return add_metric(name="MEAN_INSERT_SIZE", info=info) - @validator("median_target_coverage", always=True) - def set_median_target_coverage(cls, _, values: dict) -> list[MedianTargetCoverage]: + @field_validator("median_target_coverage") + @classmethod + def set_median_target_coverage(cls, _, info: ValidationInfo) -> list[MedianTargetCoverage]: """Set median target coverage""" - return add_metric(name="MEDIAN_TARGET_COVERAGE", values=values) + return add_metric(name="MEDIAN_TARGET_COVERAGE", info=info) - @validator("predicted_sex", always=True) - def set_predicted_sex(cls, _, values: dict) -> list[SexCheck]: + @field_validator("predicted_sex") + @classmethod + def set_predicted_sex(cls, _, info: ValidationInfo) -> list[SexCheck]: """Set predicted sex""" - return add_metric(name="gender", values=values) + return add_metric(name="gender", info=info) - @validator("sample_id_metrics", always=True) - def set_sample_id_metrics(cls, _, values: dict) -> list[MIPParsedMetrics]: + @field_validator("sample_id_metrics") + @classmethod + def set_sample_id_metrics(cls, _, info: ValidationInfo) -> list[MIPParsedMetrics]: """Set parsed sample_id metrics gathered from all metrics""" - return add_sample_id_metrics(parsed_metric=MIPParsedMetrics, values=values) + return add_sample_id_metrics(parsed_metric=MIPParsedMetrics, info=info) def get_sample_id_metric( sample_id_metrics: list[MIPParsedMetrics], sample_id: str ) -> MIPParsedMetrics: - """Get parsed metrics for an sample_id""" + """Get parsed metrics for a sample id""" for sample_id_metric in sample_id_metrics: if sample_id == sample_id_metric.sample_id: return sample_id_metric diff --git a/tests/models/mip/test_mip_analysis.py b/tests/models/mip/test_mip_analysis.py index 8fd94c2f11..052c21be76 100644 --- a/tests/models/mip/test_mip_analysis.py +++ b/tests/models/mip/test_mip_analysis.py @@ -4,10 +4,8 @@ def test_instantiate_mip_analysis(mip_analysis_raw: dict): - """ - Tests raw mip analysis against a pydantic MipAnalysis - """ - # GIVEN a dictionary with the some metrics + """Tests raw mip analysis against a pydantic MipAnalysis.""" + # GIVEN a dictionary with some metrics # WHEN instantiating a MipAnalysis object mip_dna_analysis = MipAnalysis(**mip_analysis_raw) @@ -22,9 +20,7 @@ def test_instantiate_parse_mip_analysis( mip_metrics_deliverables_raw: dict, sample_info_dna_raw: dict, ): - """ - Tests parse_analysis - """ + """Tests parse_analysis.""" # GIVEN a dictionary with some metrics and a MIP analysis API mip_analysis_api = MipAnalysisAPI(cg_context, Workflow.MIP_DNA) diff --git a/tests/models/mip/test_mip_config.py b/tests/models/mip/test_mip_config.py index 512770ebc9..2917ba5d20 100644 --- a/tests/models/mip/test_mip_config.py +++ b/tests/models/mip/test_mip_config.py @@ -4,7 +4,7 @@ from cg.constants.constants import FileFormat from cg.io.controller import ReadFile -from cg.models.mip.mip_config import MipBaseConfig +from cg.models.mip.mip_config import AnalysisType, MipBaseConfig def test_instantiate_mip_config(mip_analysis_config_dna_raw: dict): @@ -57,16 +57,16 @@ def test_mip_config_case_id_with_family_id(mip_analysis_config_dna_raw: dict): assert config_object.case_id == "a_family_id" -def test_mip_config_case_id(mip_analysis_config_dna_raw: dict): - """Test case_id validator""" +def test_mip_config_analysis_type(mip_analysis_config_dna_raw: dict): + """Test analysis type validator.""" # GIVEN a MIP config # WHEN instantiating a MipBaseSampleInfo object config_object = MipBaseConfig(**mip_analysis_config_dna_raw) - # THEN assert that samples was set + # THEN assert that samples were set assert config_object.samples # THEN assert that dict is set - analysis_type: dict = config_object.samples.pop() - assert analysis_type == {"analysis_type": "wgs", "sample_id": "sample_id"} + analysis_type: AnalysisType = config_object.samples.pop() + assert analysis_type.model_dump() == {"analysis_type": "wgs", "sample_id": "sample_id"} diff --git a/tests/models/mip/test_mip_metrics_deliverables.py b/tests/models/mip/test_mip_metrics_deliverables.py index 19e9f60087..bf78a3186d 100644 --- a/tests/models/mip/test_mip_metrics_deliverables.py +++ b/tests/models/mip/test_mip_metrics_deliverables.py @@ -1,4 +1,4 @@ -"""Test MIP metrics deliverables""" +"""Test MIP metrics deliverables.""" from cg.models.mip.mip_metrics_deliverables import ( DuplicateReads, @@ -13,10 +13,8 @@ def test_instantiate_mip_metrics_deliverables(mip_metrics_deliverables_raw: dict): - """ - Tests raw data deliverable against a pydantic MIPMetricsDeliverables - """ - # GIVEN a dictionary with the some metrics + """Tests raw data deliverable against a pydantic MIPMetricsDeliverables.""" + # GIVEN a dictionary with some metrics # WHEN instantiating a MIPMetricsDeliverables object metrics_object = MIPMetricsDeliverables(**mip_metrics_deliverables_raw) @@ -26,10 +24,8 @@ def test_instantiate_mip_metrics_deliverables(mip_metrics_deliverables_raw: dict def test_instantiate_mip_metrics_sample_ids(mip_metrics_deliverables_raw: dict): - """ - Tests set sample_ids - """ - # GIVEN a dictionary with the some metrics + """Tests set sample ids.""" + # GIVEN a dictionary with some metrics # WHEN instantiating a MIPMetricsDeliverables object metrics_object = MIPMetricsDeliverables(**mip_metrics_deliverables_raw) @@ -39,15 +35,13 @@ def test_instantiate_mip_metrics_sample_ids(mip_metrics_deliverables_raw: dict): def test_mip_metrics_set_duplicate_reads(mip_metrics_deliverables_raw: dict): - """ - Tests set duplicates read - """ - # GIVEN a dictionary with the some metrics + """Tests set duplicates read.""" + # GIVEN a dictionary with some metrics # WHEN instantiating a MIPMetricsDeliverables object metrics_object = MIPMetricsDeliverables(**mip_metrics_deliverables_raw) - # THEN assert that read duplicates was set + # THEN assert that read duplicates were set assert metrics_object.duplicate_reads duplicate_read: DuplicateReads = metrics_object.duplicate_reads.pop() @@ -67,15 +61,13 @@ def test_mip_metrics_set_duplicate_reads(mip_metrics_deliverables_raw: dict): def test_mip_metrics_set_mapped_reads(mip_metrics_deliverables_raw: dict): - """ - Tests set mapped reads - """ - # GIVEN a dictionary with the some metrics + """Tests set mapped reads.""" + # GIVEN a dictionary with some metrics # WHEN instantiating a MIPMetricsDeliverables object metrics_object = MIPMetricsDeliverables(**mip_metrics_deliverables_raw) - # THEN assert that mapped reads was set + # THEN assert that mapped reads were set assert metrics_object.mapped_reads mapped_reads: MIPMappedReads = metrics_object.mapped_reads.pop() @@ -85,10 +77,8 @@ def test_mip_metrics_set_mapped_reads(mip_metrics_deliverables_raw: dict): def test_mip_metrics_set_mean_insert_size(mip_metrics_deliverables_raw: dict): - """ - Tests set mean insert size - """ - # GIVEN a dictionary with the some metrics + """Tests set mean insert size.""" + # GIVEN a dictionary with some metrics # WHEN instantiating a MIPMetricsDeliverables object metrics_object = MIPMetricsDeliverables(**mip_metrics_deliverables_raw) @@ -103,10 +93,8 @@ def test_mip_metrics_set_mean_insert_size(mip_metrics_deliverables_raw: dict): def test_mip_metrics_set_meadian_target_coverage(mip_metrics_deliverables_raw: dict): - """ - Tests set median target coverage - """ - # GIVEN a dictionary with the some metrics + """Tests set median target coverage.""" + # GIVEN a dictionary with some metrics # WHEN instantiating a MIPMetricsDeliverables object metrics_object = MIPMetricsDeliverables(**mip_metrics_deliverables_raw) @@ -121,10 +109,8 @@ def test_mip_metrics_set_meadian_target_coverage(mip_metrics_deliverables_raw: d def test_mip_metrics_set_predicted_sex(mip_metrics_deliverables_raw: dict): - """ - Tests set predicted sex - """ - # GIVEN a dictionary with the some metrics + """Tests set predicted sex.""" + # GIVEN a dictionary with some metrics # WHEN instantiating a MIPMetricsDeliverables object metrics_object = MIPMetricsDeliverables(**mip_metrics_deliverables_raw) @@ -139,10 +125,8 @@ def test_mip_metrics_set_predicted_sex(mip_metrics_deliverables_raw: dict): def test_instantiate_mip_metrics_set_sample_id_metrics(mip_metrics_deliverables_raw: dict): - """ - Tests set sample_id metrics - """ - # GIVEN a dictionary with the some metrics + """Tests set sample id metric.""" + # GIVEN a dictionary with some metrics # WHEN instantiating a MIPMetricsDeliverables object metrics_object = MIPMetricsDeliverables(**mip_metrics_deliverables_raw) @@ -168,10 +152,8 @@ def test_instantiate_mip_metrics_set_sample_id_metrics(mip_metrics_deliverables_ def test_get_sample_id_metric(mip_metrics_deliverables_raw: dict): - """ - Tests get sample_id metrics - """ - # GIVEN a dictionary with the some metrics + """Tests get sample id metric.""" + # GIVEN a dictionary with some metrics # WHEN instantiating a MIPMetricsDeliverables object metrics_object = MIPMetricsDeliverables(**mip_metrics_deliverables_raw)