Skip to content

Commit

Permalink
fix: define single structure for drift metrics (#67)
Browse files Browse the repository at this point in the history
* feat: define single structure for drift metrics

* fix: remove import
  • Loading branch information
dtria91 authored Jul 2, 2024
1 parent 9cd152f commit aed740b
Show file tree
Hide file tree
Showing 8 changed files with 18 additions and 167 deletions.
43 changes: 9 additions & 34 deletions api/app/models/metrics/drift_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
from pydantic import BaseModel, ConfigDict
from pydantic.alias_generators import to_camel

from app.models.exceptions import MetricsInternalError
from app.models.job_status import JobStatus
from app.models.model_dto import ModelType


class DriftAlgorithm(str, Enum):
Expand All @@ -29,23 +27,15 @@ class FeatureMetrics(BaseModel):
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class BinaryClassDrift(BaseModel):
class Drift(BaseModel):
feature_metrics: List[FeatureMetrics]

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class MultiClassDrift(BaseModel):
pass


class RegressionDrift(BaseModel):
pass


class DriftDTO(BaseModel):
job_status: JobStatus
drift: Optional[BinaryClassDrift | MultiClassDrift | RegressionDrift]
drift: Optional[Drift]

model_config = ConfigDict(
arbitrary_types_allowed=True,
Expand All @@ -55,21 +45,11 @@ class DriftDTO(BaseModel):

@staticmethod
def from_dict(
model_type: ModelType,
job_status: JobStatus,
drift_data: Optional[Dict],
) -> 'DriftDTO':
"""Create a DriftDTO from a dictionary of data."""
if not drift_data:
return DriftDTO(
job_status=job_status,
drift=None,
)

drift = DriftDTO._create_drift(
model_type=model_type,
drift_data=drift_data,
)
drift = DriftDTO._create_drift(drift_data=drift_data)

return DriftDTO(
job_status=job_status,
Expand All @@ -78,14 +58,9 @@ def from_dict(

@staticmethod
def _create_drift(
model_type: ModelType,
drift_data: Dict,
) -> BinaryClassDrift | MultiClassDrift | RegressionDrift:
"""Create a specific drift instance based on the model type."""
if model_type == ModelType.BINARY:
return BinaryClassDrift(**drift_data)
if model_type == ModelType.MULTI_CLASS:
return MultiClassDrift(**drift_data)
if model_type == ModelType.REGRESSION:
return RegressionDrift(**drift_data)
raise MetricsInternalError(f'Invalid model type {model_type}')
drift_data: Optional[Dict],
) -> Optional[Drift]:
"""Create a specific drift instance from a dictionary of data."""
if not drift_data:
return None
return Drift(**drift_data)
6 changes: 0 additions & 6 deletions api/app/services/metrics_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,10 +229,8 @@ def _get_drift_by_model_uuid(
missing_status,
) -> DriftDTO:
"""Retrieve drift for a model by its UUID."""
model = self.model_service.get_model_by_uuid(model_uuid)
dataset, metrics = dataset_and_metrics_getter(model_uuid)
return self._create_drift_dto(
model_type=model.model_type,
dataset=dataset,
metrics=metrics,
missing_status=missing_status,
Expand Down Expand Up @@ -321,26 +319,22 @@ def _create_data_quality_dto(

@staticmethod
def _create_drift_dto(
model_type: ModelType,
dataset: Optional[ReferenceDataset | CurrentDataset],
metrics: Optional[ReferenceDatasetMetrics | CurrentDatasetMetrics],
missing_status,
) -> DriftDTO:
"""Create a DriftDTO from the provided dataset and metrics."""
if not dataset:
return DriftDTO.from_dict(
model_type=model_type,
job_status=missing_status,
drift_data=None,
)
if not metrics:
return DriftDTO.from_dict(
model_type=model_type,
job_status=dataset.status,
drift_data=None,
)
return DriftDTO.from_dict(
model_type=model_type,
job_status=dataset.status,
drift_data=metrics.drift,
)
1 change: 0 additions & 1 deletion api/tests/routes/metrics_route_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ def test_get_current_drift(self):
current_metrics = db_mock.get_sample_current_metrics()
drift = DriftDTO.from_dict(
job_status=JobStatus.SUCCEEDED,
model_type=model.model_type,
drift_data=current_metrics.drift,
)
self.metrics_service.get_current_drift = MagicMock(return_value=drift)
Expand Down
15 changes: 2 additions & 13 deletions api/tests/services/metrics_service_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,15 +344,13 @@ def test_get_current_drift(self):
model = db_mock.get_sample_model()
current_dataset = db_mock.get_sample_current_dataset(status=status.value)
current_metrics = db_mock.get_sample_current_metrics()
self.model_service.get_model_by_uuid = MagicMock(return_value=model)
self.current_dataset_dao.get_current_dataset_by_model_uuid = MagicMock(
return_value=current_dataset
)
self.current_metrics_dao.get_current_metrics_by_model_uuid = MagicMock(
return_value=current_metrics
)
res = self.metrics_service.get_current_drift(model.uuid, current_dataset.uuid)
self.model_service.get_model_by_uuid.assert_called_once_with(model.uuid)
self.current_dataset_dao.get_current_dataset_by_model_uuid.assert_called_once_with(
model.uuid, current_dataset.uuid
)
Expand All @@ -362,46 +360,37 @@ def test_get_current_drift(self):

assert res == DriftDTO.from_dict(
job_status=status,
model_type=model.model_type,
drift_data=current_metrics.drift,
)

def test_get_empty_current_drift(self):
status = JobStatus.IMPORTING
model = db_mock.get_sample_model()
current_dataset = db_mock.get_sample_current_dataset(status=status.value)
self.model_service.get_model_by_uuid = MagicMock(return_value=model)
self.current_dataset_dao.get_current_dataset_by_model_uuid = MagicMock(
return_value=current_dataset
)
res = self.metrics_service.get_current_drift(model.uuid, current_dataset.uuid)
self.model_service.get_model_by_uuid.assert_called_once_with(model.uuid)
res = self.metrics_service.get_current_drift(model_uuid, current_dataset.uuid)
self.current_dataset_dao.get_current_dataset_by_model_uuid.assert_called_once_with(
model.uuid, current_dataset.uuid
model_uuid, current_dataset.uuid
)

assert res == DriftDTO.from_dict(
job_status=status,
model_type=model.model_type,
drift_data=None,
)

def test_get_missing_current_drift(self):
status = JobStatus.MISSING_CURRENT
model = db_mock.get_sample_model()
self.model_service.get_model_by_uuid = MagicMock(return_value=model)
self.current_dataset_dao.get_current_dataset_by_model_uuid = MagicMock(
return_value=None
)
res = self.metrics_service.get_current_drift(model_uuid, current_uuid)
self.model_service.get_model_by_uuid.assert_called_once_with(model.uuid)
self.current_dataset_dao.get_current_dataset_by_model_uuid.assert_called_once_with(
model_uuid, current_uuid
)

assert res == DriftDTO.from_dict(
job_status=status,
model_type=model.model_type,
drift_data=None,
)

Expand Down
27 changes: 4 additions & 23 deletions sdk/radicalbit_platform_sdk/apis/model_current_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from radicalbit_platform_sdk.commons import invoke
from radicalbit_platform_sdk.errors import ClientError
from radicalbit_platform_sdk.models import (
BinaryClassDrift,
ClassificationDataQuality,
CurrentBinaryClassificationModelQuality,
CurrentFileUpload,
Expand All @@ -17,10 +16,8 @@
JobStatus,
ModelQuality,
ModelType,
MultiClassDrift,
MultiClassificationModelQuality,
RegressionDataQuality,
RegressionDrift,
RegressionModelQuality,
)

Expand Down Expand Up @@ -122,26 +119,10 @@ def __callback(
response_json = response.json()
job_status = JobStatus(response_json['jobStatus'])
if 'drift' in response_json:
match self.__model_type:
case ModelType.BINARY:
return (
job_status,
BinaryClassDrift.model_validate(response_json['drift']),
)
case ModelType.MULTI_CLASS:
return (
job_status,
MultiClassDrift.model_validate(response_json['drift']),
)
case ModelType.REGRESSION:
return (
job_status,
RegressionDrift.model_validate(response_json['drift']),
)
case _:
raise ClientError(
'Unable to parse metrics because of not managed model type'
) from None
return (
job_status,
Drift.model_validate(response_json['drift']),
)
except KeyError as e:
raise ClientError(f'Unable to parse response: {response.text}') from e
except ValidationError as e:
Expand Down
6 changes: 0 additions & 6 deletions sdk/radicalbit_platform_sdk/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,10 @@
RegressionDataQuality,
)
from .dataset_drift import (
BinaryClassDrift,
Drift,
DriftAlgorithm,
FeatureDrift,
FeatureDriftCalculation,
MultiClassDrift,
RegressionDrift,
)
from .dataset_model_quality import (
BinaryClassificationModelQuality,
Expand Down Expand Up @@ -71,9 +68,6 @@
'FeatureDriftCalculation',
'FeatureDrift',
'Drift',
'BinaryClassDrift',
'MultiClassDrift',
'RegressionDrift',
'ReferenceFileUpload',
'CurrentFileUpload',
'FileReference',
Expand Down
11 changes: 0 additions & 11 deletions sdk/radicalbit_platform_sdk/models/dataset_drift.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,7 @@ class FeatureDrift(BaseModel):


class Drift(BaseModel):
pass


class BinaryClassDrift(Drift):
feature_metrics: List[FeatureDrift]

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class MultiClassDrift(Drift):
pass


class RegressionDrift(BaseModel):
pass
76 changes: 3 additions & 73 deletions sdk/tests/apis/model_current_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,15 @@
from radicalbit_platform_sdk.apis import ModelCurrentDataset
from radicalbit_platform_sdk.errors import ClientError
from radicalbit_platform_sdk.models import (
BinaryClassDrift,
ClassificationDataQuality,
CurrentBinaryClassificationModelQuality,
CurrentFileUpload,
Drift,
DriftAlgorithm,
JobStatus,
ModelType,
MultiClassDrift,
MultiClassificationModelQuality,
RegressionDataQuality,
RegressionDrift,
RegressionModelQuality,
)

Expand Down Expand Up @@ -141,7 +139,7 @@ def test_statistics_key_error(self):
model_current_dataset.statistics()

@responses.activate
def test_binary_class_drift_ok(self):
def test_drift_ok(self):
base_url = 'http://api:9000'
model_id = uuid.uuid4()
import_uuid = uuid.uuid4()
Expand Down Expand Up @@ -185,7 +183,7 @@ def test_binary_class_drift_ok(self):

drift = model_current_dataset.drift()

assert isinstance(drift, BinaryClassDrift)
assert isinstance(drift, Drift)

assert len(drift.feature_metrics) == 3
assert drift.feature_metrics[1].feature_name == 'city'
Expand All @@ -198,74 +196,6 @@ def test_binary_class_drift_ok(self):
assert drift.feature_metrics[2].drift_calc.has_drift is True
assert model_current_dataset.status() == JobStatus.SUCCEEDED

@responses.activate
def test_multi_class_drift_ok(self):
base_url = 'http://api:9000'
model_id = uuid.uuid4()
import_uuid = uuid.uuid4()
model_current_dataset = ModelCurrentDataset(
base_url,
model_id,
ModelType.MULTI_CLASS,
CurrentFileUpload(
uuid=import_uuid,
path='s3://bucket/file.csv',
date='2014',
correlation_id_column='column',
status=JobStatus.IMPORTING,
),
)

responses.add(
method=responses.GET,
url=f'{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/drift',
status=200,
body="""{
"jobStatus": "SUCCEEDED",
"drift": {}
}""",
)

drift = model_current_dataset.drift()

assert isinstance(drift, MultiClassDrift)
# TODO: add asserts to properties
assert model_current_dataset.status() == JobStatus.SUCCEEDED

@responses.activate
def test_regression_drift_ok(self):
base_url = 'http://api:9000'
model_id = uuid.uuid4()
import_uuid = uuid.uuid4()
model_current_dataset = ModelCurrentDataset(
base_url,
model_id,
ModelType.REGRESSION,
CurrentFileUpload(
uuid=import_uuid,
path='s3://bucket/file.csv',
date='2014',
correlation_id_column='column',
status=JobStatus.IMPORTING,
),
)

responses.add(
method=responses.GET,
url=f'{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/drift',
status=200,
body="""{
"jobStatus": "SUCCEEDED",
"drift": {}
}""",
)

drift = model_current_dataset.drift()

assert isinstance(drift, RegressionDrift)
# TODO: add asserts to properties
assert model_current_dataset.status() == JobStatus.SUCCEEDED

@responses.activate
def test_drift_validation_error(self):
base_url = 'http://api:9000'
Expand Down

0 comments on commit aed740b

Please sign in to comment.