From 08c4a0292bb5f1d3a3fd7015116b7f511b9cc0b7 Mon Sep 17 00:00:00 2001 From: Stefano Zamboni Date: Thu, 4 Jul 2024 17:47:02 +0200 Subject: [PATCH] feat: added support to binary and multiclass for predictions --- spark/jobs/models/data_quality.py | 2 + spark/jobs/utils/current_binary.py | 9 +++-- spark/jobs/utils/current_multiclass.py | 9 +++-- spark/jobs/utils/reference_binary.py | 11 ++++-- spark/jobs/utils/reference_multiclass.py | 11 ++++-- spark/tests/binary_current_test.py | 47 ++++++++++++++++++++++++ spark/tests/binary_reference_test.py | 32 ++++++++++++++++ spark/tests/multiclass_current_test.py | 23 ++++++++++++ spark/tests/multiclass_reference_test.py | 18 +++++++++ 9 files changed, 150 insertions(+), 12 deletions(-) diff --git a/spark/jobs/models/data_quality.py b/spark/jobs/models/data_quality.py index 5ee325b6..4eb792c5 100644 --- a/spark/jobs/models/data_quality.py +++ b/spark/jobs/models/data_quality.py @@ -165,12 +165,14 @@ class ClassMetrics(BaseModel): class BinaryClassDataQuality(BaseModel): n_observations: int class_metrics: List[ClassMetrics] + class_metrics_prediction: List[ClassMetrics] feature_metrics: List[FeatureMetrics] class MultiClassDataQuality(BaseModel): n_observations: int class_metrics: List[ClassMetrics] + class_metrics_prediction: List[ClassMetrics] feature_metrics: List[FeatureMetrics] diff --git a/spark/jobs/utils/current_binary.py b/spark/jobs/utils/current_binary.py index 56167e57..2a11e57e 100644 --- a/spark/jobs/utils/current_binary.py +++ b/spark/jobs/utils/current_binary.py @@ -81,9 +81,9 @@ def calculate_data_quality_categorical(self) -> List[CategoricalFeatureMetrics]: dataframe_count=self.current.current_count, ) - def calculate_class_metrics(self) -> List[ClassMetrics]: + def calculate_class_metrics(self, column) -> List[ClassMetrics]: metrics = DataQualityCalculator.class_metrics( - class_column=self.current.model.target.name, + class_column=column, dataframe=self.current.current, dataframe_count=self.current.current_count, ) @@ -118,7 +118,10 @@ def calculate_data_quality(self) -> BinaryClassDataQuality: feature_metrics.extend(self.calculate_data_quality_categorical()) return BinaryClassDataQuality( n_observations=self.current.current_count, - class_metrics=self.calculate_class_metrics(), + class_metrics=self.calculate_class_metrics(self.current.model.target.name), + class_metrics_prediction=self.calculate_class_metrics( + self.current.model.outputs.prediction.name + ), feature_metrics=feature_metrics, ) diff --git a/spark/jobs/utils/current_multiclass.py b/spark/jobs/utils/current_multiclass.py index 22738723..d33763ef 100644 --- a/spark/jobs/utils/current_multiclass.py +++ b/spark/jobs/utils/current_multiclass.py @@ -68,9 +68,9 @@ def calculate_data_quality_categorical(self) -> List[CategoricalFeatureMetrics]: dataframe_count=self.current.current_count, ) - def calculate_class_metrics(self) -> List[ClassMetrics]: + def calculate_class_metrics(self, column) -> List[ClassMetrics]: return DataQualityCalculator.class_metrics( - class_column=self.current.model.target.name, + class_column=column, dataframe=self.current.current, dataframe_count=self.current.current_count, ) @@ -217,7 +217,10 @@ def calculate_data_quality(self) -> MultiClassDataQuality: feature_metrics.extend(self.calculate_data_quality_categorical()) return MultiClassDataQuality( n_observations=self.current.current_count, - class_metrics=self.calculate_class_metrics(), + class_metrics=self.calculate_class_metrics(self.current.model.target.name), + class_metrics_prediction=self.calculate_class_metrics( + self.current.model.outputs.prediction.name + ), feature_metrics=feature_metrics, ) diff --git a/spark/jobs/utils/reference_binary.py b/spark/jobs/utils/reference_binary.py index 10d557b6..ae356052 100644 --- a/spark/jobs/utils/reference_binary.py +++ b/spark/jobs/utils/reference_binary.py @@ -163,9 +163,9 @@ def calculate_data_quality_categorical(self) -> List[CategoricalFeatureMetrics]: dataframe_count=self.reference.reference_count, ) - def calculate_class_metrics(self) -> List[ClassMetrics]: + def calculate_class_metrics(self, column) -> List[ClassMetrics]: metrics = DataQualityCalculator.class_metrics( - class_column=self.reference.model.target.name, + class_column=column, dataframe=self.reference.reference, dataframe_count=self.reference.reference_count, ) @@ -200,6 +200,11 @@ def calculate_data_quality(self) -> BinaryClassDataQuality: feature_metrics.extend(self.calculate_data_quality_categorical()) return BinaryClassDataQuality( n_observations=self.reference.reference_count, - class_metrics=self.calculate_class_metrics(), + class_metrics=self.calculate_class_metrics( + self.reference.model.target.name + ), + class_metrics_prediction=self.calculate_class_metrics( + self.reference.model.outputs.prediction.name + ), feature_metrics=feature_metrics, ) diff --git a/spark/jobs/utils/reference_multiclass.py b/spark/jobs/utils/reference_multiclass.py index a7310213..573ea7d6 100644 --- a/spark/jobs/utils/reference_multiclass.py +++ b/spark/jobs/utils/reference_multiclass.py @@ -144,9 +144,9 @@ def calculate_data_quality_categorical(self) -> List[CategoricalFeatureMetrics]: dataframe_count=self.reference.reference_count, ) - def calculate_class_metrics(self) -> List[ClassMetrics]: + def calculate_class_metrics(self, column) -> List[ClassMetrics]: return DataQualityCalculator.class_metrics( - class_column=self.reference.model.target.name, + class_column=column, dataframe=self.reference.reference, dataframe_count=self.reference.reference_count, ) @@ -159,6 +159,11 @@ def calculate_data_quality(self) -> MultiClassDataQuality: feature_metrics.extend(self.calculate_data_quality_categorical()) return MultiClassDataQuality( n_observations=self.reference.reference_count, - class_metrics=self.calculate_class_metrics(), + class_metrics=self.calculate_class_metrics( + self.reference.model.target.name + ), + class_metrics_prediction=self.calculate_class_metrics( + self.reference.model.outputs.prediction.name + ), feature_metrics=feature_metrics, ) diff --git a/spark/tests/binary_current_test.py b/spark/tests/binary_current_test.py index 04e4634c..c97f3368 100644 --- a/spark/tests/binary_current_test.py +++ b/spark/tests/binary_current_test.py @@ -284,6 +284,9 @@ def test_calculation(spark_fixture, dataset): ignore_order=True, significant_digits=6, ) + print( + data_quality.model_dump(serialize_as_any=True, exclude_none=True), + ) assert not deepdiff.DeepDiff( data_quality.model_dump(serialize_as_any=True, exclude_none=True), @@ -293,6 +296,10 @@ def test_calculation(spark_fixture, dataset): {"name": "1.0", "count": 6, "percentage": 60.0}, {"name": "0.0", "count": 4, "percentage": 40.0}, ], + "class_metrics_prediction": [ + {"name": "1.0", "count": 5, "percentage": 50.0}, + {"name": "0.0", "count": 5, "percentage": 50.0}, + ], "feature_metrics": [ { "feature_name": "num1", @@ -463,6 +470,10 @@ def test_calculation_current_joined(spark_fixture, current_joined): {"name": "1.0", "count": 131, "percentage": 55.04201680672269}, {"name": "0.0", "count": 107, "percentage": 44.957983193277315}, ], + "class_metrics_prediction": [ + {"name": "1.0", "count": 133, "percentage": 55.88235294117647}, + {"name": "0.0", "count": 105, "percentage": 44.11764705882353}, + ], "feature_metrics": [ { "feature_name": "age", @@ -855,6 +866,10 @@ def test_calculation_complete(spark_fixture, complete_dataset): {"name": "1.0", "count": 7, "percentage": 100.0}, {"name": "0.0", "count": 0, "percentage": 0.0}, ], + "class_metrics_prediction": [ + {"name": "1.0", "count": 7, "percentage": 100.0}, + {"name": "0.0", "count": 0, "percentage": 0.0}, + ], "feature_metrics": [ { "feature_name": "num1", @@ -1003,6 +1018,10 @@ def test_calculation_easy_dataset(spark_fixture, easy_dataset): {"name": "1.0", "count": 6, "percentage": 85.71428571428571}, {"name": "0.0", "count": 1, "percentage": 14.285714285714285}, ], + "class_metrics_prediction": [ + {"name": "1.0", "count": 6, "percentage": 85.71428571428571}, + {"name": "0.0", "count": 1, "percentage": 14.285714285714285}, + ], "feature_metrics": [ { "feature_name": "num1", @@ -1151,6 +1170,10 @@ def test_calculation_dataset_cat_missing(spark_fixture, dataset_cat_missing): {"name": "1.0", "count": 6, "percentage": 60.0}, {"name": "0.0", "count": 4, "percentage": 40.0}, ], + "class_metrics_prediction": [ + {"name": "1.0", "count": 5, "percentage": 50.0}, + {"name": "0.0", "count": 5, "percentage": 50.0}, + ], "feature_metrics": [ { "feature_name": "num1", @@ -1314,6 +1337,10 @@ def test_calculation_dataset_with_datetime(spark_fixture, dataset_with_datetime) {"name": "1.0", "count": 6, "percentage": 60.0}, {"name": "0.0", "count": 4, "percentage": 40.0}, ], + "class_metrics_prediction": [ + {"name": "1.0", "count": 5, "percentage": 50.0}, + {"name": "0.0", "count": 5, "percentage": 50.0}, + ], "feature_metrics": [ { "feature_name": "num1", @@ -1477,6 +1504,10 @@ def test_calculation_easy_dataset_bucket_test(spark_fixture, easy_dataset_bucket {"name": "1.0", "count": 6, "percentage": 85.71428571428571}, {"name": "0.0", "count": 1, "percentage": 14.285714285714285}, ], + "class_metrics_prediction": [ + {"name": "1.0", "count": 6, "percentage": 85.71428571428571}, + {"name": "0.0", "count": 1, "percentage": 14.285714285714285}, + ], "feature_metrics": [ { "feature_name": "num1", @@ -1782,6 +1813,10 @@ def test_calculation_for_hour(spark_fixture, dataset_for_hour): {"name": "1.0", "count": 6, "percentage": 60.0}, {"name": "0.0", "count": 4, "percentage": 40.0}, ], + "class_metrics_prediction": [ + {"name": "1.0", "count": 5, "percentage": 50.0}, + {"name": "0.0", "count": 5, "percentage": 50.0}, + ], "feature_metrics": [ { "feature_name": "num1", @@ -2060,6 +2095,10 @@ def test_calculation_for_day(spark_fixture, dataset_for_day): {"name": "1.0", "count": 6, "percentage": 60.0}, {"name": "0.0", "count": 4, "percentage": 40.0}, ], + "class_metrics_prediction": [ + {"name": "1.0", "count": 5, "percentage": 50.0}, + {"name": "0.0", "count": 5, "percentage": 50.0}, + ], "feature_metrics": [ { "feature_name": "num1", @@ -2338,6 +2377,10 @@ def test_calculation_for_week(spark_fixture, dataset_for_week): {"name": "1.0", "count": 6, "percentage": 60.0}, {"name": "0.0", "count": 4, "percentage": 40.0}, ], + "class_metrics_prediction": [ + {"name": "1.0", "count": 5, "percentage": 50.0}, + {"name": "0.0", "count": 5, "percentage": 50.0}, + ], "feature_metrics": [ { "feature_name": "num1", @@ -2602,6 +2645,10 @@ def test_calculation_for_month(spark_fixture, dataset_for_month): {"name": "1.0", "count": 6, "percentage": 60.0}, {"name": "0.0", "count": 4, "percentage": 40.0}, ], + "class_metrics_prediction": [ + {"name": "1.0", "count": 5, "percentage": 50.0}, + {"name": "0.0", "count": 5, "percentage": 50.0}, + ], "feature_metrics": [ { "feature_name": "num1", diff --git a/spark/tests/binary_reference_test.py b/spark/tests/binary_reference_test.py index 828cb20c..cbbda105 100644 --- a/spark/tests/binary_reference_test.py +++ b/spark/tests/binary_reference_test.py @@ -169,6 +169,10 @@ def test_calculation(spark_fixture, dataset): {"name": "1.0", "count": 6, "percentage": 60.0}, {"name": "0.0", "count": 4, "percentage": 40.0}, ], + "class_metrics_prediction": [ + {"name": "1.0", "count": 5, "percentage": 50.0}, + {"name": "0.0", "count": 5, "percentage": 50.0}, + ], "feature_metrics": [ { "feature_name": "num1", @@ -351,6 +355,10 @@ def test_calculation_reference_joined(spark_fixture, reference_joined): {"name": "1.0", "count": 131, "percentage": 55.04201680672269}, {"name": "0.0", "count": 107, "percentage": 44.957983193277315}, ], + "class_metrics_prediction": [ + {"name": "1.0", "count": 133, "percentage": 55.88235294117647}, + {"name": "0.0", "count": 105, "percentage": 44.11764705882353}, + ], "feature_metrics": [ { "feature_name": "age", @@ -748,6 +756,10 @@ def test_calculation_complete(spark_fixture, complete_dataset): {"name": "1.0", "count": 7, "percentage": 100.0}, {"name": "0.0", "count": 0, "percentage": 0.0}, ], + "class_metrics_prediction": [ + {"name": "1.0", "count": 7, "percentage": 100.0}, + {"name": "0.0", "count": 0, "percentage": 0.0}, + ], "feature_metrics": [ { "feature_name": "num1", @@ -902,6 +914,10 @@ def test_calculation_easy_dataset(spark_fixture, easy_dataset): {"name": "1.0", "count": 6, "percentage": 85.71428571428571}, {"name": "0.0", "count": 1, "percentage": 14.285714285714285}, ], + "class_metrics_prediction": [ + {"name": "1.0", "count": 6, "percentage": 85.71428571428571}, + {"name": "0.0", "count": 1, "percentage": 14.285714285714285}, + ], "feature_metrics": [ { "feature_name": "num1", @@ -1056,6 +1072,10 @@ def test_calculation_dataset_cat_missing(spark_fixture, dataset_cat_missing): {"name": "1.0", "count": 6, "percentage": 60.0}, {"name": "0.0", "count": 4, "percentage": 40.0}, ], + "class_metrics_prediction": [ + {"name": "1.0", "count": 5, "percentage": 50.0}, + {"name": "0.0", "count": 5, "percentage": 50.0}, + ], "feature_metrics": [ { "feature_name": "num1", @@ -1233,6 +1253,10 @@ def test_calculation_dataset_with_datetime(spark_fixture, dataset_with_datetime) {"name": "1.0", "count": 6, "percentage": 60.0}, {"name": "0.0", "count": 4, "percentage": 40.0}, ], + "class_metrics_prediction": [ + {"name": "1.0", "count": 5, "percentage": 50.0}, + {"name": "0.0", "count": 5, "percentage": 50.0}, + ], "feature_metrics": [ { "feature_name": "num1", @@ -1416,6 +1440,10 @@ def test_calculation_enhanced_data(spark_fixture, enhanced_data): {"name": "1.0", "count": 14962, "percentage": 49.87333333333333}, {"name": "0.0", "count": 15038, "percentage": 50.126666666666665}, ], + "class_metrics_prediction": [ + {"name": "1.0", "count": 29967, "percentage": 99.89}, + {"name": "0.0", "count": 33, "percentage": 0.11}, + ], "feature_metrics": [ { "feature_name": "feature_0", @@ -1956,6 +1984,10 @@ def test_calculation_dataset_bool_missing(spark_fixture, dataset_bool_missing): {"name": "1.0", "count": 6, "percentage": 60.0}, {"name": "0.0", "count": 4, "percentage": 40.0}, ], + "class_metrics_prediction": [ + {"name": "1.0", "count": 5, "percentage": 50.0}, + {"name": "0.0", "count": 5, "percentage": 50.0}, + ], "feature_metrics": [ { "feature_name": "num1", diff --git a/spark/tests/multiclass_current_test.py b/spark/tests/multiclass_current_test.py index ce5a99c1..7578404e 100644 --- a/spark/tests/multiclass_current_test.py +++ b/spark/tests/multiclass_current_test.py @@ -161,6 +161,12 @@ def test_calculation_dataset_target_int(spark_fixture, dataset_target_int): {"name": "2", "count": 3, "percentage": 30.0}, {"name": "0", "count": 3, "percentage": 30.0}, ], + "class_metrics_prediction": [ + {"name": "3", "count": 2, "percentage": 20.0}, + {"name": "0", "count": 2, "percentage": 20.0}, + {"name": "1", "count": 4, "percentage": 40.0}, + {"name": "2", "count": 2, "percentage": 20.0}, + ], "feature_metrics": [ { "feature_name": "num1", @@ -474,6 +480,12 @@ def test_calculation_dataset_target_string(spark_fixture, dataset_target_string) {"name": "HEALTHY", "count": 3, "percentage": 30.0}, {"name": "UNKNOWN", "count": 3, "percentage": 30.0}, ], + "class_metrics_prediction": [ + {"name": "ORPHAN", "count": 2, "percentage": 20.0}, + {"name": "UNHEALTHY", "count": 2, "percentage": 20.0}, + {"name": "HEALTHY", "count": 4, "percentage": 40.0}, + {"name": "UNKNOWN", "count": 2, "percentage": 20.0}, + ], "feature_metrics": [ { "feature_name": "num1", @@ -787,6 +799,12 @@ def test_calculation_dataset_perfect_classes(spark_fixture, dataset_perfect_clas {"name": "HEALTHY", "count": 4, "percentage": 40.0}, {"name": "UNKNOWN", "count": 2, "percentage": 20.0}, ], + "class_metrics_prediction": [ + {"name": "ORPHAN", "count": 2, "percentage": 20.0}, + {"name": "UNHEALTHY", "count": 2, "percentage": 20.0}, + {"name": "HEALTHY", "count": 4, "percentage": 40.0}, + {"name": "UNKNOWN", "count": 2, "percentage": 20.0}, + ], "feature_metrics": [ { "feature_name": "num1", @@ -1073,6 +1091,11 @@ def test_calculation_dataset_for_hour(spark_fixture, dataset_for_hour): {"name": "COW", "count": 3, "percentage": 30.0}, {"name": "CAT", "count": 5, "percentage": 50.0}, ], + "class_metrics_prediction": [ + {"name": "DOG", "count": 3, "percentage": 30.0}, + {"name": "COW", "count": 3, "percentage": 30.0}, + {"name": "CAT", "count": 4, "percentage": 40.0}, + ], "feature_metrics": [ { "feature_name": "num1", diff --git a/spark/tests/multiclass_reference_test.py b/spark/tests/multiclass_reference_test.py index 33ac0783..a7ed373d 100644 --- a/spark/tests/multiclass_reference_test.py +++ b/spark/tests/multiclass_reference_test.py @@ -115,6 +115,12 @@ def test_calculation_dataset_target_int(spark_fixture, dataset_target_int): {"name": "2", "count": 3, "percentage": 30.0}, {"name": "0", "count": 3, "percentage": 30.0}, ], + "class_metrics_prediction": [ + {"name": "3", "count": 2, "percentage": 20.0}, + {"name": "0", "count": 2, "percentage": 20.0}, + {"name": "1", "count": 4, "percentage": 40.0}, + {"name": "2", "count": 2, "percentage": 20.0}, + ], "feature_metrics": [ { "feature_name": "num1", @@ -334,6 +340,12 @@ def test_calculation_dataset_target_string(spark_fixture, dataset_target_string) {"name": "HEALTHY", "count": 3, "percentage": 30.0}, {"name": "UNKNOWN", "count": 3, "percentage": 30.0}, ], + "class_metrics_prediction": [ + {"name": "ORPHAN", "count": 2, "percentage": 20.0}, + {"name": "UNHEALTHY", "count": 2, "percentage": 20.0}, + {"name": "HEALTHY", "count": 4, "percentage": 40.0}, + {"name": "UNKNOWN", "count": 2, "percentage": 20.0}, + ], "feature_metrics": [ { "feature_name": "num1", @@ -553,6 +565,12 @@ def test_calculation_dataset_perfect_classes(spark_fixture, dataset_perfect_clas {"name": "HEALTHY", "count": 4, "percentage": 40.0}, {"name": "UNKNOWN", "count": 2, "percentage": 20.0}, ], + "class_metrics_prediction": [ + {"name": "ORPHAN", "count": 2, "percentage": 20.0}, + {"name": "UNHEALTHY", "count": 2, "percentage": 20.0}, + {"name": "HEALTHY", "count": 4, "percentage": 40.0}, + {"name": "UNKNOWN", "count": 2, "percentage": 20.0}, + ], "feature_metrics": [ { "feature_name": "num1",