Skip to content

Commit

Permalink
feat: add data quality multiclass current (#56)
Browse files Browse the repository at this point in the history
  • Loading branch information
rivamarco committed Jul 1, 2024
1 parent cdff426 commit ccb6b79
Show file tree
Hide file tree
Showing 13 changed files with 607 additions and 190 deletions.
17 changes: 14 additions & 3 deletions spark/jobs/current_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from models.current_dataset import CurrentDataset
from models.reference_dataset import ReferenceDataset
from utils.current_binary import CurrentMetricsService
from utils.current_multiclass import CurrentMetricsMulticlassService
from utils.models import JobStatus, ModelOut, ModelType
from utils.db import update_job_status, write_to_db

Expand Down Expand Up @@ -54,9 +55,9 @@ def main(
match model.model_type:
case ModelType.BINARY:
metrics_service = CurrentMetricsService(
spark_session,
current_dataset.current,
reference_dataset.reference,
spark_session=spark_session,
current=current_dataset.current,
reference=reference_dataset.reference,
model=model,
)
statistics = calculate_statistics_current(current_dataset)
Expand All @@ -76,10 +77,20 @@ def main(
)
complete_record["DRIFT"] = orjson.dumps(drift).decode("utf-8")
case ModelType.MULTI_CLASS:
metrics_service = CurrentMetricsMulticlassService(
spark_session=spark_session,
current=current_dataset.current,
reference=reference_dataset.reference,
model=model,
)
statistics = calculate_statistics_current(current_dataset)
data_quality = metrics_service.calculate_data_quality()
complete_record["STATISTICS"] = statistics.model_dump_json(
serialize_as_any=True
)
complete_record["DATA_QUALITY"] = data_quality.model_dump_json(
serialize_as_any=True
)

schema = StructType(
[
Expand Down
187 changes: 186 additions & 1 deletion spark/jobs/metrics/data_quality_calculator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from typing import List
from math import inf
from typing import List, Dict

import numpy as np
import pyspark.sql.functions as F
from pandas import DataFrame
from pyspark.ml.feature import Bucketizer
from pyspark.sql import SparkSession
from pyspark.sql.types import IntegerType

from models.data_quality import (
NumericalFeatureMetrics,
Expand Down Expand Up @@ -198,3 +203,183 @@ def class_metrics(
)
for label, metrics in class_metrics_dict.items()
]

@staticmethod
def calculate_combined_data_quality_numerical(
model: ModelOut,
current_dataframe: DataFrame,
current_count: int,
reference_dataframe: DataFrame,
spark_session: SparkSession,
) -> List[NumericalFeatureMetrics]:
def calculate_combined_histogram(
model: ModelOut,
current_dataframe: DataFrame,
reference_dataframe: DataFrame,
spark_session: SparkSession,
) -> Dict[str, Histogram]:
numerical_features = [
numerical.name for numerical in model.get_numerical_features()
]
current = current_dataframe.withColumn("type", F.lit("current"))
reference = reference_dataframe.withColumn("type", F.lit("reference"))

def create_histogram(feature: str):
reference_and_current = current.select([feature, "type"]).unionByName(
reference.select([feature, "type"])
)

max_value = reference_and_current.agg(
F.max(
F.when(
F.col(feature).isNotNull() & ~F.isnan(feature),
F.col(feature),
)
)
).collect()[0][0]
min_value = reference_and_current.agg(
F.min(
F.when(
F.col(feature).isNotNull() & ~F.isnan(feature),
F.col(feature),
)
)
).collect()[0][0]

buckets_spacing = np.linspace(min_value, max_value, 11).tolist()
lookup = set()
generated_buckets = [
x
for x in buckets_spacing
if x not in lookup and lookup.add(x) is None
]
# workaround if all values are the same to not have errors
if len(generated_buckets) == 1:
buckets_spacing = [generated_buckets[0], generated_buckets[0]]
buckets = [-float(inf), generated_buckets[0], float(inf)]
else:
buckets = generated_buckets

bucketizer = Bucketizer(
splits=buckets, inputCol=feature, outputCol="bucket"
)
result = bucketizer.setHandleInvalid("keep").transform(
reference_and_current
)

current_df = (
result.filter(F.col("type") == "current")
.groupBy("bucket")
.agg(F.count(F.col(feature)).alias("curr_count"))
)
reference_df = (
result.filter(F.col("type") == "reference")
.groupBy("bucket")
.agg(F.count(F.col(feature)).alias("ref_count"))
)

buckets_number = list(range(10))
bucket_df = spark_session.createDataFrame(
buckets_number, IntegerType()
).withColumnRenamed("value", "bucket")
tot_df = (
bucket_df.join(current_df, on=["bucket"], how="left")
.join(reference_df, on=["bucket"], how="left")
.fillna(0)
.orderBy("bucket")
)
# workaround if all values are the same to not have errors
if len(generated_buckets) == 1:
tot_df = tot_df.filter(F.col("bucket") == 1)
cur = tot_df.select("curr_count").rdd.flatMap(lambda x: x).collect()
ref = tot_df.select("ref_count").rdd.flatMap(lambda x: x).collect()
return Histogram(
buckets=buckets_spacing, reference_values=ref, current_values=cur
)

return {
feature: create_histogram(feature) for feature in numerical_features
}

numerical_features = [
numerical.name for numerical in model.get_numerical_features()
]

mean_agg = [
(F.mean(check_not_null(x))).alias(f"{x}-mean") for x in numerical_features
]

max_agg = [
(F.max(check_not_null(x))).alias(f"{x}-max") for x in numerical_features
]

min_agg = [
(F.min(check_not_null(x))).alias(f"{x}-min") for x in numerical_features
]

median_agg = [
(F.median(check_not_null(x))).alias(f"{x}-median")
for x in numerical_features
]

perc_25_agg = [
(F.percentile(check_not_null(x), 0.25)).alias(f"{x}-perc_25")
for x in numerical_features
]

perc_75_agg = [
(F.percentile(check_not_null(x), 0.75)).alias(f"{x}-perc_75")
for x in numerical_features
]

std_agg = [
(F.std(check_not_null(x))).alias(f"{x}-std") for x in numerical_features
]

missing_values_agg = [
(F.count(F.when(F.col(x).isNull() | F.isnan(x), x))).alias(
f"{x}-missing_values"
)
for x in numerical_features
]

missing_values_perc_agg = [
(
(F.count(F.when(F.col(x).isNull() | F.isnan(x), x)) / current_count)
* 100
).alias(f"{x}-missing_values_perc")
for x in numerical_features
]

# Global
global_stat = current_dataframe.select(numerical_features).agg(
*(
mean_agg
+ max_agg
+ min_agg
+ median_agg
+ perc_25_agg
+ perc_75_agg
+ std_agg
+ missing_values_agg
+ missing_values_perc_agg
)
)

global_dict = global_stat.toPandas().iloc[0].to_dict()
global_data_quality = split_dict(global_dict)

numerical_features_histogram = calculate_combined_histogram(
model, current_dataframe, reference_dataframe, spark_session
)

numerical_features_metrics = [
NumericalFeatureMetrics.from_dict(
feature_name,
metrics,
histogram=numerical_features_histogram.get(feature_name),
)
for feature_name, metrics in global_data_quality.items()
]

return numerical_features_metrics
Loading

0 comments on commit ccb6b79

Please sign in to comment.