diff --git a/api/py/ai/chronon/group_by.py b/api/py/ai/chronon/group_by.py index e4d9eca36..51f308eb6 100644 --- a/api/py/ai/chronon/group_by.py +++ b/api/py/ai/chronon/group_by.py @@ -33,12 +33,12 @@ def collector( - op: ttypes.Operation, -) -> Callable[[ttypes.Operation], Tuple[ttypes.Operation, Dict[str, str]]]: + op: int, +) -> Callable[[int], Tuple[int, Dict[str, str]]]: return lambda k: (op, {"k": str(k)}) -def generic_collector(op: ttypes.Operation, required, **kwargs): +def generic_collector(op: int, required, **kwargs): def _collector(*args, **other_args): arguments = kwargs.copy() if kwargs else {} for idx, arg in enumerate(required): @@ -123,8 +123,8 @@ def op_to_str(operation: OperationType): # See docs/Aggregations.md def Aggregation( - input_column: str = None, - operation: Union[ttypes.Operation, Tuple[ttypes.Operation, Dict[str, str]]] = None, + input_column: Optional[str] = None, + operation: Optional[Union[int, Tuple[int, Dict[str, str]]]] = None, windows: Optional[List[ttypes.Window]] = None, buckets: Optional[List[str]] = None, tags: Optional[Dict[str, str]] = None, @@ -154,12 +154,14 @@ def Aggregation( arg_map = {} if isinstance(operation, tuple): operation, arg_map = operation[0], operation[1] + assert utils.is_valid_ttype_enum_value(operation, ttypes.Operation), f"Invalid operation: {operation}" agg = ttypes.Aggregation(input_column, operation, arg_map, windows, buckets) agg.tags = tags return agg -def Window(length: int, timeUnit: ttypes.TimeUnit) -> ttypes.Window: +def Window(length: int, timeUnit: int) -> ttypes.Window: + assert utils.is_valid_ttype_enum_value(timeUnit, ttypes.TimeUnit), f"Invalid timeUnit: {timeUnit}" return ttypes.Window(length, timeUnit) @@ -342,7 +344,7 @@ def GroupBy( env: Optional[Dict[str, Dict[str, str]]] = None, table_properties: Optional[Dict[str, str]] = None, output_namespace: Optional[str] = None, - accuracy: Optional[ttypes.Accuracy] = None, + accuracy: Optional[int] = None, lag: int = 0, offline_schedule: str = "@daily", name: Optional[str] = None, @@ -471,6 +473,8 @@ def GroupBy( A GroupBy object containing specified aggregations. """ assert sources, "Sources are not specified" + if accuracy is not None: + assert utils.is_valid_ttype_enum_value(accuracy, ttypes.Accuracy), f"Invalid accuracy: {accuracy}" agg_inputs = [] if aggregations is not None: diff --git a/api/py/ai/chronon/utils.py b/api/py/ai/chronon/utils.py index 870231a45..6305612cd 100644 --- a/api/py/ai/chronon/utils.py +++ b/api/py/ai/chronon/utils.py @@ -23,7 +23,7 @@ from collections.abc import Iterable from dataclasses import dataclass, fields from enum import Enum -from typing import Dict, List, Optional, Union, cast +from typing import Any, Dict, List, Optional, Union, cast import ai.chronon.api.ttypes as api import ai.chronon.repo.extract_objects as eo @@ -575,3 +575,9 @@ def get_config_path(join_name: str) -> str: assert "." in join_name, f"Invalid join name: {join_name}" team_name, config_name = join_name.split(".", 1) return f"production/joins/{team_name}/{config_name}" + + +def is_valid_ttype_enum_value(value: int, enum_type: Any) -> bool: + """Validates that an integer value is valid for a Thrift enum type.""" + assert hasattr(enum_type, '_VALUES_TO_NAMES'), f"enum_type {enum_type} is not a valid Thrift enum type" + return value in enum_type._VALUES_TO_NAMES diff --git a/api/py/test/test_utils.py b/api/py/test/test_utils.py index 3716542cb..94ff225fd 100644 --- a/api/py/test/test_utils.py +++ b/api/py/test/test_utils.py @@ -390,3 +390,11 @@ def test_get_dependencies_with_events(query_with_partition_column: Query): partition_col = query_with_partition_column.partitionColumn or "ds" expected_spec = f"event_table/{partition_col}={{{{ macros.ds_add(ds, -2) }}}}" assert dep["spec"] == expected_spec + +def test_is_valid_ttype_enum_value_valid(): + assert utils.is_valid_ttype_enum_value(api.Accuracy.TEMPORAL, api.Accuracy) + assert utils.is_valid_ttype_enum_value(api.Accuracy.SNAPSHOT, api.Accuracy) + +def test_is_valid_ttype_enum_value_invalid(): + assert not utils.is_valid_ttype_enum_value(-1, api.Accuracy) + assert not utils.is_valid_ttype_enum_value(100, api.Accuracy)