Skip to content

Commit

Permalink
feat: added support to binary and multiclass for predictions
Browse files Browse the repository at this point in the history
  • Loading branch information
SteZamboni committed Jul 4, 2024
1 parent 02659e7 commit 08c4a02
Show file tree
Hide file tree
Showing 9 changed files with 150 additions and 12 deletions.
2 changes: 2 additions & 0 deletions spark/jobs/models/data_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,14 @@ class ClassMetrics(BaseModel):
class BinaryClassDataQuality(BaseModel):
n_observations: int
class_metrics: List[ClassMetrics]
class_metrics_prediction: List[ClassMetrics]
feature_metrics: List[FeatureMetrics]


class MultiClassDataQuality(BaseModel):
n_observations: int
class_metrics: List[ClassMetrics]
class_metrics_prediction: List[ClassMetrics]
feature_metrics: List[FeatureMetrics]


Expand Down
9 changes: 6 additions & 3 deletions spark/jobs/utils/current_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ def calculate_data_quality_categorical(self) -> List[CategoricalFeatureMetrics]:
dataframe_count=self.current.current_count,
)

def calculate_class_metrics(self) -> List[ClassMetrics]:
def calculate_class_metrics(self, column) -> List[ClassMetrics]:
metrics = DataQualityCalculator.class_metrics(
class_column=self.current.model.target.name,
class_column=column,
dataframe=self.current.current,
dataframe_count=self.current.current_count,
)
Expand Down Expand Up @@ -118,7 +118,10 @@ def calculate_data_quality(self) -> BinaryClassDataQuality:
feature_metrics.extend(self.calculate_data_quality_categorical())
return BinaryClassDataQuality(
n_observations=self.current.current_count,
class_metrics=self.calculate_class_metrics(),
class_metrics=self.calculate_class_metrics(self.current.model.target.name),
class_metrics_prediction=self.calculate_class_metrics(
self.current.model.outputs.prediction.name
),
feature_metrics=feature_metrics,
)

Expand Down
9 changes: 6 additions & 3 deletions spark/jobs/utils/current_multiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def calculate_data_quality_categorical(self) -> List[CategoricalFeatureMetrics]:
dataframe_count=self.current.current_count,
)

def calculate_class_metrics(self) -> List[ClassMetrics]:
def calculate_class_metrics(self, column) -> List[ClassMetrics]:
return DataQualityCalculator.class_metrics(
class_column=self.current.model.target.name,
class_column=column,
dataframe=self.current.current,
dataframe_count=self.current.current_count,
)
Expand Down Expand Up @@ -217,7 +217,10 @@ def calculate_data_quality(self) -> MultiClassDataQuality:
feature_metrics.extend(self.calculate_data_quality_categorical())
return MultiClassDataQuality(
n_observations=self.current.current_count,
class_metrics=self.calculate_class_metrics(),
class_metrics=self.calculate_class_metrics(self.current.model.target.name),
class_metrics_prediction=self.calculate_class_metrics(
self.current.model.outputs.prediction.name
),
feature_metrics=feature_metrics,
)

Expand Down
11 changes: 8 additions & 3 deletions spark/jobs/utils/reference_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,9 @@ def calculate_data_quality_categorical(self) -> List[CategoricalFeatureMetrics]:
dataframe_count=self.reference.reference_count,
)

def calculate_class_metrics(self) -> List[ClassMetrics]:
def calculate_class_metrics(self, column) -> List[ClassMetrics]:
metrics = DataQualityCalculator.class_metrics(
class_column=self.reference.model.target.name,
class_column=column,
dataframe=self.reference.reference,
dataframe_count=self.reference.reference_count,
)
Expand Down Expand Up @@ -200,6 +200,11 @@ def calculate_data_quality(self) -> BinaryClassDataQuality:
feature_metrics.extend(self.calculate_data_quality_categorical())
return BinaryClassDataQuality(
n_observations=self.reference.reference_count,
class_metrics=self.calculate_class_metrics(),
class_metrics=self.calculate_class_metrics(
self.reference.model.target.name
),
class_metrics_prediction=self.calculate_class_metrics(
self.reference.model.outputs.prediction.name
),
feature_metrics=feature_metrics,
)
11 changes: 8 additions & 3 deletions spark/jobs/utils/reference_multiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,9 @@ def calculate_data_quality_categorical(self) -> List[CategoricalFeatureMetrics]:
dataframe_count=self.reference.reference_count,
)

def calculate_class_metrics(self) -> List[ClassMetrics]:
def calculate_class_metrics(self, column) -> List[ClassMetrics]:
return DataQualityCalculator.class_metrics(
class_column=self.reference.model.target.name,
class_column=column,
dataframe=self.reference.reference,
dataframe_count=self.reference.reference_count,
)
Expand All @@ -159,6 +159,11 @@ def calculate_data_quality(self) -> MultiClassDataQuality:
feature_metrics.extend(self.calculate_data_quality_categorical())
return MultiClassDataQuality(
n_observations=self.reference.reference_count,
class_metrics=self.calculate_class_metrics(),
class_metrics=self.calculate_class_metrics(
self.reference.model.target.name
),
class_metrics_prediction=self.calculate_class_metrics(
self.reference.model.outputs.prediction.name
),
feature_metrics=feature_metrics,
)
47 changes: 47 additions & 0 deletions spark/tests/binary_current_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,9 @@ def test_calculation(spark_fixture, dataset):
ignore_order=True,
significant_digits=6,
)
print(
data_quality.model_dump(serialize_as_any=True, exclude_none=True),
)

assert not deepdiff.DeepDiff(
data_quality.model_dump(serialize_as_any=True, exclude_none=True),
Expand All @@ -293,6 +296,10 @@ def test_calculation(spark_fixture, dataset):
{"name": "1.0", "count": 6, "percentage": 60.0},
{"name": "0.0", "count": 4, "percentage": 40.0},
],
"class_metrics_prediction": [
{"name": "1.0", "count": 5, "percentage": 50.0},
{"name": "0.0", "count": 5, "percentage": 50.0},
],
"feature_metrics": [
{
"feature_name": "num1",
Expand Down Expand Up @@ -463,6 +470,10 @@ def test_calculation_current_joined(spark_fixture, current_joined):
{"name": "1.0", "count": 131, "percentage": 55.04201680672269},
{"name": "0.0", "count": 107, "percentage": 44.957983193277315},
],
"class_metrics_prediction": [
{"name": "1.0", "count": 133, "percentage": 55.88235294117647},
{"name": "0.0", "count": 105, "percentage": 44.11764705882353},
],
"feature_metrics": [
{
"feature_name": "age",
Expand Down Expand Up @@ -855,6 +866,10 @@ def test_calculation_complete(spark_fixture, complete_dataset):
{"name": "1.0", "count": 7, "percentage": 100.0},
{"name": "0.0", "count": 0, "percentage": 0.0},
],
"class_metrics_prediction": [
{"name": "1.0", "count": 7, "percentage": 100.0},
{"name": "0.0", "count": 0, "percentage": 0.0},
],
"feature_metrics": [
{
"feature_name": "num1",
Expand Down Expand Up @@ -1003,6 +1018,10 @@ def test_calculation_easy_dataset(spark_fixture, easy_dataset):
{"name": "1.0", "count": 6, "percentage": 85.71428571428571},
{"name": "0.0", "count": 1, "percentage": 14.285714285714285},
],
"class_metrics_prediction": [
{"name": "1.0", "count": 6, "percentage": 85.71428571428571},
{"name": "0.0", "count": 1, "percentage": 14.285714285714285},
],
"feature_metrics": [
{
"feature_name": "num1",
Expand Down Expand Up @@ -1151,6 +1170,10 @@ def test_calculation_dataset_cat_missing(spark_fixture, dataset_cat_missing):
{"name": "1.0", "count": 6, "percentage": 60.0},
{"name": "0.0", "count": 4, "percentage": 40.0},
],
"class_metrics_prediction": [
{"name": "1.0", "count": 5, "percentage": 50.0},
{"name": "0.0", "count": 5, "percentage": 50.0},
],
"feature_metrics": [
{
"feature_name": "num1",
Expand Down Expand Up @@ -1314,6 +1337,10 @@ def test_calculation_dataset_with_datetime(spark_fixture, dataset_with_datetime)
{"name": "1.0", "count": 6, "percentage": 60.0},
{"name": "0.0", "count": 4, "percentage": 40.0},
],
"class_metrics_prediction": [
{"name": "1.0", "count": 5, "percentage": 50.0},
{"name": "0.0", "count": 5, "percentage": 50.0},
],
"feature_metrics": [
{
"feature_name": "num1",
Expand Down Expand Up @@ -1477,6 +1504,10 @@ def test_calculation_easy_dataset_bucket_test(spark_fixture, easy_dataset_bucket
{"name": "1.0", "count": 6, "percentage": 85.71428571428571},
{"name": "0.0", "count": 1, "percentage": 14.285714285714285},
],
"class_metrics_prediction": [
{"name": "1.0", "count": 6, "percentage": 85.71428571428571},
{"name": "0.0", "count": 1, "percentage": 14.285714285714285},
],
"feature_metrics": [
{
"feature_name": "num1",
Expand Down Expand Up @@ -1782,6 +1813,10 @@ def test_calculation_for_hour(spark_fixture, dataset_for_hour):
{"name": "1.0", "count": 6, "percentage": 60.0},
{"name": "0.0", "count": 4, "percentage": 40.0},
],
"class_metrics_prediction": [
{"name": "1.0", "count": 5, "percentage": 50.0},
{"name": "0.0", "count": 5, "percentage": 50.0},
],
"feature_metrics": [
{
"feature_name": "num1",
Expand Down Expand Up @@ -2060,6 +2095,10 @@ def test_calculation_for_day(spark_fixture, dataset_for_day):
{"name": "1.0", "count": 6, "percentage": 60.0},
{"name": "0.0", "count": 4, "percentage": 40.0},
],
"class_metrics_prediction": [
{"name": "1.0", "count": 5, "percentage": 50.0},
{"name": "0.0", "count": 5, "percentage": 50.0},
],
"feature_metrics": [
{
"feature_name": "num1",
Expand Down Expand Up @@ -2338,6 +2377,10 @@ def test_calculation_for_week(spark_fixture, dataset_for_week):
{"name": "1.0", "count": 6, "percentage": 60.0},
{"name": "0.0", "count": 4, "percentage": 40.0},
],
"class_metrics_prediction": [
{"name": "1.0", "count": 5, "percentage": 50.0},
{"name": "0.0", "count": 5, "percentage": 50.0},
],
"feature_metrics": [
{
"feature_name": "num1",
Expand Down Expand Up @@ -2602,6 +2645,10 @@ def test_calculation_for_month(spark_fixture, dataset_for_month):
{"name": "1.0", "count": 6, "percentage": 60.0},
{"name": "0.0", "count": 4, "percentage": 40.0},
],
"class_metrics_prediction": [
{"name": "1.0", "count": 5, "percentage": 50.0},
{"name": "0.0", "count": 5, "percentage": 50.0},
],
"feature_metrics": [
{
"feature_name": "num1",
Expand Down
32 changes: 32 additions & 0 deletions spark/tests/binary_reference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ def test_calculation(spark_fixture, dataset):
{"name": "1.0", "count": 6, "percentage": 60.0},
{"name": "0.0", "count": 4, "percentage": 40.0},
],
"class_metrics_prediction": [
{"name": "1.0", "count": 5, "percentage": 50.0},
{"name": "0.0", "count": 5, "percentage": 50.0},
],
"feature_metrics": [
{
"feature_name": "num1",
Expand Down Expand Up @@ -351,6 +355,10 @@ def test_calculation_reference_joined(spark_fixture, reference_joined):
{"name": "1.0", "count": 131, "percentage": 55.04201680672269},
{"name": "0.0", "count": 107, "percentage": 44.957983193277315},
],
"class_metrics_prediction": [
{"name": "1.0", "count": 133, "percentage": 55.88235294117647},
{"name": "0.0", "count": 105, "percentage": 44.11764705882353},
],
"feature_metrics": [
{
"feature_name": "age",
Expand Down Expand Up @@ -748,6 +756,10 @@ def test_calculation_complete(spark_fixture, complete_dataset):
{"name": "1.0", "count": 7, "percentage": 100.0},
{"name": "0.0", "count": 0, "percentage": 0.0},
],
"class_metrics_prediction": [
{"name": "1.0", "count": 7, "percentage": 100.0},
{"name": "0.0", "count": 0, "percentage": 0.0},
],
"feature_metrics": [
{
"feature_name": "num1",
Expand Down Expand Up @@ -902,6 +914,10 @@ def test_calculation_easy_dataset(spark_fixture, easy_dataset):
{"name": "1.0", "count": 6, "percentage": 85.71428571428571},
{"name": "0.0", "count": 1, "percentage": 14.285714285714285},
],
"class_metrics_prediction": [
{"name": "1.0", "count": 6, "percentage": 85.71428571428571},
{"name": "0.0", "count": 1, "percentage": 14.285714285714285},
],
"feature_metrics": [
{
"feature_name": "num1",
Expand Down Expand Up @@ -1056,6 +1072,10 @@ def test_calculation_dataset_cat_missing(spark_fixture, dataset_cat_missing):
{"name": "1.0", "count": 6, "percentage": 60.0},
{"name": "0.0", "count": 4, "percentage": 40.0},
],
"class_metrics_prediction": [
{"name": "1.0", "count": 5, "percentage": 50.0},
{"name": "0.0", "count": 5, "percentage": 50.0},
],
"feature_metrics": [
{
"feature_name": "num1",
Expand Down Expand Up @@ -1233,6 +1253,10 @@ def test_calculation_dataset_with_datetime(spark_fixture, dataset_with_datetime)
{"name": "1.0", "count": 6, "percentage": 60.0},
{"name": "0.0", "count": 4, "percentage": 40.0},
],
"class_metrics_prediction": [
{"name": "1.0", "count": 5, "percentage": 50.0},
{"name": "0.0", "count": 5, "percentage": 50.0},
],
"feature_metrics": [
{
"feature_name": "num1",
Expand Down Expand Up @@ -1416,6 +1440,10 @@ def test_calculation_enhanced_data(spark_fixture, enhanced_data):
{"name": "1.0", "count": 14962, "percentage": 49.87333333333333},
{"name": "0.0", "count": 15038, "percentage": 50.126666666666665},
],
"class_metrics_prediction": [
{"name": "1.0", "count": 29967, "percentage": 99.89},
{"name": "0.0", "count": 33, "percentage": 0.11},
],
"feature_metrics": [
{
"feature_name": "feature_0",
Expand Down Expand Up @@ -1956,6 +1984,10 @@ def test_calculation_dataset_bool_missing(spark_fixture, dataset_bool_missing):
{"name": "1.0", "count": 6, "percentage": 60.0},
{"name": "0.0", "count": 4, "percentage": 40.0},
],
"class_metrics_prediction": [
{"name": "1.0", "count": 5, "percentage": 50.0},
{"name": "0.0", "count": 5, "percentage": 50.0},
],
"feature_metrics": [
{
"feature_name": "num1",
Expand Down
23 changes: 23 additions & 0 deletions spark/tests/multiclass_current_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,12 @@ def test_calculation_dataset_target_int(spark_fixture, dataset_target_int):
{"name": "2", "count": 3, "percentage": 30.0},
{"name": "0", "count": 3, "percentage": 30.0},
],
"class_metrics_prediction": [
{"name": "3", "count": 2, "percentage": 20.0},
{"name": "0", "count": 2, "percentage": 20.0},
{"name": "1", "count": 4, "percentage": 40.0},
{"name": "2", "count": 2, "percentage": 20.0},
],
"feature_metrics": [
{
"feature_name": "num1",
Expand Down Expand Up @@ -474,6 +480,12 @@ def test_calculation_dataset_target_string(spark_fixture, dataset_target_string)
{"name": "HEALTHY", "count": 3, "percentage": 30.0},
{"name": "UNKNOWN", "count": 3, "percentage": 30.0},
],
"class_metrics_prediction": [
{"name": "ORPHAN", "count": 2, "percentage": 20.0},
{"name": "UNHEALTHY", "count": 2, "percentage": 20.0},
{"name": "HEALTHY", "count": 4, "percentage": 40.0},
{"name": "UNKNOWN", "count": 2, "percentage": 20.0},
],
"feature_metrics": [
{
"feature_name": "num1",
Expand Down Expand Up @@ -787,6 +799,12 @@ def test_calculation_dataset_perfect_classes(spark_fixture, dataset_perfect_clas
{"name": "HEALTHY", "count": 4, "percentage": 40.0},
{"name": "UNKNOWN", "count": 2, "percentage": 20.0},
],
"class_metrics_prediction": [
{"name": "ORPHAN", "count": 2, "percentage": 20.0},
{"name": "UNHEALTHY", "count": 2, "percentage": 20.0},
{"name": "HEALTHY", "count": 4, "percentage": 40.0},
{"name": "UNKNOWN", "count": 2, "percentage": 20.0},
],
"feature_metrics": [
{
"feature_name": "num1",
Expand Down Expand Up @@ -1073,6 +1091,11 @@ def test_calculation_dataset_for_hour(spark_fixture, dataset_for_hour):
{"name": "COW", "count": 3, "percentage": 30.0},
{"name": "CAT", "count": 5, "percentage": 50.0},
],
"class_metrics_prediction": [
{"name": "DOG", "count": 3, "percentage": 30.0},
{"name": "COW", "count": 3, "percentage": 30.0},
{"name": "CAT", "count": 4, "percentage": 40.0},
],
"feature_metrics": [
{
"feature_name": "num1",
Expand Down
Loading

0 comments on commit 08c4a02

Please sign in to comment.