diff --git a/gutenTAG/base_oscillations/custom_input.py b/gutenTAG/base_oscillations/custom_input.py index b07ee89..2f62a7c 100644 --- a/gutenTAG/base_oscillations/custom_input.py +++ b/gutenTAG/base_oscillations/custom_input.py @@ -100,7 +100,7 @@ def generate_only_base( raise ValueError( "Number of rows in the input timeseries file is less than the desired length" ) - col_type = df.dtypes[0] + col_type = df.dtypes.iloc[0] if col_type != np.float64: df = df.astype(np.float64) warnings.warn( diff --git a/gutenTAG/config/parser.py b/gutenTAG/config/parser.py index 460fe61..3a2ff3e 100644 --- a/gutenTAG/config/parser.py +++ b/gutenTAG/config/parser.py @@ -63,7 +63,9 @@ def parse(self, config: Dict) -> ResultType: for i, ts in enumerate(config.get(TIMESERIES, [])): name = ts.get(PARAMETERS.NAME, f"ts_{i}") - if self._skip_name(name) or not self._check_compatibility(ts): + bos, n_channel = self._extract_bos(ts, name) + + if self._skip_name(name) or not self._check_compatibility(ts, name, bos, n_channel): continue raw_ts_config = deepcopy(ts) @@ -72,8 +74,8 @@ def parse(self, config: Dict) -> ResultType: generation_options = GenerationOptions.from_dict(ts) generation_options.dataset_name = name - base_oscillations = self._build_base_oscillations(ts) - anomalies = self._build_anomalies(ts) + base_oscillations = self._build_base_oscillations(ts, bos) + anomalies = self._build_anomalies(ts, name) self.result.append( (base_oscillations, anomalies, generation_options, raw_ts_config) @@ -81,44 +83,61 @@ def parse(self, config: Dict) -> ResultType: return self.result - def _check_compatibility(self, ts: Dict) -> bool: - base_oscillations = ts.get( - BASE_OSCILLATIONS, - [ts.get(BASE_OSCILLATION)] * ts.get(PARAMETERS.CHANNELS, 0), - ) + def _check_compatibility(self, ts: Dict, name: str, bos: list[dict], n_channels: int) -> bool: anomalies = ts.get(ANOMALIES, []) for anomaly in anomalies: - base_oscillation = base_oscillations[ - anomaly.get( - PARAMETERS.CHANNEL, default_values[ANOMALIES][PARAMETERS.CHANNEL] - ) - ][PARAMETERS.KIND] + channel = anomaly.get(PARAMETERS.CHANNEL, default_values[ANOMALIES][PARAMETERS.CHANNEL]) + if channel >= n_channels: + self._report_error( + name, + f"Invalid channel index: {channel} >= {n_channels}." + ) + base_oscillation = bos[channel][PARAMETERS.KIND] for anomaly_kind in anomaly.get(PARAMETERS.KINDS, []): anomaly_kind = anomaly_kind[PARAMETERS.KIND] if not Compatibility.check(anomaly_kind, base_oscillation): - if self.skip_errors: - logging.warning( - f"Skip generation of time series {ts.get('name', '')} due to incompatible types: {anomaly_kind} -> {base_oscillation}." - ) - return False - else: - raise ValueError( - f"Incompatible types: {anomaly_kind} -> {base_oscillation}." - ) + self._report_error( + name, + f"Incompatible types: {anomaly_kind} -> {base_oscillation}.", + warning_msg=f"Skip generation due to incompatible types: {anomaly_kind} -> {base_oscillation}.", + ) + return False return True def _skip_name(self, name: str) -> bool: return self.only is not None and name != self.only - def _build_base_oscillations(self, d: Dict) -> List[BaseOscillationInterface]: + def _build_base_oscillations(self, d: Dict, bos: list[dict]) -> List[BaseOscillationInterface]: length = d.get( PARAMETERS.LENGTH, default_values[BASE_OSCILLATIONS][PARAMETERS.LENGTH] ) - bos = d.get( - BASE_OSCILLATIONS, [d.get(BASE_OSCILLATION)] * d.get(PARAMETERS.CHANNELS, 0) - ) return [self._build_single_base_oscillation(bo, length) for bo in bos] + def _extract_bos(self, d: Dict, name: str) -> Tuple[List[Dict], int]: + if BASE_OSCILLATIONS in d: + base_oscillations: List[Dict] = d.get(BASE_OSCILLATIONS) # type: ignore + elif BASE_OSCILLATION in d and PARAMETERS.CHANNELS not in d: + self._report_error( + name, + f"'{BASE_OSCILLATION}' requires parameter '{PARAMETERS.CHANNELS}'." + ) + else: + bo_template = d.get(BASE_OSCILLATION) + if isinstance(bo_template, list): + self._report_error( + name, + f"'{BASE_OSCILLATION}' must be a single object." + ) + base_oscillations = [bo_template] * d.get(PARAMETERS.CHANNELS, 0) # type: ignore + + n_channels = len(base_oscillations) + if n_channels == 0: + self._report_error( + name, + f"No base oscillations defined. Please provide either '{BASE_OSCILLATION}' and '{PARAMETERS.CHANNELS}' or '{BASE_OSCILLATIONS}'." + ) + return base_oscillations, n_channels + def _build_single_base_oscillation( self, d: Dict, length: int ) -> BaseOscillationInterface: @@ -129,13 +148,13 @@ def _build_single_base_oscillation( key = base_oscillation_config[PARAMETERS.KIND] return BaseOscillation.from_key(key, **base_oscillation_config) - def _build_anomalies(self, d: Dict) -> List[Anomaly]: + def _build_anomalies(self, d: Dict, ts_name: str) -> List[Anomaly]: return [ - self._build_single_anomaly(anomaly_config) + self._build_single_anomaly(anomaly_config, ts_name) for anomaly_config in d.get(ANOMALIES, []) ] - def _build_single_anomaly(self, d: Dict) -> Anomaly: + def _build_single_anomaly(self, d: Dict, ts_name: str) -> Anomaly: anomaly = Anomaly( Position( d.get( @@ -151,19 +170,20 @@ def _build_single_anomaly(self, d: Dict) -> Anomaly: ), ) - anomaly_kinds = self._build_anomaly_kinds(d, anomaly.anomaly_length) + anomaly_kinds = self._build_anomaly_kinds(d, anomaly.anomaly_length, ts_name) for kind in anomaly_kinds: anomaly.set_anomaly(kind) return anomaly - def _build_anomaly_kinds(self, d: Dict, length: int) -> List[BaseAnomaly]: - return [ - self._build_single_anomaly_kind(anomaly_kind, length) + def _build_anomaly_kinds(self, d: Dict, length: int, ts_name: str) -> List[BaseAnomaly]: + potential_anomalies = [ + self._build_single_anomaly_kind(anomaly_kind, length, ts_name) for anomaly_kind in d.get(PARAMETERS.KINDS, []) ] + return [anomaly for anomaly in potential_anomalies if anomaly is not None] - def _build_single_anomaly_kind(self, d: Dict, length: int) -> BaseAnomaly: + def _build_single_anomaly_kind(self, d: Dict, length: int, ts_name: str) -> Optional[BaseAnomaly]: kind = d[PARAMETERS.KIND] if kind == PARAMETERS.TREND: parameters = { @@ -179,8 +199,13 @@ def _build_single_anomaly_kind(self, d: Dict, length: int) -> BaseAnomaly: except TypeError as ex: if "unexpected keyword argument" in str(ex): parameter = str(ex).split("'")[-2] - raise ValueError( - f"Anomaly kind '{kind}' does not support parameter '{parameter}'" - ) from ex + raise ValueError(f"Time series {ts_name}: Anomaly kind '{kind}' does not support parameter '{parameter}'.") from ex else: raise ex + + def _report_error(self, name: str, msg: str, warning_msg: Optional[str] = None) -> None: + warning_msg = warning_msg or msg + if self.skip_errors: + logging.warning(warning_msg) + else: + raise ValueError(f"Time series {name}: {msg}") diff --git a/gutenTAG/generator/overview.py b/gutenTAG/generator/overview.py index 44aec72..3a91a96 100644 --- a/gutenTAG/generator/overview.py +++ b/gutenTAG/generator/overview.py @@ -7,7 +7,7 @@ class DictSanitizer: - NUMPY_TYPES = tuple(list(np.core._type_aliases.allTypes.values()) + [np.ndarray]) # type: ignore # mypy does not find allTypes + NUMPY_TYPES = tuple(list(np._core._type_aliases.allTypes.values()) + [np.ndarray]) # type: ignore # mypy does not find allTypes def sanitize(self, obj: Dict) -> Dict: for key, value in obj.items(): diff --git a/pyproject.toml b/pyproject.toml index 3d88810..9f51e45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,9 @@ classifiers=[ "Intended Audience :: Science/Research", "Intended Audience :: Developers", ] -dynamic = ["readme", "version", "requires-python", "scripts", "dependencies"] +dynamic = [ + "readme", "version", "requires-python", "scripts", "dependencies", "entry-points" +] [build-system] requires = ["setuptools"]