Skip to content

Commit

Permalink
Custom Input: Add type conversion to float and warning (#31)
Browse files Browse the repository at this point in the history
* Added auto conversion and warning for integer input data. Resolved #30

* Added integer conversion test #31

* Updated integer conversion test #31

* Updated integer conversion #31
  • Loading branch information
B-Deforce authored Jul 16, 2023
1 parent 484d064 commit 6834ec8
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 20 deletions.
46 changes: 32 additions & 14 deletions gutenTAG/base_oscillations/custom_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import pandas as pd
import warnings

from . import BaseOscillation
from .interface import BaseOscillationInterface
Expand All @@ -18,16 +19,19 @@ def get_base_oscillation_kind(self) -> str:
def get_timeseries_periods(self) -> Optional[int]:
return None

def generate_only_base(self,
ctx: BOGenerationContext,
input_timeseries_path_test: str = None,
use_column_test: Union[str, int] = None,
length: Optional[int] = None,
input_timeseries_path_train: Optional[str] = None,
use_column_train: Optional[Union[str, int]] = None,
semi_supervised: Optional[bool] = None,
supervised: Optional[bool] = None,
*args, **kwargs) -> np.ndarray:
def generate_only_base(
self,
ctx: BOGenerationContext,
input_timeseries_path_test: str = None,
use_column_test: Union[str, int] = None,
length: Optional[int] = None,
input_timeseries_path_train: Optional[str] = None,
use_column_train: Optional[Union[str, int]] = None,
semi_supervised: Optional[bool] = None,
supervised: Optional[bool] = None,
*args,
**kwargs,
) -> np.ndarray:
"""Generates a numpy array of timeseries data from a CSV file based on the specified parameters.
The following requirements must be met by the input file:
Expand Down Expand Up @@ -69,14 +73,20 @@ def generate_only_base(self,
If the number of rows in the input timeseries file is less than the desired length.
"""
length = length or self.length
input_timeseries_path_train = input_timeseries_path_train or self.input_timeseries_path_train
input_timeseries_path_test = input_timeseries_path_test or self.input_timeseries_path_test
input_timeseries_path_train = (
input_timeseries_path_train or self.input_timeseries_path_train
)
input_timeseries_path_test = (
input_timeseries_path_test or self.input_timeseries_path_test
)
use_column_train = use_column_train or self.use_column_train
use_column_test = use_column_test or self.use_column_test

if semi_supervised or supervised:
if input_timeseries_path_train is None:
raise ValueError("No path to an input timeseries file for the training timeseries specified!")
raise ValueError(
"No path to an input timeseries file for the training timeseries specified!"
)

df = pd.read_csv(input_timeseries_path_train, usecols=[use_column_train])

Expand All @@ -87,7 +97,15 @@ def generate_only_base(self,
df = pd.read_csv(input_timeseries_path_test, usecols=[use_column_test])

if len(df) < length:
raise ValueError("Number of rows in the input timeseries file is less than the desired length")
raise ValueError(
"Number of rows in the input timeseries file is less than the desired length"
)
col_type = df.dtypes[0]
if col_type != np.float_:
df = df.astype(float)
warnings.warn(
f"Input data was of {col_type} type and has been automatically converted to float."
)
return df.iloc[:length, 0]


Expand Down
41 changes: 35 additions & 6 deletions tests/test_base_oscillations/test_custom_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from pathlib import Path

import pandas as pd
import numpy as np
import os
from numpy.random import SeedSequence
from numpy.testing import assert_array_equal

Expand All @@ -16,11 +18,21 @@ def setUp(self) -> None:
self.input_path2 = Path("tests/custom_input_ts/dummy_timeseries_2.csv")
self.column_idx = 1
self.length = 100
self.expected_test = pd.read_csv(self.input_path1, usecols=[self.column_idx]).iloc[:self.length, 0].values
self.expected_train = pd.read_csv(self.input_path2, usecols=[self.column_idx]).iloc[:self.length, 0].values
self.expected_test = (
pd.read_csv(self.input_path1, usecols=[self.column_idx])
.iloc[: self.length, 0]
.values
)
self.expected_train = (
pd.read_csv(self.input_path2, usecols=[self.column_idx])
.iloc[: self.length, 0]
.values
)

def test_all_args_specified_just_unsupervised(self):
timeseries = CustomInput(supervised=False, semi_supervised=False).generate_only_base(
timeseries = CustomInput(
supervised=False, semi_supervised=False
).generate_only_base(
ctx=self.ctx,
length=self.length,
input_timeseries_path_test=str(self.input_path1),
Expand All @@ -40,7 +52,7 @@ def test_all_args_specified_supervised(self):
use_column_test=self.column_idx,
input_timeseries_path_train=str(self.input_path2),
use_column_train=self.column_idx,
supervised=True
supervised=True,
)

self.assertEqual(len(timeseries), 100)
Expand All @@ -54,7 +66,7 @@ def test_all_args_specified_semi_supervised(self):
use_column_test=self.column_idx,
input_timeseries_path_train=str(self.input_path2),
use_column_train=self.column_idx,
semi_supervised=True
semi_supervised=True,
)

self.assertEqual(len(timeseries), 100)
Expand Down Expand Up @@ -97,7 +109,7 @@ def test_trainfile_none(self):
input_timeseries_path_test="wrong-folder/missing-file.csv",
use_column_test=1,
use_column_train=1,
supervised=True
supervised=True,
)
self.assertRegex(str(e.exception).lower(), "no path.*training timeseries")

Expand Down Expand Up @@ -127,3 +139,20 @@ def test_input_too_short(self):
use_column_test=1,
)
self.assertRegex(str(e.exception).lower(), "less than the desired length")

def test_integer_conversion(self):
df = pd.DataFrame({"data": [1, 2, 3, 4, 5]})
df.to_csv("test_data.csv", index=False)
custom_input = CustomInput("test_data.csv")
# test if warning is raised
with self.assertWarns(UserWarning):
# Generate the time series data
timeseries = custom_input.generate_only_base(
ctx=self.ctx,
length=5,
input_timeseries_path_test="test_data.csv",
use_column_test="data",
)
# test if data is properly converted
self.assertEqual(timeseries.dtype, np.float64)
os.remove("test_data.csv")

0 comments on commit 6834ec8

Please sign in to comment.