Skip to content

Commit

Permalink
Eliminate _check_dimension and _get_time_dim_or_default
Browse files Browse the repository at this point in the history
generalize validation logic into `ProcessArgs`
  • Loading branch information
soxofaan committed Aug 4, 2023
1 parent aad903d commit d0c34f4
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 75 deletions.
69 changes: 23 additions & 46 deletions openeo_driver/ProcessGraphDeserializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,11 +679,11 @@ def apply_neighborhood(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
def apply_dimension(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
data_cube = args.get_required("data", expected_type=DriverDataCube)
process = args.get_deep("process", "process_graph", expected_type=dict)
dimension = args.get_required("dimension", expected_type=str)
dimension = args.get_required(
"dimension", expected_type=str, validator=ProcessArgs.validator_one_of(data_cube.metadata.dimension_names())
)
target_dimension = args.get_optional("target_dimension", default=None, expected_type=str)
context = args.get_optional("context", default=None)
# do check_dimension here for error handling
dimension = _check_dimension(cube=data_cube, dim=dimension, process="apply_dimension")

cube = data_cube.apply_dimension(
process=process, dimension=dimension, target_dimension=target_dimension, context=context, env=env
Expand Down Expand Up @@ -747,10 +747,10 @@ def apply(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
def reduce_dimension(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
data_cube: DriverDataCube = args.get_required("data", expected_type=DriverDataCube)
reduce_pg = args.get_deep("reducer", "process_graph", expected_type=dict)
dimension = args.get_required("dimension", expected_type=str)
dimension = args.get_required(
"dimension", expected_type=str, validator=ProcessArgs.validator_one_of(data_cube.metadata.dimension_names())
)
context = args.get_optional("context", default=None)
# do check_dimension here for error handling
dimension = _check_dimension(cube=data_cube, dim=dimension, process="reduce_dimension")
return data_cube.reduce_dimension(reducer=reduce_pg, dimension=dimension, context=context, env=env)


Expand Down Expand Up @@ -915,40 +915,35 @@ def rename_labels(args: dict, env: EvalEnv) -> DriverDataCube:
)


def _check_dimension(cube: DriverDataCube, dim: str, process: str) -> str:
"""
Helper to check/validate the requested and available dimensions of a cube.
:return: tuple (requested dimension, name of band dimension, name of temporal dimension)
"""
metadata = cube.metadata

if dim not in metadata.dimension_names():
raise ProcessParameterInvalidException(
parameter="dimension", process=process,
reason="got {d!r}, but should be one of {n!r}".format(d=dim, n=metadata.dimension_names()))

return dim


@process
def aggregate_temporal(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
data_cube = args.get_required("data", expected_type=DriverDataCube)
reduce_pg = args.get_deep("reducer", "process_graph", expected_type=dict)
context = args.get_optional("context", default=None)
intervals = args.get_required("intervals")
reduce_pg = args.get_deep("reducer", "process_graph", expected_type=dict)
labels = args.get_optional("labels", default=None)
dimension = _get_time_dim_or_default(args, data_cube)
return data_cube.aggregate_temporal(intervals=intervals,labels=labels,reducer=reduce_pg, dimension=dimension, context=context)
dimension = args.get_optional(
"dimension",
default=lambda: data_cube.metadata.temporal_dimension.name,
validator=ProcessArgs.validator_one_of(data_cube.metadata.dimension_names()),
)
context = args.get_optional("context", default=None)

return data_cube.aggregate_temporal(
intervals=intervals, labels=labels, reducer=reduce_pg, dimension=dimension, context=context
)


@process_registry_100.add_function
def aggregate_temporal_period(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
data_cube = args.get_required("data", expected_type=DriverDataCube)
period = args.get_required("period")
reduce_pg = args.get_deep("reducer", "process_graph", expected_type=dict)
dimension = args.get_optional(
"dimension",
default=lambda: data_cube.metadata.temporal_dimension.name,
validator=ProcessArgs.validator_one_of(data_cube.metadata.dimension_names()),
)
context = args.get_optional("context", default=None)
period = args.get_required("period")
dimension = _get_time_dim_or_default(args, data_cube, "aggregate_temporal_period")

dry_run_tracer: DryRunDataTracer = env.get(ENV_DRY_RUN_TRACER)
if dry_run_tracer:
Expand Down Expand Up @@ -1025,24 +1020,6 @@ def _period_to_intervals(start, end, period) -> List[Tuple[pd.Timestamp, pd.Time
return intervals


def _get_time_dim_or_default(args: ProcessArgs, data_cube, process_id="aggregate_temporal"):
dimension = args.get_optional("dimension", None)
if dimension is not None:
dimension = _check_dimension(cube=data_cube, dim=dimension, process=process_id)
else:
# default: there is a single temporal dimension
try:
dimension = data_cube.metadata.temporal_dimension.name
except MetadataException:
raise ProcessParameterInvalidException(
parameter="dimension", process=process_id,
reason="No dimension was set, and no temporal dimension could be found. Available dimensions: {n!r}".format(
n=data_cube.metadata.dimension_names()))
# do check_dimension here for error handling
dimension = _check_dimension(cube=data_cube, dim=dimension, process=process_id)
return dimension


@process_registry_100.add_function
def aggregate_spatial(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
cube = args.get_required("data", expected_type=DriverDataCube)
Expand Down
71 changes: 63 additions & 8 deletions openeo_driver/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,11 @@ def cast(cls, args: Union[dict, "ProcessArgs"], process_id: Optional[str] = None
return args

def get_required(
self, name: str, *, expected_type: Optional[Union[type, Tuple[type, ...]]] = None
self,
name: str,
*,
expected_type: Optional[Union[type, Tuple[type, ...]]] = None,
validator: Optional[Callable[[Any], bool]] = None,
) -> ArgumentValue:
"""
Get a required argument by name.
Expand All @@ -301,33 +305,69 @@ def get_required(
value = self[name]
except KeyError:
raise ProcessParameterRequiredException(process=self.process_id, parameter=name) from None
self._check_type(name=name, value=value, expected_type=expected_type)
self._check_value(name=name, value=value, expected_type=expected_type, validator=validator)
return value

def _check_type(self, *, name: str, value: Any, expected_type: Optional[Union[type, Tuple[type, ...]]] = None):
def _check_value(
self,
*,
name: str,
value: Any,
expected_type: Optional[Union[type, Tuple[type, ...]]] = None,
validator: Optional[Callable[[Any], bool]] = None,
):
if expected_type:
if not isinstance(value, expected_type):
raise ProcessParameterInvalidException(
parameter=name, process=self.process_id, reason=f"Expected {expected_type} but got {type(value)}."
)
if validator:
try:
valid = validator(value)
reason = "Failed validation."
except Exception as e:
valid = False
reason = str(e)
if not valid:
raise ProcessParameterInvalidException(parameter=name, process=self.process_id, reason=reason)

def get_optional(
self, name: str, default: Any = None, *, expected_type: Optional[Union[type, Tuple[type, ...]]] = None
self,
name: str,
default: Union[Any, Callable[[], Any]] = None,
*,
expected_type: Optional[Union[type, Tuple[type, ...]]] = None,
validator: Optional[Callable[[Any], bool]] = None,
) -> ArgumentValue:
"""
Get an optional argument with default
:param name: argument name
:param default: default value or a function/factory to generate the default value
:param expected_type: expected class (or list of multiple options) the value should be (unless it's None)
:param validator: optional validation callable
"""
value = self.get(name, default)
if name in self:
value = self.get(name)
else:
value = default() if callable(default) else default
if value is not None:
self._check_type(name=name, value=value, expected_type=expected_type)
self._check_value(name=name, value=value, expected_type=expected_type, validator=validator)

return value

def get_deep(self, *steps: str, expected_type: Optional[Union[type, Tuple[type, ...]]] = None) -> ArgumentValue:
def get_deep(
self,
*steps: str,
expected_type: Optional[Union[type, Tuple[type, ...]]] = None,
validator: Optional[Callable[[Any], bool]] = None,
) -> ArgumentValue:
"""
Walk recursively through a dictionary to get to a value.
Originally: `extract_deep`
"""
# TODO: current implementation requires the argument. Allow it to be optional too?
value = self
for step in steps:
keys = [step] if not isinstance(step, list) else step
Expand All @@ -338,7 +378,7 @@ def get_deep(self, *steps: str, expected_type: Optional[Union[type, Tuple[type,
else:
raise ProcessParameterInvalidException(process=self.process_id, parameter=steps[0], reason=f"{step=}")

self._check_type(name=steps[0], value=value, expected_type=expected_type)
self._check_value(name=steps[0], value=value, expected_type=expected_type, validator=validator)
return value

def get_aliased(self, names: List[str]) -> ArgumentValue:
Expand Down Expand Up @@ -385,3 +425,18 @@ def get_enum(self, name: str, options: typing.Container[ArgumentValue]) -> Argum
reason=f"Invalid enum value {value!r}. Expected one of {options}.",
)
return value

@staticmethod
def validator_one_of(options: list, show_value: bool = True):
"""Build a validator function that check that the value is in given list"""

def validator(value):
if value not in options:
if show_value:
message = f"Must be one of {options!r} but got {value!r}."
else:
message = f"Must be one of {options!r}."
raise ValueError(message)
return True

return validator
71 changes: 70 additions & 1 deletion tests/test_processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,13 +461,55 @@ def test_get_required_with_type(self):
):
_ = args.get_required("color", expected_type=DriverDataCube)

def test_get_required_with_validator(self):
args = ProcessArgs({"color": "red", "size": 5}, process_id="wibble")
assert args.get_required("color", expected_type=str, validator=lambda v: len(v) == 3) == "red"
assert (
args.get_required(
"color", expected_type=str, validator=ProcessArgs.validator_one_of(["red", "green", "blue"])
)
== "red"
)
assert args.get_required("size", expected_type=int, validator=lambda v: v % 3 == 2) == 5
with pytest.raises(
ProcessParameterInvalidException,
match=re.escape(
"The value passed for parameter 'color' in process 'wibble' is invalid: Failed validation."
),
):
_ = args.get_required("color", expected_type=str, validator=lambda v: len(v) == 10)
with pytest.raises(
ProcessParameterInvalidException,
match=re.escape("The value passed for parameter 'size' in process 'wibble' is invalid: Failed validation."),
):
_ = args.get_required("size", expected_type=int, validator=lambda v: v % 3 == 1)
with pytest.raises(
ProcessParameterInvalidException,
match=re.escape(
"The value passed for parameter 'color' in process 'wibble' is invalid: Must be one of ['yellow', 'violet'] but got 'red'."
),
):
_ = args.get_required(
"color", expected_type=str, validator=ProcessArgs.validator_one_of(["yellow", "violet"])
)

def test_get_optional(self):
args = ProcessArgs({"foo": "bar"}, process_id="wibble")
assert args.get_optional("foo") == "bar"
assert args.get_optional("other") is None
assert args.get_optional("foo", 123) == "bar"
assert args.get_optional("other", 123) == 123

def test_get_optional_callable_default(self):
args = ProcessArgs({"foo": "bar"}, process_id="wibble")
assert args.get_optional("foo", default=lambda: 123) == "bar"
assert args.get_optional("other", default=lambda: 123) == 123

# Possible, but probably a bad idea:
default = [1, 2, 3].pop
assert args.get_optional("other", default=default) == 3
assert args.get_optional("other", default=default) == 2

def test_get_optional_with_type(self):
args = ProcessArgs({"foo": "bar"}, process_id="wibble")
assert args.get_optional("foo", expected_type=str) == "bar"
Expand All @@ -480,7 +522,24 @@ def test_get_optional_with_type(self):
"The value passed for parameter 'foo' in process 'wibble' is invalid: Expected <class 'openeo_driver.datacube.DriverDataCube'> but got <class 'str'>."
),
):
_ = args.get_required("foo", expected_type=DriverDataCube)
_ = args.get_optional("foo", expected_type=DriverDataCube)

def test_get_optional_with_validator(self):
args = ProcessArgs({"foo": "bar"}, process_id="wibble")
assert args.get_optional("foo", validator=lambda s: all(c.lower() for c in s)) == "bar"
assert args.get_optional("foo", validator=ProcessArgs.validator_one_of(["bar", "meh"])) == "bar"
with pytest.raises(
ProcessParameterInvalidException,
match=re.escape("The value passed for parameter 'foo' in process 'wibble' is invalid: Failed validation."),
):
_ = args.get_optional("foo", validator=lambda s: all(c.isupper() for c in s))
with pytest.raises(
ProcessParameterInvalidException,
match=re.escape(
"The value passed for parameter 'foo' in process 'wibble' is invalid: Must be one of ['nope', 'meh'] but got 'bar'."
),
):
_ = args.get_optional("foo", validator=ProcessArgs.validator_one_of(["nope", "meh"]))

def test_get_deep(self):
args = ProcessArgs({"foo": {"bar": {"color": "red", "size": {"x": 5, "y": 8}}}}, process_id="wibble")
Expand Down Expand Up @@ -508,6 +567,16 @@ def test_get_deep_with_type(self):
):
_ = args.get_deep("foo", "bar", "size", "x", expected_type=(DriverDataCube, str))

def test_get_deep_with_validator(self):
args = ProcessArgs({"foo": {"bar": {"color": "red", "size": {"x": 5, "y": 8}}}}, process_id="wibble")
assert args.get_deep("foo", "bar", "size", "x", validator=lambda v: v % 5 == 0) == 5

with pytest.raises(
ProcessParameterInvalidException,
match=re.escape("The value passed for parameter 'foo' in process 'wibble' is invalid: Failed validation."),
):
_ = args.get_deep("foo", "bar", "size", "y", validator=lambda v: v % 5 == 0)

def test_get_aliased(self):
args = ProcessArgs({"size": 5, "color": "red"}, process_id="wibble")
assert args.get_aliased(["size", "dimensions"]) == 5
Expand Down
Loading

0 comments on commit d0c34f4

Please sign in to comment.