Skip to content

Commit

Permalink
feat(spark): added data quality for current regression
Browse files Browse the repository at this point in the history
  • Loading branch information
lorenzodagostinoradicalbit committed Jul 2, 2024
1 parent bfcd273 commit ae7665a
Show file tree
Hide file tree
Showing 5 changed files with 629 additions and 105 deletions.
11 changes: 11 additions & 0 deletions spark/jobs/current_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
from metrics.statistics import calculate_statistics_current
from models.current_dataset import CurrentDataset
from models.reference_dataset import ReferenceDataset

from utils.current_binary import CurrentMetricsService
from utils.current_multiclass import CurrentMetricsMulticlassService
from utils.current_regression import CurrentMetricsRegressionService
from utils.models import JobStatus, ModelOut, ModelType
from utils.db import update_job_status, write_to_db

Expand Down Expand Up @@ -96,10 +98,19 @@ def main(
)
complete_record["DRIFT"] = orjson.dumps(drift).decode("utf-8")
case ModelType.REGRESSION:
metrics_service = CurrentMetricsRegressionService(
reference=reference_dataset,
current=current_dataset,
spark_session=spark_session,
)
statistics = calculate_statistics_current(current_dataset)
data_quality = metrics_service.calculate_data_quality(is_current=True)
complete_record["STATISTICS"] = statistics.model_dump_json(
serialize_as_any=True
)
complete_record["DATA_QUALITY"] = data_quality.model_dump_json(
serialize_as_any=True
)

schema = StructType(
[
Expand Down
221 changes: 124 additions & 97 deletions spark/jobs/metrics/data_quality_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,102 +206,96 @@ def class_metrics(
]

@staticmethod
def calculate_combined_data_quality_numerical(
model: ModelOut,
def calculate_combined_histogram(
current_dataframe: DataFrame,
current_count: int,
reference_dataframe: DataFrame,
spark_session: SparkSession,
) -> List[NumericalFeatureMetrics]:
def calculate_combined_histogram(
model: ModelOut,
current_dataframe: DataFrame,
reference_dataframe: DataFrame,
spark_session: SparkSession,
) -> Dict[str, Histogram]:
numerical_features = [
numerical.name for numerical in model.get_numerical_features()
]
current = current_dataframe.withColumn("type", F.lit("current"))
reference = reference_dataframe.withColumn("type", F.lit("reference"))

def create_histogram(feature: str):
reference_and_current = current.select([feature, "type"]).unionByName(
reference.select([feature, "type"])
)
columns: List[str],
) -> Dict[str, Histogram]:
current = current_dataframe.withColumn("type", F.lit("current"))
reference = reference_dataframe.withColumn("type", F.lit("reference"))

def create_histogram(feature: str):
reference_and_current = current.select([feature, "type"]).unionByName(
reference.select([feature, "type"])
)

max_value = reference_and_current.agg(
F.max(
F.when(
F.col(feature).isNotNull() & ~F.isnan(feature),
F.col(feature),
)
)
).collect()[0][0]
min_value = reference_and_current.agg(
F.min(
F.when(
F.col(feature).isNotNull() & ~F.isnan(feature),
F.col(feature),
)
max_value = reference_and_current.agg(
F.max(
F.when(
F.col(feature).isNotNull() & ~F.isnan(feature),
F.col(feature),
)
).collect()[0][0]

buckets_spacing = np.linspace(min_value, max_value, 11).tolist()
lookup = set()
generated_buckets = [
x
for x in buckets_spacing
if x not in lookup and lookup.add(x) is None
]
# workaround if all values are the same to not have errors
if len(generated_buckets) == 1:
buckets_spacing = [generated_buckets[0], generated_buckets[0]]
buckets = [-float(inf), generated_buckets[0], float(inf)]
else:
buckets = generated_buckets

bucketizer = Bucketizer(
splits=buckets, inputCol=feature, outputCol="bucket"
)
result = bucketizer.setHandleInvalid("keep").transform(
reference_and_current
).collect()[0][0]
min_value = reference_and_current.agg(
F.min(
F.when(
F.col(feature).isNotNull() & ~F.isnan(feature),
F.col(feature),
)
)
).collect()[0][0]

current_df = (
result.filter(F.col("type") == "current")
.groupBy("bucket")
.agg(F.count(F.col(feature)).alias("curr_count"))
)
reference_df = (
result.filter(F.col("type") == "reference")
.groupBy("bucket")
.agg(F.count(F.col(feature)).alias("ref_count"))
)
buckets_spacing = np.linspace(min_value, max_value, 11).tolist()
lookup = set()
generated_buckets = [
x for x in buckets_spacing if x not in lookup and lookup.add(x) is None
]
# workaround if all values are the same to not have errors
if len(generated_buckets) == 1:
buckets_spacing = [generated_buckets[0], generated_buckets[0]]
buckets = [-float(inf), generated_buckets[0], float(inf)]
else:
buckets = generated_buckets

bucketizer = Bucketizer(
splits=buckets, inputCol=feature, outputCol="bucket"
)
result = bucketizer.setHandleInvalid("keep").transform(
reference_and_current
)

buckets_number = list(range(10))
bucket_df = spark_session.createDataFrame(
buckets_number, IntegerType()
).withColumnRenamed("value", "bucket")
tot_df = (
bucket_df.join(current_df, on=["bucket"], how="left")
.join(reference_df, on=["bucket"], how="left")
.fillna(0)
.orderBy("bucket")
)
# workaround if all values are the same to not have errors
if len(generated_buckets) == 1:
tot_df = tot_df.filter(F.col("bucket") == 1)
cur = tot_df.select("curr_count").rdd.flatMap(lambda x: x).collect()
ref = tot_df.select("ref_count").rdd.flatMap(lambda x: x).collect()
return Histogram(
buckets=buckets_spacing, reference_values=ref, current_values=cur
)
current_df = (
result.filter(F.col("type") == "current")
.groupBy("bucket")
.agg(F.count(F.col(feature)).alias("curr_count"))
)
reference_df = (
result.filter(F.col("type") == "reference")
.groupBy("bucket")
.agg(F.count(F.col(feature)).alias("ref_count"))
)

buckets_number = list(range(10))
bucket_df = spark_session.createDataFrame(
buckets_number, IntegerType()
).withColumnRenamed("value", "bucket")
tot_df = (
bucket_df.join(current_df, on=["bucket"], how="left")
.join(reference_df, on=["bucket"], how="left")
.fillna(0)
.orderBy("bucket")
)
# workaround if all values are the same to not have errors
if len(generated_buckets) == 1:
tot_df = tot_df.filter(F.col("bucket") == 1)
cur = tot_df.select("curr_count").rdd.flatMap(lambda x: x).collect()
ref = tot_df.select("ref_count").rdd.flatMap(lambda x: x).collect()
return Histogram(
buckets=buckets_spacing, reference_values=ref, current_values=cur
)

return {
feature: create_histogram(feature) for feature in numerical_features
}
return {feature: create_histogram(feature) for feature in columns}

@staticmethod
def calculate_combined_data_quality_numerical(
model: ModelOut,
current_dataframe: DataFrame,
current_count: int,
reference_dataframe: DataFrame,
spark_session: SparkSession,
) -> List[NumericalFeatureMetrics]:
numerical_features = [
numerical.name for numerical in model.get_numerical_features()
]
Expand Down Expand Up @@ -370,8 +364,13 @@ def create_histogram(feature: str):
global_dict = global_stat.toPandas().iloc[0].to_dict()
global_data_quality = split_dict(global_dict)

numerical_features_histogram = calculate_combined_histogram(
model, current_dataframe, reference_dataframe, spark_session
numerical_features_histogram = (
DataQualityCalculator.calculate_combined_histogram(
current_dataframe,
reference_dataframe,
spark_session,
numerical_features,
)
)

numerical_features_metrics = [
Expand All @@ -388,7 +387,44 @@ def create_histogram(feature: str):
def regression_target_metrics(
target_column: str, dataframe: DataFrame, dataframe_count: int
) -> NumericalTargetMetrics:
target_metrics = (
target_metrics = DataQualityCalculator.regression_target_metrics_for_dataframe(
target_column, dataframe, dataframe_count
)

_histogram = (
dataframe.select(target_column).rdd.flatMap(lambda x: x).histogram(10)
)
histogram = Histogram(buckets=_histogram[0], reference_values=_histogram[1])

return NumericalTargetMetrics.from_dict(
target_column, target_metrics, histogram
)

@staticmethod
def regression_target_metrics_current(
target_column: str,
curr_df: DataFrame,
curr_count: int,
ref_df: DataFrame,
spark_session: SparkSession,
):
target_metrics = DataQualityCalculator.regression_target_metrics_for_dataframe(
target_column, curr_df, curr_count
)
_histogram = DataQualityCalculator.calculate_combined_histogram(
curr_df, ref_df, spark_session, [target_column]
)
histogram = _histogram[target_column]

return NumericalTargetMetrics.from_dict(
target_column, target_metrics, histogram
)

@staticmethod
def regression_target_metrics_for_dataframe(
target_column: str, dataframe: DataFrame, dataframe_count: int
) -> dict:
return (
dataframe.select(target_column)
.filter(F.isnotnull(target_column))
.agg(
Expand Down Expand Up @@ -419,12 +455,3 @@ def regression_target_metrics(
.iloc[0]
.to_dict()
)

_histogram = (
dataframe.select(target_column).rdd.flatMap(lambda x: x).histogram(10)
)
histogram = Histogram(buckets=_histogram[0], reference_values=_histogram[1])

return NumericalTargetMetrics.from_dict(
target_column, target_metrics, histogram
)
4 changes: 0 additions & 4 deletions spark/jobs/metrics/model_quality_regression_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,6 @@ def __calc_mq_metrics(
def numerical_metrics(
model: ModelOut, dataframe: DataFrame, dataframe_count: int
) -> ModelQualityRegression:
# TODO: understand if we should filter out rows with null values in prediction || ground_truth
# # drop row where prediction or ground_truth is null
# _dataframe = dataframe.dropna(subset=[model.outputs.prediction.name, model.target.name])
# _dataframe_count = dataframe.count()
return ModelQualityRegressionCalculator.__calc_mq_metrics(
model, dataframe, dataframe_count
)
76 changes: 76 additions & 0 deletions spark/jobs/utils/current_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from typing import List, Optional

from pyspark.sql import SparkSession

from metrics.data_quality_calculator import DataQualityCalculator
from models.current_dataset import CurrentDataset
from models.data_quality import (
CategoricalFeatureMetrics,
NumericalFeatureMetrics,
NumericalTargetMetrics,
RegressionDataQuality,
)
from models.reference_dataset import ReferenceDataset


class CurrentMetricsRegressionService:
def __init__(
self,
spark_session: SparkSession,
current: CurrentDataset,
reference: ReferenceDataset,
):
self.spark_session = spark_session
self.current = current
self.reference = reference

def calculate_data_quality_numerical(self) -> List[NumericalFeatureMetrics]:
return DataQualityCalculator.calculate_combined_data_quality_numerical(
model=self.current.model,
current_dataframe=self.current.current,
current_count=self.current.current_count,
reference_dataframe=self.reference.reference,
spark_session=self.spark_session,
)

def calculate_data_quality_categorical(self) -> List[CategoricalFeatureMetrics]:
return DataQualityCalculator.categorical_metrics(
model=self.current.model,
dataframe=self.current.current,
dataframe_count=self.current.current_count,
)

def calculate_target_metrics(self) -> NumericalTargetMetrics:
return DataQualityCalculator.regression_target_metrics(
target_column=self.current.model.target.name,
dataframe=self.current.current,
dataframe_count=self.current.current_count,
)

def calculate_current_target_metrics(self) -> NumericalTargetMetrics:
return DataQualityCalculator.regression_target_metrics_current(
target_column=self.current.model.target.name,
curr_df=self.current.current,
curr_count=self.current.current_count,
ref_df=self.reference.reference,
spark_session=self.spark_session,
)

def calculate_data_quality(
self, is_current: Optional[bool] = False
) -> RegressionDataQuality:
feature_metrics = []
if self.reference.model.get_numerical_features():
feature_metrics.extend(self.calculate_data_quality_numerical())
if self.reference.model.get_categorical_features():
feature_metrics.extend(self.calculate_data_quality_categorical())
target_metrics = (
self.calculate_target_metrics()
if not is_current
else self.calculate_current_target_metrics()
)
return RegressionDataQuality(
n_observations=self.current.current_count,
target_metrics=target_metrics,
feature_metrics=feature_metrics,
)
Loading

0 comments on commit ae7665a

Please sign in to comment.