Skip to content

Commit

Permalink
feat: added job for regression model quality (#51)
Browse files Browse the repository at this point in the history
* feat: added job for regression model quality

* style(spark): run ruff format

* fix(spark): avoid dataframe override in metrics calcs

---------

Co-authored-by: lorenzodagostinoradicalbit <[email protected]>
  • Loading branch information
lorenzodagostinoradicalbit and lorenzodagostinoradicalbit committed Jun 28, 2024
1 parent 20e2d6b commit 9bf435a
Show file tree
Hide file tree
Showing 6 changed files with 955 additions and 0 deletions.
88 changes: 88 additions & 0 deletions spark/jobs/metrics/model_quality_regression_calculator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from pyspark.sql import DataFrame
from pyspark.sql.functions import col
from pyspark.sql.functions import abs as pyspark_abs

from models.regression_model_quality import RegressionMetricType, ModelQualityRegression
from utils.models import ModelOut
from pyspark.ml.evaluation import RegressionEvaluator


class ModelQualityRegressionCalculator:
@staticmethod
def __eval_model_quality_metric(
model: ModelOut,
dataframe: DataFrame,
dataframe_count: int,
metric_name: RegressionMetricType,
) -> float:
try:
match metric_name:
case RegressionMetricType.ADJ_R2:
# Source: https://medium.com/analytics-vidhya/adjusted-r-squared-formula-explanation-1ce033e25699
# adj_r2 = 1 - (n - 1) / (n - p - 1)
# n: number of observations
# p: number of indipendent variables (feaures)
p: float = len(model.features)
n: float = dataframe_count
r2: float = (
ModelQualityRegressionCalculator.__eval_model_quality_metric(
model, dataframe, dataframe_count, RegressionMetricType.R2
)
)
return 1 - (1 - r2) * ((n - 1) / (n - p - 1))
case RegressionMetricType.MAPE:
# Source: https://en.wikipedia.org/wiki/Mean_absolute_percentage_error
# mape = 100 * (abs(actual - predicted) / actual) / n
_dataframe = dataframe.withColumn(
"mape",
pyspark_abs(
(
col(model.outputs.prediction.name)
- col(model.target.name)
)
/ col(model.target.name)
),
)
return _dataframe.agg({"mape": "avg"}).collect()[0][0] * 100
case (
RegressionMetricType.MAE
| RegressionMetricType.MSE
| RegressionMetricType.RMSE
| RegressionMetricType.R2
| RegressionMetricType.VAR
):
return RegressionEvaluator(
metricName=metric_name.value,
labelCol=model.target.name,
predictionCol=model.outputs.prediction.name,
).evaluate(dataframe)
except Exception:
return float("nan")

@staticmethod
def __calc_mq_metrics(
model: ModelOut, dataframe: DataFrame, dataframe_count: int
) -> ModelQualityRegression:
return ModelQualityRegression(
**{
metric_name.value: ModelQualityRegressionCalculator.__eval_model_quality_metric(
model,
dataframe,
dataframe_count,
metric_name,
)
for metric_name in RegressionMetricType
}
)

@staticmethod
def numerical_metrics(
model: ModelOut, dataframe: DataFrame, dataframe_count: int
) -> ModelQualityRegression:
# TODO: understand if we should filter out rows with null values in prediction || ground_truth
# # drop row where prediction or ground_truth is null
# _dataframe = dataframe.dropna(subset=[model.outputs.prediction.name, model.target.name])
# _dataframe_count = dataframe.count()
return ModelQualityRegressionCalculator.__calc_mq_metrics(
model, dataframe, dataframe_count
)
23 changes: 23 additions & 0 deletions spark/jobs/models/regression_model_quality.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from pydantic import BaseModel

from enum import Enum


class RegressionMetricType(str, Enum):
MAE = "mae"
MAPE = "mape"
MSE = "mse"
RMSE = "rmse"
R2 = "r2"
ADJ_R2 = "adj_r2"
VAR = "var"


class ModelQualityRegression(BaseModel):
mae: float
mape: float
mse: float
rmse: float
r2: float
adj_r2: float
var: float
9 changes: 9 additions & 0 deletions spark/jobs/reference_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from metrics.statistics import calculate_statistics_reference
from models.reference_dataset import ReferenceDataset
from utils.reference_regression import ReferenceMetricsRegressionService
from utils.reference_binary import ReferenceMetricsService
from utils.models import JobStatus, ModelOut, ModelType
from utils.db import update_job_status, write_to_db
Expand Down Expand Up @@ -76,6 +77,14 @@ def main(
complete_record["DATA_QUALITY"] = data_quality.model_dump_json(
serialize_as_any=True
)
case ModelType.REGRESSION:
metrics_service = ReferenceMetricsRegressionService(
reference=reference_dataset
)
model_quality = metrics_service.calculate_model_quality()
complete_record["MODEL_QUALITY"] = model_quality.model_dump_json(
serialize_as_any=True
)

schema = StructType(
[
Expand Down
15 changes: 15 additions & 0 deletions spark/jobs/utils/reference_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from models.regression_model_quality import ModelQualityRegression
from models.reference_dataset import ReferenceDataset
from metrics.model_quality_regression_calculator import ModelQualityRegressionCalculator


class ReferenceMetricsRegressionService:
def __init__(self, reference: ReferenceDataset):
self.reference = reference

def calculate_model_quality(self) -> ModelQualityRegression:
return ModelQualityRegressionCalculator.numerical_metrics(
model=self.reference.model,
dataframe=self.reference.reference,
dataframe_count=self.reference.reference_count,
)
88 changes: 88 additions & 0 deletions spark/tests/regression_reference_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import datetime
import uuid

import pytest

from jobs.utils.reference_regression import ReferenceMetricsRegressionService
from jobs.models.reference_dataset import ReferenceDataset
from jobs.utils.models import (
ModelOut,
ModelType,
DataType,
OutputType,
ColumnDefinition,
SupportedTypes,
Granularity,
)
from tests.utils.pytest_utils import my_approx


@pytest.fixture()
def reference_bike(spark_fixture, test_data_dir):
yield spark_fixture.read.csv(
f"{test_data_dir}/reference/regression/reference_bike.csv", header=True
)


def test_calculation_dataset_target_int(spark_fixture, reference_bike):
output = OutputType(
prediction=ColumnDefinition(name="predictions", type=SupportedTypes.float),
prediction_proba=None,
output=[ColumnDefinition(name="predictions", type=SupportedTypes.float)],
)
target = ColumnDefinition(name="ground_truth", type=SupportedTypes.int)
timestamp = ColumnDefinition(name="dteday", type=SupportedTypes.datetime)
granularity = Granularity.HOUR
features = [
ColumnDefinition(name="season", type=SupportedTypes.int),
ColumnDefinition(name="yr", type=SupportedTypes.int),
ColumnDefinition(name="mnth", type=SupportedTypes.int),
ColumnDefinition(name="holiday", type=SupportedTypes.int),
ColumnDefinition(name="weekday", type=SupportedTypes.int),
ColumnDefinition(name="workingday", type=SupportedTypes.int),
ColumnDefinition(name="weathersit", type=SupportedTypes.float),
ColumnDefinition(name="temp", type=SupportedTypes.float),
ColumnDefinition(name="atemp", type=SupportedTypes.float),
ColumnDefinition(name="hum", type=SupportedTypes.float),
ColumnDefinition(name="windspeed", type=SupportedTypes.float),
ColumnDefinition(name="instant", type=SupportedTypes.int),
]
model = ModelOut(
uuid=uuid.uuid4(),
name="regression model",
description="description",
model_type=ModelType.REGRESSION,
data_type=DataType.TABULAR,
timestamp=timestamp,
granularity=granularity,
outputs=output,
target=target,
features=features,
frameworks="framework",
algorithm="algorithm",
created_at=str(datetime.datetime.now()),
updated_at=str(datetime.datetime.now()),
)

reference_dataset = ReferenceDataset(
raw_dataframe=reference_bike,
model=model,
)
assert reference_dataset.reference_count == 731

regression_service = ReferenceMetricsRegressionService(reference=reference_dataset)
model_quality_metrics = regression_service.calculate_model_quality()

expected = my_approx(
{
"mae": 126.6230232558139,
"mape": 33.33458358635063,
"mse": 42058.59416703146,
"rmse": 205.08192062449447,
"r2": 0.9106667318989127,
"adj_r2": 0.9091736967774461,
"var": 388091.1367098835,
}
)

assert model_quality_metrics.model_dump() == expected
Loading

0 comments on commit 9bf435a

Please sign in to comment.