Skip to content

Commit f9a4f57

Browse files
Update pydantic mip (#3848)
### Changed - Update Pydantic models for MIP
1 parent 4b5c6da commit f9a4f57

File tree

7 files changed

+136
-141
lines changed

7 files changed

+136
-141
lines changed

cg/meta/workflow/mip.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from pathlib import Path
33
from typing import Any
44

5-
from pydantic.v1 import ValidationError
5+
from pydantic import ValidationError
66

77
from cg.apps.mip.confighandler import ConfigHandler
88
from cg.constants import FileExtensions, Workflow

cg/models/deliverables/metric_deliverables.py

Lines changed: 37 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import operator
2-
from typing import Any, Callable
2+
from typing import Annotated, Any, Callable
33

4-
from pydantic.v1 import BaseModel, Field, validator
4+
from pydantic import BaseModel, Field, field_validator
5+
from pydantic_core.core_schema import ValidationInfo
56

67
from cg.constants import PRECISION
78
from cg.exc import CgError, MetricsQCError
@@ -15,28 +16,26 @@ def _get_metric_per_sample_id(sample_id: str, metric_objs: list) -> Any:
1516
return metric
1617

1718

18-
def add_metric(name: str, values: dict) -> list[Any]:
19-
"""Add metric to list of objects"""
20-
found_metrics: list = []
21-
raw_metrics: list = values.get("metrics_")
22-
metrics_validator: dict[str, Any] = values.get("metric_to_get_")
23-
for metric in raw_metrics:
24-
if name == metric.name and metric.name in metrics_validator:
25-
found_metrics.append(
26-
metrics_validator[metric.name](
27-
sample_id=metric.id, step=metric.step, value=metric.value
28-
)
29-
)
19+
def add_metric(name: str, info: ValidationInfo) -> list[Any]:
20+
"""Add metric to a list of objects."""
21+
raw_metrics: list = info.data.get("metrics_")
22+
metrics_validator: dict[str, Any] = info.data.get("metric_to_get_")
23+
found_metrics: list = [
24+
metrics_validator[metric.name](sample_id=metric.id, step=metric.step, value=metric.value)
25+
for metric in raw_metrics
26+
if name == metric.name and metric.name in metrics_validator
27+
]
3028
return found_metrics
3129

3230

33-
def add_sample_id_metrics(parsed_metric: Any, values: dict) -> list[Any]:
31+
def add_sample_id_metrics(parsed_metric: Any, info: ValidationInfo) -> list[Any]:
3432
"""Add parsed sample_id metrics gathered from all metrics to list"""
35-
sample_ids: set = values.get("sample_ids")
33+
sample_ids: set = info.data.get("sample_ids")
3634
sample_id_metrics: list = []
37-
metric_per_sample_id_map: dict = {}
38-
for metric_name in values.get("sample_metric_to_parse"):
39-
metric_per_sample_id_map.update({metric_name: values.get(metric_name)})
35+
metric_per_sample_id_map: dict = {
36+
metric_name: info.data.get(metric_name)
37+
for metric_name in info.data.get("sample_metric_to_parse")
38+
}
4039
for sample_id in sample_ids:
4140
metric_per_sample_id: dict = {"sample_id": sample_id}
4241
for metric_name, metric_objs in metric_per_sample_id_map.items():
@@ -45,7 +44,7 @@ def add_sample_id_metrics(parsed_metric: Any, values: dict) -> list[Any]:
4544
)
4645
if sample_metric.value:
4746
metric_per_sample_id[metric_name]: Any = sample_metric.value
48-
metric_per_sample_id[metric_name + "_step"]: str = sample_metric.step
47+
metric_per_sample_id[f"{metric_name}_step"]: str = sample_metric.step
4948
sample_id_metrics.append(parsed_metric(**metric_per_sample_id))
5049
return sample_id_metrics
5150

@@ -61,7 +60,8 @@ class MetricCondition(BaseModel):
6160
norm: str
6261
threshold: float | str
6362

64-
@validator("norm")
63+
@field_validator("norm")
64+
@classmethod
6565
def validate_operator(cls, norm: str) -> str:
6666
"""Validate that an operator is accepted."""
6767
try:
@@ -76,13 +76,13 @@ def validate_operator(cls, norm: str) -> str:
7676
class MetricsBase(BaseModel):
7777
"""Definition for elements in deliverables metrics file."""
7878

79-
header: str | None
79+
header: str | None = None
8080
id: str
8181
input: str
8282
name: str
8383
step: str
8484
value: Any
85-
condition: MetricCondition | None
85+
condition: MetricCondition | None = None
8686

8787

8888
class SampleMetric(BaseModel):
@@ -97,7 +97,8 @@ class MeanInsertSize(SampleMetric):
9797

9898
value: float
9999

100-
@validator("value", always=True)
100+
@field_validator("value")
101+
@classmethod
101102
def convert_mean_insert_size(cls, value) -> int:
102103
"""Convert raw value from float to int"""
103104
return int(value)
@@ -122,16 +123,15 @@ class ParsedMetrics(QCMetrics):
122123
class MetricsDeliverables(BaseModel):
123124
"""Specification for a metric general deliverables file"""
124125

125-
metrics_: list[MetricsBase] = Field(..., alias="metrics")
126-
sample_ids: set | None
126+
metrics_: list[MetricsBase] = Field(list[MetricsBase], alias="metrics")
127+
sample_ids: Annotated[set | None, Field(validate_default=True)] = None
127128

128-
@validator("sample_ids", always=True)
129-
def set_sample_ids(cls, _, values: dict) -> set:
129+
@field_validator("sample_ids")
130+
@classmethod
131+
def set_sample_ids(cls, _, info: ValidationInfo) -> set:
130132
"""Set sample_ids gathered from all metrics"""
131-
sample_ids: list = []
132-
raw_metrics: list = values.get("metrics_")
133-
for metric in raw_metrics:
134-
sample_ids.append(metric.id)
133+
raw_metrics: list = info.data.get("metrics_")
134+
sample_ids: list = [metric.id for metric in raw_metrics]
135135
return set(sample_ids)
136136

137137

@@ -140,7 +140,8 @@ class MetricsDeliverablesCondition(BaseModel):
140140

141141
metrics: list[MetricsBase]
142142

143-
@validator("metrics")
143+
@field_validator("metrics")
144+
@classmethod
144145
def validate_metrics(cls, metrics: list[MetricsBase]) -> list[MetricsBase]:
145146
"""Verify that metrics met QC conditions."""
146147
failed_metrics = []
@@ -164,6 +165,6 @@ def validate_metrics(cls, metrics: list[MetricsBase]) -> list[MetricsBase]:
164165
class MultiqcDataJson(BaseModel):
165166
"""Multiqc data json model."""
166167

167-
report_general_stats_data: list[dict] | None
168-
report_data_sources: dict | None
169-
report_saved_raw_data: dict[str, dict] | None
168+
report_general_stats_data: list[dict] | None = None
169+
report_data_sources: dict | None = None
170+
report_saved_raw_data: dict[str, dict] | None = None

cg/models/mip/mip_config.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""Model MIP config"""
22

3-
from pydantic.v1 import BaseModel, EmailStr, Field, validator
3+
from typing import Annotated
4+
5+
from pydantic import BaseModel, EmailStr, Field, ValidationInfo, field_validator
46

57
from cg.constants.priority import SlurmQos
68

@@ -13,29 +15,33 @@ class AnalysisType(BaseModel):
1315
class MipBaseConfig(BaseModel):
1416
"""This model is used when validating the mip analysis config"""
1517

16-
family_id_: str = Field(None, alias="family_id")
17-
case_id: str = None
18-
analysis_type_: dict = Field(..., alias="analysis_type")
19-
samples: list[AnalysisType] = None
20-
config_path: str = Field(..., alias="config_file_analysis")
21-
deliverables_file_path: str = Field(..., alias="store_file")
18+
family_id_: str | None = Field(None, alias="family_id")
19+
case_id: Annotated[str | None, Field(validate_default=True)] = None
20+
analysis_type_: dict = Field(dict, alias="analysis_type")
21+
samples: Annotated[list[AnalysisType] | None, Field(validate_default=True)] = None
22+
config_path: str = Field(str, alias="config_file_analysis")
23+
deliverables_file_path: str = Field(str, alias="store_file")
2224
email: EmailStr
2325
is_dry_run: bool = Field(False, alias="dry_run_all")
24-
log_path: str = Field(..., alias="log_file")
25-
out_dir: str = Field(..., alias="outdata_dir")
26-
priority: SlurmQos = Field(..., alias="slurm_quality_of_service")
27-
sample_info_path: str = Field(..., alias="sample_info_file")
26+
log_path: str = Field(str, alias="log_file")
27+
out_dir: str = Field(str, alias="outdata_dir")
28+
priority: SlurmQos = Field(SlurmQos, alias="slurm_quality_of_service")
29+
sample_info_path: str = Field(str, alias="sample_info_file")
2830
sample_ids: list[str]
2931

30-
@validator("case_id", always=True, pre=True)
31-
def set_case_id(cls, value, values: dict) -> str:
32-
"""Set case_id. Family_id is used for older versions of MIP analysis"""
33-
return value or values.get("family_id_")
34-
35-
@validator("samples", always=True, pre=True)
36-
def set_samples(cls, _, values: dict) -> list[AnalysisType]:
32+
@field_validator("case_id")
33+
@classmethod
34+
def set_case_id(cls, value: str, info: ValidationInfo) -> str:
35+
"""Set case id. Family id is used for older versions of MIP analysis"""
36+
return value or info.data.get("family_id_")
37+
38+
@field_validator(
39+
"samples",
40+
)
41+
@classmethod
42+
def set_samples(cls, _, info: ValidationInfo) -> list[AnalysisType]:
3743
"""Set samples analysis type"""
38-
raw_samples: dict = values.get("analysis_type_")
44+
raw_samples: dict = info.data.get("analysis_type_")
3945
return [
4046
AnalysisType(sample_id=sample_id, analysis_type=analysis_type)
4147
for sample_id, analysis_type in raw_samples.items()

cg/models/mip/mip_metrics_deliverables.py

Lines changed: 43 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from typing import Any
1+
from typing import Annotated, Any
22

3-
from pydantic.v1 import validator
3+
from pydantic import Field, ValidationInfo, field_validator
44

55
from cg.constants.subject import Sex
66
from cg.models.deliverables.metric_deliverables import (
@@ -27,14 +27,15 @@ class DuplicateReads(SampleMetric):
2727

2828
value: float
2929

30-
@validator("value", always=True)
30+
@field_validator("value")
31+
@classmethod
3132
def convert_duplicate_read(cls, value) -> float:
32-
"""Convert raw value from fraction to percent"""
33+
"""Convert raw value from fraction to percent."""
3334
return value * 100
3435

3536

3637
class SexCheck(SampleMetric):
37-
"""Definition of sex check metric"""
38+
"""Definition of sex check metric."""
3839

3940
value: str
4041

@@ -44,7 +45,8 @@ class MIPMappedReads(SampleMetric):
4445

4546
value: float
4647

47-
@validator("value", always=True)
48+
@field_validator("value")
49+
@classmethod
4850
def convert_mapped_read(cls, value) -> float:
4951
"""Convert raw value from fraction to percent"""
5052
return value * 100
@@ -70,34 +72,38 @@ class MIPMetricsDeliverables(MetricsDeliverables):
7072
"MEDIAN_TARGET_COVERAGE": MedianTargetCoverage,
7173
"gender": SexCheck,
7274
}
73-
duplicate_reads: list[DuplicateReads] | None
74-
mapped_reads: list[MIPMappedReads] | None
75-
mean_insert_size: list[MeanInsertSize] | None
76-
median_target_coverage: list[MedianTargetCoverage] | None
77-
predicted_sex: list[SexCheck] | None
75+
duplicate_reads: Annotated[list[DuplicateReads] | None, Field(validate_default=True)] = None
76+
mapped_reads: Annotated[list[MIPMappedReads] | None, Field(validate_default=True)] = None
77+
mean_insert_size: Annotated[list[MeanInsertSize] | None, Field(validate_default=True)] = None
78+
median_target_coverage: Annotated[
79+
list[MedianTargetCoverage] | None, Field(validate_default=True)
80+
] = None
81+
predicted_sex: Annotated[list[SexCheck] | None, Field(validate_default=True)] = None
7882
sample_metric_to_parse: list[str] = SAMPLE_METRICS_TO_PARSE
79-
sample_id_metrics: list[MIPParsedMetrics] | None
83+
sample_id_metrics: Annotated[list[MIPParsedMetrics] | None, Field(validate_default=True)] = None
8084

81-
@validator("duplicate_reads", always=True)
82-
def set_duplicate_reads(cls, _, values: dict) -> list[DuplicateReads]:
85+
@field_validator("duplicate_reads")
86+
@classmethod
87+
def set_duplicate_reads(cls, _, info: ValidationInfo) -> list[DuplicateReads]:
8388
"""Set duplicate_reads"""
84-
return add_metric(name="fraction_duplicates", values=values)
89+
return add_metric(name="fraction_duplicates", info=info)
8590

86-
@validator("mapped_reads", always=True)
87-
def set_mapped_reads(cls, _, values: dict) -> list[MIPMappedReads]:
91+
@field_validator("mapped_reads")
92+
@classmethod
93+
def set_mapped_reads(cls, _, info: ValidationInfo) -> list[MIPMappedReads]:
8894
"""Set mapped reads"""
89-
sample_ids: set = values.get("sample_ids")
95+
sample_ids: set = info.data.get("sample_ids")
9096
mapped_reads: list = []
9197
total_sequences: dict = {}
9298
reads_mapped: dict = {}
93-
raw_metrics: list = values.get("metrics_")
99+
raw_metrics: list = info.data.get("metrics_")
94100
metric_step: str = ""
95101
for metric in raw_metrics:
96102
if metric.name == "raw_total_sequences":
97103
raw_total_sequences = total_sequences.get(metric.id, 0)
98104
total_sequences[metric.id] = int(metric.value) + raw_total_sequences
99105
metric_step: str = metric.step
100-
if metric.name == "reads_mapped":
106+
elif metric.name == "reads_mapped":
101107
raw_reads_mapped = reads_mapped.get(metric.id, 0)
102108
reads_mapped[metric.id] = int(metric.value) + raw_reads_mapped
103109
for sample_id in sample_ids:
@@ -107,31 +113,35 @@ def set_mapped_reads(cls, _, values: dict) -> list[MIPMappedReads]:
107113
)
108114
return mapped_reads
109115

110-
@validator("mean_insert_size", always=True)
111-
def set_mean_insert_size(cls, _, values: dict) -> list[MeanInsertSize]:
116+
@field_validator("mean_insert_size")
117+
@classmethod
118+
def set_mean_insert_size(cls, _, info: ValidationInfo) -> list[MeanInsertSize]:
112119
"""Set mean insert size"""
113-
return add_metric(name="MEAN_INSERT_SIZE", values=values)
120+
return add_metric(name="MEAN_INSERT_SIZE", info=info)
114121

115-
@validator("median_target_coverage", always=True)
116-
def set_median_target_coverage(cls, _, values: dict) -> list[MedianTargetCoverage]:
122+
@field_validator("median_target_coverage")
123+
@classmethod
124+
def set_median_target_coverage(cls, _, info: ValidationInfo) -> list[MedianTargetCoverage]:
117125
"""Set median target coverage"""
118-
return add_metric(name="MEDIAN_TARGET_COVERAGE", values=values)
126+
return add_metric(name="MEDIAN_TARGET_COVERAGE", info=info)
119127

120-
@validator("predicted_sex", always=True)
121-
def set_predicted_sex(cls, _, values: dict) -> list[SexCheck]:
128+
@field_validator("predicted_sex")
129+
@classmethod
130+
def set_predicted_sex(cls, _, info: ValidationInfo) -> list[SexCheck]:
122131
"""Set predicted sex"""
123-
return add_metric(name="gender", values=values)
132+
return add_metric(name="gender", info=info)
124133

125-
@validator("sample_id_metrics", always=True)
126-
def set_sample_id_metrics(cls, _, values: dict) -> list[MIPParsedMetrics]:
134+
@field_validator("sample_id_metrics")
135+
@classmethod
136+
def set_sample_id_metrics(cls, _, info: ValidationInfo) -> list[MIPParsedMetrics]:
127137
"""Set parsed sample_id metrics gathered from all metrics"""
128-
return add_sample_id_metrics(parsed_metric=MIPParsedMetrics, values=values)
138+
return add_sample_id_metrics(parsed_metric=MIPParsedMetrics, info=info)
129139

130140

131141
def get_sample_id_metric(
132142
sample_id_metrics: list[MIPParsedMetrics], sample_id: str
133143
) -> MIPParsedMetrics:
134-
"""Get parsed metrics for an sample_id"""
144+
"""Get parsed metrics for a sample id"""
135145
for sample_id_metric in sample_id_metrics:
136146
if sample_id == sample_id_metric.sample_id:
137147
return sample_id_metric

tests/models/mip/test_mip_analysis.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,8 @@
44

55

66
def test_instantiate_mip_analysis(mip_analysis_raw: dict):
7-
"""
8-
Tests raw mip analysis against a pydantic MipAnalysis
9-
"""
10-
# GIVEN a dictionary with the some metrics
7+
"""Tests raw mip analysis against a pydantic MipAnalysis."""
8+
# GIVEN a dictionary with some metrics
119

1210
# WHEN instantiating a MipAnalysis object
1311
mip_dna_analysis = MipAnalysis(**mip_analysis_raw)
@@ -22,9 +20,7 @@ def test_instantiate_parse_mip_analysis(
2220
mip_metrics_deliverables_raw: dict,
2321
sample_info_dna_raw: dict,
2422
):
25-
"""
26-
Tests parse_analysis
27-
"""
23+
"""Tests parse_analysis."""
2824
# GIVEN a dictionary with some metrics and a MIP analysis API
2925
mip_analysis_api = MipAnalysisAPI(cg_context, Workflow.MIP_DNA)
3026

0 commit comments

Comments
 (0)