From 6834ec878257ae9fb2d76480c92b358fbfd272ce Mon Sep 17 00:00:00 2001 From: Boje Deforce <72612139+B-Deforce@users.noreply.github.com> Date: Sun, 16 Jul 2023 11:00:45 -0400 Subject: [PATCH] Custom Input: Add type conversion to float and warning (#31) * Added auto conversion and warning for integer input data. Resolved #30 * Added integer conversion test #31 * Updated integer conversion test #31 * Updated integer conversion #31 --- gutenTAG/base_oscillations/custom_input.py | 46 +++++++++++++------ .../test_custom_input.py | 41 ++++++++++++++--- 2 files changed, 67 insertions(+), 20 deletions(-) diff --git a/gutenTAG/base_oscillations/custom_input.py b/gutenTAG/base_oscillations/custom_input.py index 8d65194..232b665 100644 --- a/gutenTAG/base_oscillations/custom_input.py +++ b/gutenTAG/base_oscillations/custom_input.py @@ -2,6 +2,7 @@ import numpy as np import pandas as pd +import warnings from . import BaseOscillation from .interface import BaseOscillationInterface @@ -18,16 +19,19 @@ def get_base_oscillation_kind(self) -> str: def get_timeseries_periods(self) -> Optional[int]: return None - def generate_only_base(self, - ctx: BOGenerationContext, - input_timeseries_path_test: str = None, - use_column_test: Union[str, int] = None, - length: Optional[int] = None, - input_timeseries_path_train: Optional[str] = None, - use_column_train: Optional[Union[str, int]] = None, - semi_supervised: Optional[bool] = None, - supervised: Optional[bool] = None, - *args, **kwargs) -> np.ndarray: + def generate_only_base( + self, + ctx: BOGenerationContext, + input_timeseries_path_test: str = None, + use_column_test: Union[str, int] = None, + length: Optional[int] = None, + input_timeseries_path_train: Optional[str] = None, + use_column_train: Optional[Union[str, int]] = None, + semi_supervised: Optional[bool] = None, + supervised: Optional[bool] = None, + *args, + **kwargs, + ) -> np.ndarray: """Generates a numpy array of timeseries data from a CSV file based on the specified parameters. The following requirements must be met by the input file: @@ -69,14 +73,20 @@ def generate_only_base(self, If the number of rows in the input timeseries file is less than the desired length. """ length = length or self.length - input_timeseries_path_train = input_timeseries_path_train or self.input_timeseries_path_train - input_timeseries_path_test = input_timeseries_path_test or self.input_timeseries_path_test + input_timeseries_path_train = ( + input_timeseries_path_train or self.input_timeseries_path_train + ) + input_timeseries_path_test = ( + input_timeseries_path_test or self.input_timeseries_path_test + ) use_column_train = use_column_train or self.use_column_train use_column_test = use_column_test or self.use_column_test if semi_supervised or supervised: if input_timeseries_path_train is None: - raise ValueError("No path to an input timeseries file for the training timeseries specified!") + raise ValueError( + "No path to an input timeseries file for the training timeseries specified!" + ) df = pd.read_csv(input_timeseries_path_train, usecols=[use_column_train]) @@ -87,7 +97,15 @@ def generate_only_base(self, df = pd.read_csv(input_timeseries_path_test, usecols=[use_column_test]) if len(df) < length: - raise ValueError("Number of rows in the input timeseries file is less than the desired length") + raise ValueError( + "Number of rows in the input timeseries file is less than the desired length" + ) + col_type = df.dtypes[0] + if col_type != np.float_: + df = df.astype(float) + warnings.warn( + f"Input data was of {col_type} type and has been automatically converted to float." + ) return df.iloc[:length, 0] diff --git a/tests/test_base_oscillations/test_custom_input.py b/tests/test_base_oscillations/test_custom_input.py index 6fbe8dc..b03f723 100644 --- a/tests/test_base_oscillations/test_custom_input.py +++ b/tests/test_base_oscillations/test_custom_input.py @@ -2,6 +2,8 @@ from pathlib import Path import pandas as pd +import numpy as np +import os from numpy.random import SeedSequence from numpy.testing import assert_array_equal @@ -16,11 +18,21 @@ def setUp(self) -> None: self.input_path2 = Path("tests/custom_input_ts/dummy_timeseries_2.csv") self.column_idx = 1 self.length = 100 - self.expected_test = pd.read_csv(self.input_path1, usecols=[self.column_idx]).iloc[:self.length, 0].values - self.expected_train = pd.read_csv(self.input_path2, usecols=[self.column_idx]).iloc[:self.length, 0].values + self.expected_test = ( + pd.read_csv(self.input_path1, usecols=[self.column_idx]) + .iloc[: self.length, 0] + .values + ) + self.expected_train = ( + pd.read_csv(self.input_path2, usecols=[self.column_idx]) + .iloc[: self.length, 0] + .values + ) def test_all_args_specified_just_unsupervised(self): - timeseries = CustomInput(supervised=False, semi_supervised=False).generate_only_base( + timeseries = CustomInput( + supervised=False, semi_supervised=False + ).generate_only_base( ctx=self.ctx, length=self.length, input_timeseries_path_test=str(self.input_path1), @@ -40,7 +52,7 @@ def test_all_args_specified_supervised(self): use_column_test=self.column_idx, input_timeseries_path_train=str(self.input_path2), use_column_train=self.column_idx, - supervised=True + supervised=True, ) self.assertEqual(len(timeseries), 100) @@ -54,7 +66,7 @@ def test_all_args_specified_semi_supervised(self): use_column_test=self.column_idx, input_timeseries_path_train=str(self.input_path2), use_column_train=self.column_idx, - semi_supervised=True + semi_supervised=True, ) self.assertEqual(len(timeseries), 100) @@ -97,7 +109,7 @@ def test_trainfile_none(self): input_timeseries_path_test="wrong-folder/missing-file.csv", use_column_test=1, use_column_train=1, - supervised=True + supervised=True, ) self.assertRegex(str(e.exception).lower(), "no path.*training timeseries") @@ -127,3 +139,20 @@ def test_input_too_short(self): use_column_test=1, ) self.assertRegex(str(e.exception).lower(), "less than the desired length") + + def test_integer_conversion(self): + df = pd.DataFrame({"data": [1, 2, 3, 4, 5]}) + df.to_csv("test_data.csv", index=False) + custom_input = CustomInput("test_data.csv") + # test if warning is raised + with self.assertWarns(UserWarning): + # Generate the time series data + timeseries = custom_input.generate_only_base( + ctx=self.ctx, + length=5, + input_timeseries_path_test="test_data.csv", + use_column_test="data", + ) + # test if data is properly converted + self.assertEqual(timeseries.dtype, np.float64) + os.remove("test_data.csv")