Skip to content

Commit

Permalink
add validation on _window_size param in topo extractor (#1284)
Browse files Browse the repository at this point in the history
* fix: add positive appropriation for a window_size param

* test: test positive _window_size in topo extractor
  • Loading branch information
Lopa10ko committed Apr 22, 2024
1 parent 0bdece1 commit 79ec471
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def fit(self, input_data: InputData):
self._window_size = int(input_data.features.shape[1] * self.window_size_as_share)
self._window_size = max(self._window_size, 2)
self._window_size = min(self._window_size, input_data.features.shape[1] - 2)
self._window_size = max(self._window_size, 1)
return self

def transform(self, input_data: InputData) -> OutputData:
Expand Down
14 changes: 14 additions & 0 deletions test/unit/data_operations/test_data_operation_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import pandas as pd
import pytest

from fedot.core.data.data import InputData
from fedot.core.data.data_split import train_test_data_setup
Expand All @@ -10,6 +11,7 @@
from fedot.core.repository.dataset_types import DataTypesEnum
from fedot.core.repository.tasks import Task, TaskTypesEnum, TsForecastingParams
from fedot.core.utils import fedot_project_root
from .test_time_series_operations import get_timeseries


def get_ts_pipeline(window_size):
Expand Down Expand Up @@ -102,3 +104,15 @@ def test_params_filter_with_non_default():
assert default_params == {}
assert 'n_neighbors' in list(updated_params.keys())
assert len(list(updated_params.keys())) == 1


@pytest.mark.parametrize(('length', 'features_count', 'target_count', 'window_size'),
[(40, 1, 1, 10), (5, 1, 1, 10), (4, 1, 1, 10), (2, 1, 1, 10), (1, 1, 1, 10)])
def test_positive_window_size_in_fast_topo(length, features_count, target_count, window_size):
data = get_timeseries(length=length, features_count=features_count, target_count=target_count, random=True)
lagged_node = PipelineNode('lagged')
lagged_node.parameters = {'window_size': window_size}
lagged_data = lagged_node.fit(data)
topo_node = PipelineNode('topological_features')
topo_node.fit(lagged_data)
assert topo_node.fitted_operation._window_size > 0

0 comments on commit 79ec471

Please sign in to comment.