Skip to content

Commit

Permalink
Update pydantic mip (#3848)
Browse files Browse the repository at this point in the history
### Changed

- Update Pydantic models for MIP
  • Loading branch information
henrikstranneheim authored Oct 16, 2024
1 parent 4b5c6da commit f9a4f57
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 141 deletions.
2 changes: 1 addition & 1 deletion cg/meta/workflow/mip.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pathlib import Path
from typing import Any

from pydantic.v1 import ValidationError
from pydantic import ValidationError

from cg.apps.mip.confighandler import ConfigHandler
from cg.constants import FileExtensions, Workflow
Expand Down
73 changes: 37 additions & 36 deletions cg/models/deliverables/metric_deliverables.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import operator
from typing import Any, Callable
from typing import Annotated, Any, Callable

from pydantic.v1 import BaseModel, Field, validator
from pydantic import BaseModel, Field, field_validator
from pydantic_core.core_schema import ValidationInfo

from cg.constants import PRECISION
from cg.exc import CgError, MetricsQCError
Expand All @@ -15,28 +16,26 @@ def _get_metric_per_sample_id(sample_id: str, metric_objs: list) -> Any:
return metric


def add_metric(name: str, values: dict) -> list[Any]:
"""Add metric to list of objects"""
found_metrics: list = []
raw_metrics: list = values.get("metrics_")
metrics_validator: dict[str, Any] = values.get("metric_to_get_")
for metric in raw_metrics:
if name == metric.name and metric.name in metrics_validator:
found_metrics.append(
metrics_validator[metric.name](
sample_id=metric.id, step=metric.step, value=metric.value
)
)
def add_metric(name: str, info: ValidationInfo) -> list[Any]:
"""Add metric to a list of objects."""
raw_metrics: list = info.data.get("metrics_")
metrics_validator: dict[str, Any] = info.data.get("metric_to_get_")
found_metrics: list = [
metrics_validator[metric.name](sample_id=metric.id, step=metric.step, value=metric.value)
for metric in raw_metrics
if name == metric.name and metric.name in metrics_validator
]
return found_metrics


def add_sample_id_metrics(parsed_metric: Any, values: dict) -> list[Any]:
def add_sample_id_metrics(parsed_metric: Any, info: ValidationInfo) -> list[Any]:
"""Add parsed sample_id metrics gathered from all metrics to list"""
sample_ids: set = values.get("sample_ids")
sample_ids: set = info.data.get("sample_ids")
sample_id_metrics: list = []
metric_per_sample_id_map: dict = {}
for metric_name in values.get("sample_metric_to_parse"):
metric_per_sample_id_map.update({metric_name: values.get(metric_name)})
metric_per_sample_id_map: dict = {
metric_name: info.data.get(metric_name)
for metric_name in info.data.get("sample_metric_to_parse")
}
for sample_id in sample_ids:
metric_per_sample_id: dict = {"sample_id": sample_id}
for metric_name, metric_objs in metric_per_sample_id_map.items():
Expand All @@ -45,7 +44,7 @@ def add_sample_id_metrics(parsed_metric: Any, values: dict) -> list[Any]:
)
if sample_metric.value:
metric_per_sample_id[metric_name]: Any = sample_metric.value
metric_per_sample_id[metric_name + "_step"]: str = sample_metric.step
metric_per_sample_id[f"{metric_name}_step"]: str = sample_metric.step
sample_id_metrics.append(parsed_metric(**metric_per_sample_id))
return sample_id_metrics

Expand All @@ -61,7 +60,8 @@ class MetricCondition(BaseModel):
norm: str
threshold: float | str

@validator("norm")
@field_validator("norm")
@classmethod
def validate_operator(cls, norm: str) -> str:
"""Validate that an operator is accepted."""
try:
Expand All @@ -76,13 +76,13 @@ def validate_operator(cls, norm: str) -> str:
class MetricsBase(BaseModel):
"""Definition for elements in deliverables metrics file."""

header: str | None
header: str | None = None
id: str
input: str
name: str
step: str
value: Any
condition: MetricCondition | None
condition: MetricCondition | None = None


class SampleMetric(BaseModel):
Expand All @@ -97,7 +97,8 @@ class MeanInsertSize(SampleMetric):

value: float

@validator("value", always=True)
@field_validator("value")
@classmethod
def convert_mean_insert_size(cls, value) -> int:
"""Convert raw value from float to int"""
return int(value)
Expand All @@ -122,16 +123,15 @@ class ParsedMetrics(QCMetrics):
class MetricsDeliverables(BaseModel):
"""Specification for a metric general deliverables file"""

metrics_: list[MetricsBase] = Field(..., alias="metrics")
sample_ids: set | None
metrics_: list[MetricsBase] = Field(list[MetricsBase], alias="metrics")
sample_ids: Annotated[set | None, Field(validate_default=True)] = None

@validator("sample_ids", always=True)
def set_sample_ids(cls, _, values: dict) -> set:
@field_validator("sample_ids")
@classmethod
def set_sample_ids(cls, _, info: ValidationInfo) -> set:
"""Set sample_ids gathered from all metrics"""
sample_ids: list = []
raw_metrics: list = values.get("metrics_")
for metric in raw_metrics:
sample_ids.append(metric.id)
raw_metrics: list = info.data.get("metrics_")
sample_ids: list = [metric.id for metric in raw_metrics]
return set(sample_ids)


Expand All @@ -140,7 +140,8 @@ class MetricsDeliverablesCondition(BaseModel):

metrics: list[MetricsBase]

@validator("metrics")
@field_validator("metrics")
@classmethod
def validate_metrics(cls, metrics: list[MetricsBase]) -> list[MetricsBase]:
"""Verify that metrics met QC conditions."""
failed_metrics = []
Expand All @@ -164,6 +165,6 @@ def validate_metrics(cls, metrics: list[MetricsBase]) -> list[MetricsBase]:
class MultiqcDataJson(BaseModel):
"""Multiqc data json model."""

report_general_stats_data: list[dict] | None
report_data_sources: dict | None
report_saved_raw_data: dict[str, dict] | None
report_general_stats_data: list[dict] | None = None
report_data_sources: dict | None = None
report_saved_raw_data: dict[str, dict] | None = None
44 changes: 25 additions & 19 deletions cg/models/mip/mip_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Model MIP config"""

from pydantic.v1 import BaseModel, EmailStr, Field, validator
from typing import Annotated

from pydantic import BaseModel, EmailStr, Field, ValidationInfo, field_validator

from cg.constants.priority import SlurmQos

Expand All @@ -13,29 +15,33 @@ class AnalysisType(BaseModel):
class MipBaseConfig(BaseModel):
"""This model is used when validating the mip analysis config"""

family_id_: str = Field(None, alias="family_id")
case_id: str = None
analysis_type_: dict = Field(..., alias="analysis_type")
samples: list[AnalysisType] = None
config_path: str = Field(..., alias="config_file_analysis")
deliverables_file_path: str = Field(..., alias="store_file")
family_id_: str | None = Field(None, alias="family_id")
case_id: Annotated[str | None, Field(validate_default=True)] = None
analysis_type_: dict = Field(dict, alias="analysis_type")
samples: Annotated[list[AnalysisType] | None, Field(validate_default=True)] = None
config_path: str = Field(str, alias="config_file_analysis")
deliverables_file_path: str = Field(str, alias="store_file")
email: EmailStr
is_dry_run: bool = Field(False, alias="dry_run_all")
log_path: str = Field(..., alias="log_file")
out_dir: str = Field(..., alias="outdata_dir")
priority: SlurmQos = Field(..., alias="slurm_quality_of_service")
sample_info_path: str = Field(..., alias="sample_info_file")
log_path: str = Field(str, alias="log_file")
out_dir: str = Field(str, alias="outdata_dir")
priority: SlurmQos = Field(SlurmQos, alias="slurm_quality_of_service")
sample_info_path: str = Field(str, alias="sample_info_file")
sample_ids: list[str]

@validator("case_id", always=True, pre=True)
def set_case_id(cls, value, values: dict) -> str:
"""Set case_id. Family_id is used for older versions of MIP analysis"""
return value or values.get("family_id_")

@validator("samples", always=True, pre=True)
def set_samples(cls, _, values: dict) -> list[AnalysisType]:
@field_validator("case_id")
@classmethod
def set_case_id(cls, value: str, info: ValidationInfo) -> str:
"""Set case id. Family id is used for older versions of MIP analysis"""
return value or info.data.get("family_id_")

@field_validator(
"samples",
)
@classmethod
def set_samples(cls, _, info: ValidationInfo) -> list[AnalysisType]:
"""Set samples analysis type"""
raw_samples: dict = values.get("analysis_type_")
raw_samples: dict = info.data.get("analysis_type_")
return [
AnalysisType(sample_id=sample_id, analysis_type=analysis_type)
for sample_id, analysis_type in raw_samples.items()
Expand Down
76 changes: 43 additions & 33 deletions cg/models/mip/mip_metrics_deliverables.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any
from typing import Annotated, Any

from pydantic.v1 import validator
from pydantic import Field, ValidationInfo, field_validator

from cg.constants.subject import Sex
from cg.models.deliverables.metric_deliverables import (
Expand All @@ -27,14 +27,15 @@ class DuplicateReads(SampleMetric):

value: float

@validator("value", always=True)
@field_validator("value")
@classmethod
def convert_duplicate_read(cls, value) -> float:
"""Convert raw value from fraction to percent"""
"""Convert raw value from fraction to percent."""
return value * 100


class SexCheck(SampleMetric):
"""Definition of sex check metric"""
"""Definition of sex check metric."""

value: str

Expand All @@ -44,7 +45,8 @@ class MIPMappedReads(SampleMetric):

value: float

@validator("value", always=True)
@field_validator("value")
@classmethod
def convert_mapped_read(cls, value) -> float:
"""Convert raw value from fraction to percent"""
return value * 100
Expand All @@ -70,34 +72,38 @@ class MIPMetricsDeliverables(MetricsDeliverables):
"MEDIAN_TARGET_COVERAGE": MedianTargetCoverage,
"gender": SexCheck,
}
duplicate_reads: list[DuplicateReads] | None
mapped_reads: list[MIPMappedReads] | None
mean_insert_size: list[MeanInsertSize] | None
median_target_coverage: list[MedianTargetCoverage] | None
predicted_sex: list[SexCheck] | None
duplicate_reads: Annotated[list[DuplicateReads] | None, Field(validate_default=True)] = None
mapped_reads: Annotated[list[MIPMappedReads] | None, Field(validate_default=True)] = None
mean_insert_size: Annotated[list[MeanInsertSize] | None, Field(validate_default=True)] = None
median_target_coverage: Annotated[
list[MedianTargetCoverage] | None, Field(validate_default=True)
] = None
predicted_sex: Annotated[list[SexCheck] | None, Field(validate_default=True)] = None
sample_metric_to_parse: list[str] = SAMPLE_METRICS_TO_PARSE
sample_id_metrics: list[MIPParsedMetrics] | None
sample_id_metrics: Annotated[list[MIPParsedMetrics] | None, Field(validate_default=True)] = None

@validator("duplicate_reads", always=True)
def set_duplicate_reads(cls, _, values: dict) -> list[DuplicateReads]:
@field_validator("duplicate_reads")
@classmethod
def set_duplicate_reads(cls, _, info: ValidationInfo) -> list[DuplicateReads]:
"""Set duplicate_reads"""
return add_metric(name="fraction_duplicates", values=values)
return add_metric(name="fraction_duplicates", info=info)

@validator("mapped_reads", always=True)
def set_mapped_reads(cls, _, values: dict) -> list[MIPMappedReads]:
@field_validator("mapped_reads")
@classmethod
def set_mapped_reads(cls, _, info: ValidationInfo) -> list[MIPMappedReads]:
"""Set mapped reads"""
sample_ids: set = values.get("sample_ids")
sample_ids: set = info.data.get("sample_ids")
mapped_reads: list = []
total_sequences: dict = {}
reads_mapped: dict = {}
raw_metrics: list = values.get("metrics_")
raw_metrics: list = info.data.get("metrics_")
metric_step: str = ""
for metric in raw_metrics:
if metric.name == "raw_total_sequences":
raw_total_sequences = total_sequences.get(metric.id, 0)
total_sequences[metric.id] = int(metric.value) + raw_total_sequences
metric_step: str = metric.step
if metric.name == "reads_mapped":
elif metric.name == "reads_mapped":
raw_reads_mapped = reads_mapped.get(metric.id, 0)
reads_mapped[metric.id] = int(metric.value) + raw_reads_mapped
for sample_id in sample_ids:
Expand All @@ -107,31 +113,35 @@ def set_mapped_reads(cls, _, values: dict) -> list[MIPMappedReads]:
)
return mapped_reads

@validator("mean_insert_size", always=True)
def set_mean_insert_size(cls, _, values: dict) -> list[MeanInsertSize]:
@field_validator("mean_insert_size")
@classmethod
def set_mean_insert_size(cls, _, info: ValidationInfo) -> list[MeanInsertSize]:
"""Set mean insert size"""
return add_metric(name="MEAN_INSERT_SIZE", values=values)
return add_metric(name="MEAN_INSERT_SIZE", info=info)

@validator("median_target_coverage", always=True)
def set_median_target_coverage(cls, _, values: dict) -> list[MedianTargetCoverage]:
@field_validator("median_target_coverage")
@classmethod
def set_median_target_coverage(cls, _, info: ValidationInfo) -> list[MedianTargetCoverage]:
"""Set median target coverage"""
return add_metric(name="MEDIAN_TARGET_COVERAGE", values=values)
return add_metric(name="MEDIAN_TARGET_COVERAGE", info=info)

@validator("predicted_sex", always=True)
def set_predicted_sex(cls, _, values: dict) -> list[SexCheck]:
@field_validator("predicted_sex")
@classmethod
def set_predicted_sex(cls, _, info: ValidationInfo) -> list[SexCheck]:
"""Set predicted sex"""
return add_metric(name="gender", values=values)
return add_metric(name="gender", info=info)

@validator("sample_id_metrics", always=True)
def set_sample_id_metrics(cls, _, values: dict) -> list[MIPParsedMetrics]:
@field_validator("sample_id_metrics")
@classmethod
def set_sample_id_metrics(cls, _, info: ValidationInfo) -> list[MIPParsedMetrics]:
"""Set parsed sample_id metrics gathered from all metrics"""
return add_sample_id_metrics(parsed_metric=MIPParsedMetrics, values=values)
return add_sample_id_metrics(parsed_metric=MIPParsedMetrics, info=info)


def get_sample_id_metric(
sample_id_metrics: list[MIPParsedMetrics], sample_id: str
) -> MIPParsedMetrics:
"""Get parsed metrics for an sample_id"""
"""Get parsed metrics for a sample id"""
for sample_id_metric in sample_id_metrics:
if sample_id == sample_id_metric.sample_id:
return sample_id_metric
Expand Down
10 changes: 3 additions & 7 deletions tests/models/mip/test_mip_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@


def test_instantiate_mip_analysis(mip_analysis_raw: dict):
"""
Tests raw mip analysis against a pydantic MipAnalysis
"""
# GIVEN a dictionary with the some metrics
"""Tests raw mip analysis against a pydantic MipAnalysis."""
# GIVEN a dictionary with some metrics

# WHEN instantiating a MipAnalysis object
mip_dna_analysis = MipAnalysis(**mip_analysis_raw)
Expand All @@ -22,9 +20,7 @@ def test_instantiate_parse_mip_analysis(
mip_metrics_deliverables_raw: dict,
sample_info_dna_raw: dict,
):
"""
Tests parse_analysis
"""
"""Tests parse_analysis."""
# GIVEN a dictionary with some metrics and a MIP analysis API
mip_analysis_api = MipAnalysisAPI(cg_context, Workflow.MIP_DNA)

Expand Down
Loading

0 comments on commit f9a4f57

Please sign in to comment.