Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added job for regression model quality #51

Merged
merged 3 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions spark/jobs/metrics/model_quality_regression_calculator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from pyspark.sql import DataFrame
from pyspark.sql.functions import col
from pyspark.sql.functions import abs as pyspark_abs

from models.regression_model_quality import RegressionMetricType, ModelQualityRegression
from utils.models import ModelOut
from pyspark.ml.evaluation import RegressionEvaluator


class ModelQualityRegressionCalculator:
@staticmethod
def __eval_model_quality_metric(
model: ModelOut,
dataframe: DataFrame,
dataframe_count: int,
metric_name: RegressionMetricType,
) -> float:
try:
match metric_name:
case RegressionMetricType.ADJ_R2:
# Source: https://medium.com/analytics-vidhya/adjusted-r-squared-formula-explanation-1ce033e25699
# adj_r2 = 1 - (n - 1) / (n - p - 1)
# n: number of observations
# p: number of indipendent variables (feaures)
p: float = len(model.features)
n: float = dataframe_count
r2: float = (
ModelQualityRegressionCalculator.__eval_model_quality_metric(
model, dataframe, dataframe_count, RegressionMetricType.R2
)
)
return 1 - (1 - r2) * ((n - 1) / (n - p - 1))
case RegressionMetricType.MAPE:
# Source: https://en.wikipedia.org/wiki/Mean_absolute_percentage_error
# mape = 100 * (abs(actual - predicted) / actual) / n
_dataframe = dataframe.withColumn(
"mape",
pyspark_abs(
(
col(model.outputs.prediction.name)
- col(model.target.name)
)
/ col(model.target.name)
),
)
return _dataframe.agg({"mape": "avg"}).collect()[0][0] * 100
case (
RegressionMetricType.MAE
| RegressionMetricType.MSE
| RegressionMetricType.RMSE
| RegressionMetricType.R2
| RegressionMetricType.VAR
):
return RegressionEvaluator(
metricName=metric_name.value,
labelCol=model.target.name,
predictionCol=model.outputs.prediction.name,
).evaluate(dataframe)
except Exception:
return float("nan")

@staticmethod
def __calc_mq_metrics(
model: ModelOut, dataframe: DataFrame, dataframe_count: int
) -> ModelQualityRegression:
return ModelQualityRegression(
**{
metric_name.value: ModelQualityRegressionCalculator.__eval_model_quality_metric(
model,
dataframe,
dataframe_count,
metric_name,
)
for metric_name in RegressionMetricType
}
)

@staticmethod
def numerical_metrics(
model: ModelOut, dataframe: DataFrame, dataframe_count: int
) -> ModelQualityRegression:
# TODO: understand if we should filter out rows with null values in prediction || ground_truth
# # drop row where prediction or ground_truth is null
# _dataframe = dataframe.dropna(subset=[model.outputs.prediction.name, model.target.name])
# _dataframe_count = dataframe.count()
return ModelQualityRegressionCalculator.__calc_mq_metrics(
model, dataframe, dataframe_count
)
23 changes: 23 additions & 0 deletions spark/jobs/models/regression_model_quality.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from pydantic import BaseModel

from enum import Enum


class RegressionMetricType(str, Enum):
MAE = "mae"
MAPE = "mape"
MSE = "mse"
RMSE = "rmse"
R2 = "r2"
ADJ_R2 = "adj_r2"
VAR = "var"


class ModelQualityRegression(BaseModel):
mae: float
mape: float
mse: float
rmse: float
r2: float
adj_r2: float
var: float
9 changes: 9 additions & 0 deletions spark/jobs/reference_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from metrics.statistics import calculate_statistics_reference
from models.reference_dataset import ReferenceDataset
from utils.reference_regression import ReferenceMetricsRegressionService
from utils.reference_binary import ReferenceMetricsService
from utils.models import JobStatus, ModelOut, ModelType
from utils.db import update_job_status, write_to_db
Expand Down Expand Up @@ -76,6 +77,14 @@ def main(
complete_record["DATA_QUALITY"] = data_quality.model_dump_json(
serialize_as_any=True
)
case ModelType.REGRESSION:
metrics_service = ReferenceMetricsRegressionService(
reference=reference_dataset
)
model_quality = metrics_service.calculate_model_quality()
complete_record["MODEL_QUALITY"] = model_quality.model_dump_json(
serialize_as_any=True
)

schema = StructType(
[
Expand Down
15 changes: 15 additions & 0 deletions spark/jobs/utils/reference_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from models.regression_model_quality import ModelQualityRegression
from models.reference_dataset import ReferenceDataset
from metrics.model_quality_regression_calculator import ModelQualityRegressionCalculator


class ReferenceMetricsRegressionService:
def __init__(self, reference: ReferenceDataset):
self.reference = reference

def calculate_model_quality(self) -> ModelQualityRegression:
return ModelQualityRegressionCalculator.numerical_metrics(
model=self.reference.model,
dataframe=self.reference.reference,
dataframe_count=self.reference.reference_count,
)
88 changes: 88 additions & 0 deletions spark/tests/regression_reference_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import datetime
import uuid

import pytest

from jobs.utils.reference_regression import ReferenceMetricsRegressionService
from jobs.models.reference_dataset import ReferenceDataset
from jobs.utils.models import (
ModelOut,
ModelType,
DataType,
OutputType,
ColumnDefinition,
SupportedTypes,
Granularity,
)
from tests.utils.pytest_utils import my_approx


@pytest.fixture()
def reference_bike(spark_fixture, test_data_dir):
yield spark_fixture.read.csv(
f"{test_data_dir}/reference/regression/reference_bike.csv", header=True
)


def test_calculation_dataset_target_int(spark_fixture, reference_bike):
output = OutputType(
prediction=ColumnDefinition(name="predictions", type=SupportedTypes.float),
prediction_proba=None,
output=[ColumnDefinition(name="predictions", type=SupportedTypes.float)],
)
target = ColumnDefinition(name="ground_truth", type=SupportedTypes.int)
timestamp = ColumnDefinition(name="dteday", type=SupportedTypes.datetime)
granularity = Granularity.HOUR
features = [
ColumnDefinition(name="season", type=SupportedTypes.int),
ColumnDefinition(name="yr", type=SupportedTypes.int),
ColumnDefinition(name="mnth", type=SupportedTypes.int),
ColumnDefinition(name="holiday", type=SupportedTypes.int),
ColumnDefinition(name="weekday", type=SupportedTypes.int),
ColumnDefinition(name="workingday", type=SupportedTypes.int),
ColumnDefinition(name="weathersit", type=SupportedTypes.float),
ColumnDefinition(name="temp", type=SupportedTypes.float),
ColumnDefinition(name="atemp", type=SupportedTypes.float),
ColumnDefinition(name="hum", type=SupportedTypes.float),
ColumnDefinition(name="windspeed", type=SupportedTypes.float),
ColumnDefinition(name="instant", type=SupportedTypes.int),
]
model = ModelOut(
uuid=uuid.uuid4(),
name="regression model",
description="description",
model_type=ModelType.REGRESSION,
data_type=DataType.TABULAR,
timestamp=timestamp,
granularity=granularity,
outputs=output,
target=target,
features=features,
frameworks="framework",
algorithm="algorithm",
created_at=str(datetime.datetime.now()),
updated_at=str(datetime.datetime.now()),
)

reference_dataset = ReferenceDataset(
raw_dataframe=reference_bike,
model=model,
)
assert reference_dataset.reference_count == 731

regression_service = ReferenceMetricsRegressionService(reference=reference_dataset)
model_quality_metrics = regression_service.calculate_model_quality()

expected = my_approx(
{
"mae": 126.6230232558139,
"mape": 33.33458358635063,
"mse": 42058.59416703146,
"rmse": 205.08192062449447,
"r2": 0.9106667318989127,
"adj_r2": 0.9091736967774461,
"var": 388091.1367098835,
}
)

assert model_quality_metrics.model_dump() == expected
Loading