Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add regression current statistics #72

Merged
merged 1 commit into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions spark/jobs/current_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from metrics.statistics import calculate_statistics_current
from models.current_dataset import CurrentDataset
from models.reference_dataset import ReferenceDataset
from utils.reference_regression import ReferenceMetricsRegressionService
from utils.current_binary import CurrentMetricsService
from utils.current_multiclass import CurrentMetricsMulticlassService
from utils.models import JobStatus, ModelOut, ModelType
Expand Down Expand Up @@ -91,6 +92,16 @@ def main(
complete_record["DATA_QUALITY"] = data_quality.model_dump_json(
serialize_as_any=True
)
case ModelType.REGRESSION:
metrics_service = ReferenceMetricsRegressionService(
reference=reference_dataset
)
statistics = calculate_statistics_current(current_dataset)
data_quality = metrics_service.calculate_data_quality()

complete_record["STATISTICS"] = statistics.model_dump_json(
serialize_as_any=True
)

schema = StructType(
[
Expand Down
92 changes: 92 additions & 0 deletions spark/tests/regression_current_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import datetime
import uuid
import pytest

from metrics.statistics import calculate_statistics_current
from models.current_dataset import CurrentDataset
from utils.models import (
ColumnDefinition,
DataType,
Granularity,
ModelOut,
ModelType,
OutputType,
SupportedTypes,
)


@pytest.fixture()
def reference_bike(spark_fixture, test_data_dir):
yield spark_fixture.read.csv(
f"{test_data_dir}/current/regression/bike.csv", header=True
)


@pytest.fixture()
def current_dataset(reference_bike):
output = OutputType(
prediction=ColumnDefinition(name="predictions", type=SupportedTypes.float),
prediction_proba=None,
output=[ColumnDefinition(name="predictions", type=SupportedTypes.float)],
)
target = ColumnDefinition(name="ground_truth", type=SupportedTypes.int)
timestamp = ColumnDefinition(name="dteday", type=SupportedTypes.datetime)
granularity = Granularity.HOUR
features = [
ColumnDefinition(name="season", type=SupportedTypes.int),
ColumnDefinition(name="yr", type=SupportedTypes.int),
ColumnDefinition(name="mnth", type=SupportedTypes.int),
ColumnDefinition(name="holiday", type=SupportedTypes.int),
ColumnDefinition(name="weekday", type=SupportedTypes.int),
ColumnDefinition(name="workingday", type=SupportedTypes.int),
ColumnDefinition(name="weathersit", type=SupportedTypes.float),
ColumnDefinition(name="temp", type=SupportedTypes.float),
ColumnDefinition(name="atemp", type=SupportedTypes.float),
ColumnDefinition(name="hum", type=SupportedTypes.float),
ColumnDefinition(name="windspeed", type=SupportedTypes.float),
]
model = ModelOut(
uuid=uuid.uuid4(),
name="regression model",
description="description",
model_type=ModelType.REGRESSION,
data_type=DataType.TABULAR,
timestamp=timestamp,
granularity=granularity,
outputs=output,
target=target,
features=features,
frameworks="framework",
algorithm="algorithm",
created_at=str(datetime.datetime.now()),
updated_at=str(datetime.datetime.now()),
)

yield CurrentDataset(
raw_dataframe=reference_bike,
model=model,
)


def test_current_statistics(current_dataset):
stats = calculate_statistics_current(current_dataset)

assert current_dataset.current_count == stats.n_observations

assert stats.missing_cells_perc == 100 * stats.missing_cells / (
stats.n_variables * stats.n_observations
)

expected = {
"n_variables": 14,
"n_observations": 100,
"missing_cells": 7,
"missing_cells_perc": 0.5,
"duplicate_rows": 2,
"duplicate_rows_perc": 2.0,
"numeric": 13,
"categorical": 0,
"datetime": 1,
}

assert stats.model_dump(serialize_as_any=True) == expected
2 changes: 1 addition & 1 deletion spark/tests/regression_reference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def expected_data_quality_json():


@pytest.fixture()
def reference_dataset(spark_fixture, reference_bike):
def reference_dataset(reference_bike):
output = OutputType(
prediction=ColumnDefinition(name="predictions", type=SupportedTypes.float),
prediction_proba=None,
Expand Down
101 changes: 101 additions & 0 deletions spark/tests/resources/current/regression/bike.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
instant,dteday,season,yr,mnth,holiday,weekday,workingday,weathersit,temp,atemp,hum,windspeed,predictions,ground_truth
3,2011-01-03,1,0,1,0,1,1,1,0.196364,0.189405,0.437273,0.248309,113.83,120
4,2011-01-04,1,0,1,0,2,1,1,0.2,0.212122,0.590435,0.160296,95.03,108
5,2011-01-05,1,0,1,0,3,1,1,0.226957,0.22927,0.436957,0.1869,104.1,82
6,2011-01-06,1,0,1,0,4,1,1,0.204348,0.233209,0.518261,0.0895652,136.77,88
7,2011-01-07,1,0,1,0,5,1,2,0.196522,0.208839,0.498696,0.168726,130.9,148
8,2011-01-08,1,0,,0,6,0,2,0.165,0.162254,0.535833,0.266804,100.39,68
9,2011-01-09,1,0,1,0,0,0,1,0.138333,0.116175,0.434167,0.36195,161.6,54
10,2011-01-10,1,0,1,0,1,1,1,0.150833,0.150888,0.482917,0.223267,52.61,41
11,2011-01-11,1,0,1,0,2,1,2,0.169091,0.191464,0.686364,0.122132,54.54,43
12,2011-01-12,1,0,1,0,3,1,1,0.172727,0.160473,0.599545,0.304627,34.11,25
13,2011-01-13,1,0,1,0,4,1,1,0.165,0.150883,0.470417,0.301,46.14,38
14,2011-01-14,1,0,1,0,5,1,1,0.16087,0.188413,0.537826,0.126548,68.29,54
15,2011-01-15,1,0,1,0,6,0,2,,0.248112,0.49875,0.157963,257.48,222
16,2011-01-16,1,0,1,0,0,0,1,0.231667,0.234217,0.48375,0.188433,253.11,251
17,2011-01-17,1,0,1,1,1,0,2,0.175833,0.176771,0.5375,0.194017,132.76,117
18,2011-01-18,1,0,1,0,2,1,2,0.216667,0.232333,0.861667,0.146775,152.52,9
19,2011-01-19,1,0,1,0,3,1,2,0.292174,0.298422,0.741739,0.208317,158.84,78
20,2011-01-20,1,0,1,0,4,1,2,0.261667,0.25505,0.538333,0.195904,116.11,83
21,2011-01-21,1,0,1,0,5,1,1,0.1775,0.157833,0.457083,0.353242,76.25,75
22,2011-01-22,1,0,1,0,6,0,1,0.0591304,0.0790696,0.4,0.17197,207.3,93
23,2011-01-23,1,0,1,0,0,0,1,0.0965217,0.0988391,0.436522,0.2466,171.66,150
24,2011-01-24,1,0,1,0,1,1,1,0.0973913,0.11793,0.491739,0.15833,75.92,86
25,2011-01-25,1,0,1,0,2,1,2,0.223478,0.234526,0.616957,0.129796,167.34,186
26,2011-01-26,1,0,1,0,3,1,3,0.2175,0.2036,0.8625,0.29385,43.15,34
27,2011-01-27,1,0,1,0,4,1,1,0.195,0.2197,0.6875,0.113837,46.09,15
28,2011-01-28,1,0,1,0,5,1,2,0.203478,0.223317,0.793043,0.1233,170.72,38
29,2011-01-29,1,0,1,0,6,0,1,0.196522,0.212126,0.651739,0.145365,142.45,123
30,2011-01-30,1,0,1,0,0,0,1,0.216522,0.250322,0.722174,0.0739826,196.17,140
31,2011-01-31,1,0,1,,1,1,2,0.180833,,0.60375,0.187192,49.55,42
32,2011-02-01,1,0,2,0,2,1,2,0.192174,0.23453,0.829565,0.053213,182.08,47
33,2011-02-02,1,0,2,0,3,1,2,0.26,0.254417,0.775417,0.264308,95.26,72
34,2011-02-03,1,0,2,0,4,1,1,0.186957,0.177878,0.437826,0.277752,74.92,61
35,2011-02-04,1,0,2,0,5,1,2,0.211304,0.228587,0.585217,0.127839,180.28,88
36,2011-02-05,1,0,2,0,6,0,2,0.233333,0.243058,0.929167,0.161079,201.03,100
37,2011-02-06,1,0,2,0,0,0,1,0.285833,0.291671,0.568333,0.1418,366.02,354
38,2011-02-07,1,0,2,0,1,1,1,0.271667,0.303658,0.738333,0.0454083,186.74,120
39,2011-02-08,1,0,2,0,2,1,1,0.220833,0.198246,0.537917,0.36195,84.22,64
40,2011-02-09,1,0,2,0,3,1,2,0.134783,0.144283,0.494783,0.188839,59.29,53
41,2011-02-10,1,0,2,0,4,1,1,0.144348,0.149548,0.437391,0.221935,70.96,47
42,2011-02-11,1,0,2,0,5,1,1,0.189091,0.213509,0.506364,0.10855,135.62,149
43,2011-02-12,1,0,2,0,6,0,1,0.2225,0.232954,0.544167,0.203367,254.52,288
44,2011-02-13,1,0,2,0,0,0,1,0.316522,0.324113,0.457391,0.260883,561.73,397
45,2011-02-14,1,0,2,0,1,1,1,0.415,0.39835,0.375833,0.417908,421.74,208
46,2011-02-15,1,0,2,0,2,1,1,0.266087,0.254274,0.314348,0.291374,220.26,140
47,2011-02-16,1,0,2,0,3,1,1,0.318261,0.3162,0.423478,0.251791,222.85,218
48,2011-02-17,1,0,2,0,4,1,1,0.435833,0.428658,0.505,0.230104,335.41,259
49,2011-02-18,1,0,2,0,5,1,1,0.521667,0.511983,0.516667,0.264925,984.07,579
50,2011-02-19,1,0,2,0,6,0,1,0.399167,0.391404,0.187917,0.507463,1055.73,532
51,2011-02-20,1,0,2,0,0,0,1,0.285217,0.27733,0.407826,0.223235,477.84,639
52,2011-02-21,1,0,2,1,1,0,2,0.303333,0.284075,0.605,0.307846,283.46,195
53,2011-02-22,1,0,2,0,2,1,1,0.182222,0.186033,0.577778,0.195683,65.41,74
54,2011-02-23,1,0,2,0,3,1,1,0.221739,0.245717,0.423043,0.094113,148.39,139
55,2011-02-24,1,0,2,0,4,1,2,0.295652,0.289191,0.697391,0.250496,132.79,100
56,2011-02-25,1,0,2,0,5,1,2,0.364348,0.350461,0.712174,0.346539,254.06,120
57,2011-02-26,1,0,2,0,,0,1,,0.282192,0.537917,,432.56,424
58,2011-02-27,1,0,2,0,0,0,1,0.343478,0.351109,0.68,0.125248,753.05,694
59,2011-02-28,1,0,2,0,1,1,2,0.407273,0.400118,0.876364,0.289686,117.27,81
59,2011-02-28,1,0,2,0,1,1,2,0.407273,0.400118,0.876364,0.289686,117.27,81
60,2011-03-01,1,0,3,0,2,1,1,0.266667,0.263879,0.535,0.216425,161.42,137
61,2011-03-02,1,0,3,0,3,1,1,0.335,0.320071,0.449583,0.307833,253.87,231
62,2011-03-03,1,0,3,0,4,1,1,0.198333,0.200133,0.318333,0.225754,145.57,123
62,2011-03-03,1,0,3,0,4,1,1,0.198333,0.200133,0.318333,0.225754,145.57,123
63,2011-03-04,1,0,3,0,5,1,2,0.261667,0.255679,0.610417,0.203346,193.62,214
64,2011-03-05,1,0,3,0,6,0,2,0.384167,0.378779,0.789167,0.251871,702.21,640
65,2011-03-06,1,0,3,0,0,0,2,0.376522,0.366252,0.948261,0.343287,588.13,114
66,2011-03-07,1,0,3,0,1,1,1,0.261739,0.238461,0.551304,0.341352,216.56,244
67,2011-03-08,1,0,3,0,2,1,1,0.2925,0.3024,0.420833,0.12065,310.66,316
68,2011-03-09,1,0,3,0,3,1,2,0.295833,0.286608,0.775417,0.22015,181.41,191
69,2011-03-10,1,0,3,0,4,1,3,0.389091,0.385668,0.0,0.261877,449.07,46
70,2011-03-11,1,0,3,0,5,1,2,0.316522,0.305,0.649565,0.23297,266.37,247
71,2011-03-12,1,0,3,0,6,0,1,0.329167,0.32575,0.594583,0.220775,828.91,724
72,2011-03-13,1,0,3,0,0,0,1,0.384348,0.380091,0.527391,0.270604,977.09,982
73,2011-03-14,1,0,3,0,1,1,1,0.325217,0.332,0.496957,0.136926,343.25,359
74,2011-03-15,1,0,3,0,2,1,2,0.317391,0.318178,0.655652,0.184309,254.91,289
75,2011-03-16,1,0,3,0,3,1,2,0.365217,0.36693,0.776522,0.203117,289.19,321
76,2011-03-17,1,0,3,0,4,1,1,0.415,0.410333,0.602917,0.209579,384.92,424
77,2011-03-18,1,0,3,0,5,1,1,0.54,0.527009,0.525217,0.231017,1575.57,884
78,2011-03-19,1,0,3,0,6,0,1,0.4725,0.466525,0.379167,0.368167,1630.82,1424
79,2011-03-20,1,0,3,0,0,0,1,0.3325,0.32575,0.47375,0.207721,1128.02,1047
80,2011-03-21,2,0,3,0,1,1,2,0.430435,0.409735,0.737391,0.288783,295.64,401
81,2011-03-22,2,0,3,0,2,1,1,0.441667,0.440642,0.624583,0.22575,511.68,460
82,2011-03-23,2,0,3,0,3,1,2,0.346957,0.337939,0.839565,0.234261,217.51,203
83,2011-03-24,2,0,3,0,4,1,2,0.285,0.270833,0.805833,0.243787,169.93,166
84,2011-03-25,2,0,3,0,5,1,1,0.264167,0.256312,0.495,0.230725,269.15,300
85,2011-03-26,2,0,3,0,6,0,1,0.265833,0.257571,0.394167,0.209571,773.12,981
86,2011-03-27,2,0,3,0,0,0,2,0.253043,0.250339,0.493913,0.1843,466.77,472
87,2011-03-28,2,0,3,0,1,1,1,0.264348,0.257574,0.302174,0.212204,295.03,222
88,2011-03-29,2,0,3,0,2,1,1,0.3025,0.292908,0.314167,0.226996,372.79,317
89,2011-03-30,2,0,3,0,3,1,2,0.3,0.29735,0.646667,0.172888,218.23,168
90,2011-03-31,2,0,3,0,4,1,3,0.268333,0.257575,0.918333,0.217646,172.12,179
91,2011-04-01,2,0,4,0,5,1,2,0.3,0.283454,0.68625,0.258708,302.14,307
92,2011-04-02,2,0,4,0,6,0,2,0.315,0.315637,0.65375,0.197146,755.73,898
93,2011-04-03,2,0,4,0,0,0,1,0.378333,0.378767,0.48,0.182213,1508.42,1651
94,2011-04-04,2,0,4,0,1,1,1,0.573333,0.542929,0.42625,0.385571,821.83,734
95,2011-04-05,2,0,4,0,2,1,2,0.414167,0.39835,0.642083,0.388067,232.43,167
96,2011-04-06,2,0,4,0,3,1,1,0.390833,0.387608,0.470833,0.263063,448.77,413
97,2011-04-07,2,0,4,0,4,1,1,0.4375,0.433696,0.602917,0.162312,561.67,571
98,2011-04-08,2,0,4,0,5,1,2,0.335833,0.324479,0.83625,0.226992,268.22,172
99,2011-04-09,2,0,4,0,6,0,2,0.3425,0.341529,0.8775,0.133083,876.47,879
100,2011-04-10,2,0,4,0,0,0,2,0.426667,0.426737,0.8575,0.146767,1234.71,1188