Skip to content

Commit

Permalink
feat: reference multiclass model quality (#49)
Browse files Browse the repository at this point in the history
* feat: add confusion matrix

* feat: model quality reference multiclass

* feat: add calc in job

* refactor: improve compose

* fix: handle null
  • Loading branch information
rivamarco committed Jun 28, 2024
1 parent 9bf435a commit e0e68d2
Show file tree
Hide file tree
Showing 8 changed files with 479 additions and 13 deletions.
2 changes: 2 additions & 0 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ services:
stdin_open: true
tty: true
command: -A
ports:
- 4040:4040 # Spark UI port if forwarded from k9s
volumes:
- ./docker/k3s_data/kubeconfig/kubeconfig.yaml:/root/.kube/config

Expand Down
9 changes: 7 additions & 2 deletions spark/jobs/models/reference_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,12 @@ def get_string_indexed_dataframe(self):
self.model.target.name, "classes"
)
prediction_target_df = predictions_df.union(target_df)
indexer = StringIndexer(inputCol="classes", outputCol="classes_index")
indexer = StringIndexer(
inputCol="classes",
outputCol="classes_index",
stringOrderType="alphabetAsc",
handleInvalid="skip",
)
indexer_model = indexer.fit(prediction_target_df)
indexer_prediction = indexer_model.setInputCol(
self.model.outputs.prediction.name
Expand All @@ -115,7 +120,7 @@ def get_string_indexed_dataframe(self):
indexed_target_df = indexer_target.transform(indexed_prediction_df)

index_label_map = {
str(float(index)): label
str(float(index)): str(label)
for index, label in enumerate(indexer_model.labelsArray[0])
}
return index_label_map, indexed_target_df
5 changes: 4 additions & 1 deletion spark/jobs/reference_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,19 @@ def main(
serialize_as_any=True
)
case ModelType.MULTI_CLASS:
# TODO add data quality and model quality
metrics_service = ReferenceMetricsMulticlassService(
reference=reference_dataset
)
statistics = calculate_statistics_reference(reference_dataset)
data_quality = metrics_service.calculate_data_quality()
model_quality = metrics_service.calculate_model_quality()
complete_record["STATISTICS"] = orjson.dumps(statistics).decode("utf-8")
complete_record["DATA_QUALITY"] = data_quality.model_dump_json(
serialize_as_any=True
)
complete_record["MODEL_QUALITY"] = orjson.dumps(model_quality).decode(
"utf-8"
)
case ModelType.REGRESSION:
metrics_service = ReferenceMetricsRegressionService(
reference=reference_dataset
Expand Down
86 changes: 85 additions & 1 deletion spark/jobs/utils/reference_multiclass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from typing import List
from typing import List, Dict

from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.mllib.evaluation import MulticlassMetrics
from pyspark.sql import DataFrame

from metrics.data_quality_calculator import DataQualityCalculator
from models.reference_dataset import ReferenceDataset
Expand All @@ -16,6 +20,86 @@ def __init__(self, reference: ReferenceDataset):
index_label_map, indexed_reference = reference.get_string_indexed_dataframe()
self.index_label_map = index_label_map
self.indexed_reference = indexed_reference
self.model_quality_multiclass_classificator_global = {
"f1": "f1",
"accuracy": "accuracy",
"weightedPrecision": "weighted_precision",
"weightedRecall": "weighted_recall",
"weightedTruePositiveRate": "weighted_true_positive_rate",
"weightedFalsePositiveRate": "weighted_false_positive_rate",
"weightedFMeasure": "weighted_f_measure",
}
self.model_quality_multiclass_classificator_by_label = {
"truePositiveRateByLabel": "true_positive_rate",
"falsePositiveRateByLabel": "false_positive_rate",
"precisionByLabel": "precision",
"recallByLabel": "recall",
"fMeasureByLabel": "f_measure",
}

def __evaluate_multi_class_classification(
self, dataset: DataFrame, metric_name: str, class_index: float
) -> float:
try:
return MulticlassClassificationEvaluator(
metricName=metric_name,
predictionCol=f"{self.reference.model.outputs.prediction.name}-idx",
labelCol=f"{self.reference.model.target.name}-idx",
metricLabel=class_index,
).evaluate(dataset)
except Exception:
return float("nan")

# FIXME use pydantic struct like data quality
def __calc_multiclass_by_label_metrics(self) -> List[Dict]:
return [
{
"class_name": label,
"metrics": {
metric_label: self.__evaluate_multi_class_classification(
self.indexed_reference, metric_name, float(index)
)
for (
metric_name,
metric_label,
) in self.model_quality_multiclass_classificator_by_label.items()
},
}
for index, label in self.index_label_map.items()
]

def __calc_multiclass_global_metrics(self) -> Dict:
return {
metric_label: self.__evaluate_multi_class_classification(
self.indexed_reference, metric_name, 0.0
)
for (
metric_name,
metric_label,
) in self.model_quality_multiclass_classificator_global.items()
}

def __calc_confusion_matrix(self):
prediction_and_labels = self.indexed_reference.select(
*[
f"{self.reference.model.outputs.prediction.name}-idx",
f"{self.reference.model.target.name}-idx",
]
).rdd
multiclass_metrics_calculator = MulticlassMetrics(prediction_and_labels)
return multiclass_metrics_calculator.confusionMatrix().toArray().tolist()

def calculate_model_quality(self) -> Dict:
metrics_by_label = self.__calc_multiclass_by_label_metrics()
global_metrics = self.__calc_multiclass_global_metrics()
global_metrics["confusion_matrix"] = self.__calc_confusion_matrix()
metrics = {
"classes": list(self.index_label_map.values()),
"class_metrics": metrics_by_label,
"global_metrics": global_metrics,
}

return metrics

def calculate_data_quality_numerical(self) -> List[NumericalFeatureMetrics]:
return DataQualityCalculator.numerical_metrics(
Expand Down
8 changes: 4 additions & 4 deletions spark/tests/models/reference_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def test_indexer(spark_fixture, dataset_target_string):
index_label_map, indexed_dataset = reference_dataset.get_string_indexed_dataframe()

assert index_label_map == {
"0.0": "HEALTY",
"1.0": "UNHEALTHY",
"2.0": "UNKNOWN",
"3.0": "ORPHAN",
"0.0": "HEALTHY",
"1.0": "ORPHAN",
"2.0": "UNHEALTHY",
"3.0": "UNKNOWN",
}
Loading

0 comments on commit e0e68d2

Please sign in to comment.