-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add statistics in current dataset for multiclass (#53)
* feat: add statistics for current multiclass * feat: add test and handle multiclass in job
- Loading branch information
Showing
10 changed files
with
280 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,215 @@ | ||
import datetime | ||
import uuid | ||
|
||
import pytest | ||
|
||
from jobs.metrics.statistics import calculate_statistics_current | ||
from jobs.utils.models import ( | ||
ModelOut, | ||
ModelType, | ||
DataType, | ||
OutputType, | ||
ColumnDefinition, | ||
SupportedTypes, | ||
Granularity, | ||
) | ||
from models.current_dataset import CurrentDataset | ||
from tests.utils.pytest_utils import my_approx | ||
|
||
|
||
@pytest.fixture() | ||
def dataset_target_int(spark_fixture, test_data_dir): | ||
yield ( | ||
spark_fixture.read.csv( | ||
f"{test_data_dir}/reference/multiclass/current/dataset_target_int.csv", | ||
header=True, | ||
), | ||
spark_fixture.read.csv( | ||
f"{test_data_dir}/reference/multiclass/reference/dataset_target_int.csv", | ||
header=True, | ||
), | ||
) | ||
|
||
|
||
@pytest.fixture() | ||
def dataset_target_string(spark_fixture, test_data_dir): | ||
yield ( | ||
spark_fixture.read.csv( | ||
f"{test_data_dir}/reference/multiclass/current/dataset_target_string.csv", | ||
header=True, | ||
), | ||
spark_fixture.read.csv( | ||
f"{test_data_dir}/reference/multiclass/reference/dataset_target_string.csv", | ||
header=True, | ||
), | ||
) | ||
|
||
|
||
@pytest.fixture() | ||
def dataset_perfect_classes(spark_fixture, test_data_dir): | ||
yield ( | ||
spark_fixture.read.csv( | ||
f"{test_data_dir}/reference/multiclass/current/dataset_perfect_classes.csv", | ||
header=True, | ||
), | ||
spark_fixture.read.csv( | ||
f"{test_data_dir}/reference/multiclass/reference/dataset_perfect_classes.csv", | ||
header=True, | ||
), | ||
) | ||
|
||
|
||
def test_calculation_dataset_target_int(spark_fixture, dataset_target_int): | ||
output = OutputType( | ||
prediction=ColumnDefinition(name="prediction", type=SupportedTypes.int), | ||
prediction_proba=None, | ||
output=[ColumnDefinition(name="prediction", type=SupportedTypes.int)], | ||
) | ||
target = ColumnDefinition(name="target", type=SupportedTypes.int) | ||
timestamp = ColumnDefinition(name="datetime", type=SupportedTypes.datetime) | ||
granularity = Granularity.HOUR | ||
features = [ | ||
ColumnDefinition(name="cat1", type=SupportedTypes.string), | ||
ColumnDefinition(name="cat2", type=SupportedTypes.string), | ||
ColumnDefinition(name="num1", type=SupportedTypes.float), | ||
ColumnDefinition(name="num2", type=SupportedTypes.float), | ||
] | ||
model = ModelOut( | ||
uuid=uuid.uuid4(), | ||
name="model", | ||
description="description", | ||
model_type=ModelType.MULTI_CLASS, | ||
data_type=DataType.TABULAR, | ||
timestamp=timestamp, | ||
granularity=granularity, | ||
outputs=output, | ||
target=target, | ||
features=features, | ||
frameworks="framework", | ||
algorithm="algorithm", | ||
created_at=str(datetime.datetime.now()), | ||
updated_at=str(datetime.datetime.now()), | ||
) | ||
|
||
current_dataframe, reference_dataframe = dataset_target_int | ||
current_dataset = CurrentDataset(model=model, raw_dataframe=current_dataframe) | ||
|
||
stats = calculate_statistics_current(current_dataset) | ||
|
||
assert stats == my_approx( | ||
{ | ||
"categorical": 2, | ||
"datetime": 1, | ||
"duplicate_rows": 0, | ||
"duplicate_rows_perc": 0.0, | ||
"missing_cells": 3, | ||
"missing_cells_perc": 4.285714285714286, | ||
"n_observations": 10, | ||
"n_variables": 7, | ||
"numeric": 4, | ||
} | ||
) | ||
|
||
|
||
def test_calculation_dataset_target_string(spark_fixture, dataset_target_string): | ||
output = OutputType( | ||
prediction=ColumnDefinition(name="prediction", type=SupportedTypes.string), | ||
prediction_proba=None, | ||
output=[ColumnDefinition(name="prediction", type=SupportedTypes.string)], | ||
) | ||
target = ColumnDefinition(name="target", type=SupportedTypes.string) | ||
timestamp = ColumnDefinition(name="datetime", type=SupportedTypes.datetime) | ||
granularity = Granularity.HOUR | ||
features = [ | ||
ColumnDefinition(name="cat1", type=SupportedTypes.string), | ||
ColumnDefinition(name="cat2", type=SupportedTypes.string), | ||
ColumnDefinition(name="num1", type=SupportedTypes.float), | ||
ColumnDefinition(name="num2", type=SupportedTypes.float), | ||
] | ||
model = ModelOut( | ||
uuid=uuid.uuid4(), | ||
name="model", | ||
description="description", | ||
model_type=ModelType.MULTI_CLASS, | ||
data_type=DataType.TABULAR, | ||
timestamp=timestamp, | ||
granularity=granularity, | ||
outputs=output, | ||
target=target, | ||
features=features, | ||
frameworks="framework", | ||
algorithm="algorithm", | ||
created_at=str(datetime.datetime.now()), | ||
updated_at=str(datetime.datetime.now()), | ||
) | ||
|
||
current_dataframe, reference_dataframe = dataset_target_string | ||
current_dataset = CurrentDataset(model=model, raw_dataframe=current_dataframe) | ||
|
||
stats = calculate_statistics_current(current_dataset) | ||
|
||
assert stats == my_approx( | ||
{ | ||
"categorical": 4, | ||
"datetime": 1, | ||
"duplicate_rows": 0, | ||
"duplicate_rows_perc": 0.0, | ||
"missing_cells": 3, | ||
"missing_cells_perc": 4.285714285714286, | ||
"n_observations": 10, | ||
"n_variables": 7, | ||
"numeric": 2, | ||
} | ||
) | ||
|
||
|
||
def test_calculation_dataset_perfect_classes(spark_fixture, dataset_perfect_classes): | ||
output = OutputType( | ||
prediction=ColumnDefinition(name="prediction", type=SupportedTypes.string), | ||
prediction_proba=None, | ||
output=[ColumnDefinition(name="prediction", type=SupportedTypes.string)], | ||
) | ||
target = ColumnDefinition(name="target", type=SupportedTypes.string) | ||
timestamp = ColumnDefinition(name="datetime", type=SupportedTypes.datetime) | ||
granularity = Granularity.HOUR | ||
features = [ | ||
ColumnDefinition(name="cat1", type=SupportedTypes.string), | ||
ColumnDefinition(name="cat2", type=SupportedTypes.string), | ||
ColumnDefinition(name="num1", type=SupportedTypes.float), | ||
ColumnDefinition(name="num2", type=SupportedTypes.float), | ||
] | ||
model = ModelOut( | ||
uuid=uuid.uuid4(), | ||
name="model", | ||
description="description", | ||
model_type=ModelType.MULTI_CLASS, | ||
data_type=DataType.TABULAR, | ||
timestamp=timestamp, | ||
granularity=granularity, | ||
outputs=output, | ||
target=target, | ||
features=features, | ||
frameworks="framework", | ||
algorithm="algorithm", | ||
created_at=str(datetime.datetime.now()), | ||
updated_at=str(datetime.datetime.now()), | ||
) | ||
|
||
current_dataframe, reference_dataframe = dataset_perfect_classes | ||
current_dataset = CurrentDataset(model=model, raw_dataframe=current_dataframe) | ||
|
||
stats = calculate_statistics_current(current_dataset) | ||
|
||
assert stats == my_approx( | ||
{ | ||
"categorical": 4, | ||
"datetime": 1, | ||
"duplicate_rows": 0, | ||
"duplicate_rows_perc": 0.0, | ||
"missing_cells": 3, | ||
"missing_cells_perc": 4.285714285714286, | ||
"n_observations": 10, | ||
"n_variables": 7, | ||
"numeric": 2, | ||
} | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
File renamed without changes.
11 changes: 11 additions & 0 deletions
11
spark/tests/resources/reference/multiclass/reference/dataset_perfect_classes.csv
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
cat1,cat2,num1,num2,prediction,target,datetime | ||
A,X,1.0,1.4,HEALTHY,HEALTHY,2024-06-16 00:01:00-05:00 | ||
B,X,1.5,100.0,UNHEALTHY,UNHEALTHY,2024-06-16 00:02:00-05:00 | ||
A,Y,3.0,123.0,HEALTHY,HEALTHY,2024-06-16 00:03:00-05:00 | ||
B,X,0.5,,UNKNOWN,UNKNOWN,2024-06-16 00:04:00-05:00 | ||
B,X,0.5,,ORPHAN,ORPHAN,2024-06-16 00:05:00-05:00 | ||
B,X,,200.0,HEALTHY,HEALTHY,2024-06-16 00:06:00-05:00 | ||
C,X,1.0,300.0,UNHEALTHY,UNHEALTHY,2024-06-16 00:07:00-05:00 | ||
A,X,1.0,499.0,UNKNOWN,UNKNOWN,2024-06-16 00:08:00-05:00 | ||
A,X,1.0,499.0,HEALTHY,HEALTHY,2024-06-16 00:09:00-05:00 | ||
A,X,1.0,499.0,ORPHAN,ORPHAN,2024-06-16 00:10:00-05:00 |
11 changes: 11 additions & 0 deletions
11
spark/tests/resources/reference/multiclass/reference/dataset_target_int.csv
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
cat1,cat2,num1,num2,prediction,target,datetime | ||
A,X,1.0,1.4,1,1,2024-06-16 00:01:00-05:00 | ||
B,X,1.5,100.0,0,0,2024-06-16 00:02:00-05:00 | ||
A,Y,3.0,123.0,1,1,2024-06-16 00:03:00-05:00 | ||
B,X,0.5,,2,0,2024-06-16 00:04:00-05:00 | ||
B,X,0.5,,3,2,2024-06-16 00:05:00-05:00 | ||
B,X,,200.0,1,3,2024-06-16 00:06:00-05:00 | ||
C,X,1.0,300.0,0,0,2024-06-16 00:07:00-05:00 | ||
A,X,1.0,499.0,2,2,2024-06-16 00:08:00-05:00 | ||
A,X,1.0,499.0,1,1,2024-06-16 00:09:00-05:00 | ||
A,X,1.0,499.0,3,2,2024-06-16 00:10:00-05:00 |
11 changes: 11 additions & 0 deletions
11
spark/tests/resources/reference/multiclass/reference/dataset_target_string.csv
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
cat1,cat2,num1,num2,prediction,target,datetime | ||
A,X,1.0,1.4,HEALTHY,HEALTHY,2024-06-16 00:01:00-05:00 | ||
B,X,1.5,100.0,UNHEALTHY,UNHEALTHY,2024-06-16 00:02:00-05:00 | ||
A,Y,3.0,123.0,HEALTHY,HEALTHY,2024-06-16 00:03:00-05:00 | ||
B,X,0.5,,UNKNOWN,UNHEALTHY,2024-06-16 00:04:00-05:00 | ||
B,X,0.5,,ORPHAN,UNKNOWN,2024-06-16 00:05:00-05:00 | ||
B,X,,200.0,HEALTHY,ORPHAN,2024-06-16 00:06:00-05:00 | ||
C,X,1.0,300.0,UNHEALTHY,UNHEALTHY,2024-06-16 00:07:00-05:00 | ||
A,X,1.0,499.0,UNKNOWN,UNKNOWN,2024-06-16 00:08:00-05:00 | ||
A,X,1.0,499.0,HEALTHY,HEALTHY,2024-06-16 00:09:00-05:00 | ||
A,X,1.0,499.0,ORPHAN,UNKNOWN,2024-06-16 00:10:00-05:00 |