Skip to content

Commit

Permalink
slice_dataset now casts --sel "start/stop" flags to string (to agree …
Browse files Browse the repository at this point in the history
…with API documentation) and "step" to int (because step must be an int in xarray).

PiperOrigin-RevId: 707620263
  • Loading branch information
langmore authored and Weatherbench2 authors committed Dec 18, 2024
1 parent 75fb2b9 commit 2f849ab
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 15 deletions.
36 changes: 23 additions & 13 deletions scripts/slice_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@
help=(
'Selection criteria, to pass to xarray.Dataset.sel. Passed as'
' key=value pairs, with key = VARNAME_{start,stop,step,list}. '
'If key ends with start, stop, or step, the value should be strings '
'(defaulting to None). If key ends with "list", the value should be '
'If key ends with start, stop, or step, the values are used in a slice '
'as slice(str(start), str(stop), int(step)). start/stop/step default to'
' None. If key ends with "list", the value should be '
'a list of "+" delimited ints/floats/strings.'
),
)
Expand All @@ -84,8 +85,9 @@
help=(
'Selection criteria, to pass to xarray.Dataset.drop_sel. Passed as'
' key=value pairs, with key = VARNAME_{start,stop,step,list}. '
'If key ends with start, stop, or step, the value should be strings '
'(defaulting to None). If key ends with "list", the value should be '
'If key ends with start, stop, or step, the values are used in a slice '
'as slice(str(start), str(stop), int(step)). start/stop/step default to'
' None. If key ends with "list", the value should be '
'a list of "+" delimited ints/floats/strings.'
),
)
Expand Down Expand Up @@ -138,9 +140,15 @@

def _get_selections(
flag_values: dict[str, flag_utils.DimValueType],
is_sel_or_dropsel: bool,
) -> list[dict[str, t.Union[str, int, list[int], slice]]]:
"""Gets parts used to select based on flags."""

def maybe_tostr(v):
if is_sel_or_dropsel:
return str(v)
return v

list_selectors = {}
value_selectors = {}
for k, v in flag_values.items():
Expand All @@ -157,18 +165,20 @@ def _get_selections(
if '++' in v:
raise ValueError(f'Found ambiguous "++" in {dim=} flag value {v}')
list_selectors[dim] = [
flag_utils.get_dim_value(v_i) for v_i in v.split('+')
maybe_tostr(flag_utils.get_dim_value(v_i)) for v_i in v.split('+')
]
else: # Else handle non-list types
v = flag_utils.get_dim_value(v)
if dim not in value_selectors:
value_selectors[dim] = [None, None, None]
if placement == 'start':
value_selectors[dim][0] = v
value_selectors[dim][0] = maybe_tostr(v)
elif placement == 'stop':
value_selectors[dim][1] = v
else:
value_selectors[dim][2] = v
value_selectors[dim][1] = maybe_tostr(v)
else: # Else 'step'
# In Xarray, step must be an int.
# https://github.com/pydata/xarray/issues/5228
value_selectors[dim][2] = int(v)

selections = []
for dim, selector in list_selectors.items():
Expand All @@ -191,13 +201,13 @@ def main(argv: abc.Sequence[str]) -> None:
ds = ds[KEEP_VARIABLES.value]
input_chunks = {k: v for k, v in input_chunks.items() if k in ds.dims}

for selection in _get_selections(ISEL.value):
for selection in _get_selections(ISEL.value, is_sel_or_dropsel=False):
ds = ds.isel(selection)
for selection in _get_selections(SEL.value):
for selection in _get_selections(SEL.value, is_sel_or_dropsel=True):
ds = ds.sel(selection)
for selection in _get_selections(DROP_ISEL.value):
for selection in _get_selections(DROP_ISEL.value, is_sel_or_dropsel=False):
ds = ds.drop_isel(selection)
for selection in _get_selections(DROP_SEL.value):
for selection in _get_selections(DROP_SEL.value, is_sel_or_dropsel=True):
ds = ds.drop_sel(selection)

template = xbeam.make_template(ds)
Expand Down
27 changes: 25 additions & 2 deletions scripts/slice_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,39 @@ def test_valid_selections(self):
flag_values={
'A_start': '1 day',
'A_stop': '10 days',
'A_step': '2 days',
'A_step': 2,
'B_stop': 2.2,
'C_step': 3,
'D_list': 'planes+trains+automobiles',
},
is_sel_or_dropsel=False,
)
expected_sel = [
{'A': slice('1 day', '10 days', '2 days')},
{'A': slice('1 day', '10 days', 2)},
{'B': slice(None, 2.2, None)},
{'C': slice(None, None, 3)},
{'D': ['planes', 'trains', 'automobiles']},
]
self.assertCountEqual(expected_sel, sel)

def test_valid_selections_is_sel_or_dropsel(self):
sel = slice_dataset._get_selections(
flag_values={
'A_start': '1 day',
'A_stop': '10 days',
'A_step': 2,
'B_stop': 2020, # As in the year 2020 for a date
'D_list': 'planes+trains+automobiles',
},
is_sel_or_dropsel=True,
)
expected_sel = [
{'A': slice('1 day', '10 days', 2)},
{'B': slice(None, '2020', None)},
{'D': ['planes', 'trains', 'automobiles']},
]
self.assertCountEqual(expected_sel, sel)

def test_valid_index_selections(self):
isel = slice_dataset._get_selections(
flag_values={
Expand All @@ -56,6 +75,7 @@ def test_valid_index_selections(self):
'Z_start': 1,
'W_step': 2,
},
is_sel_or_dropsel=False,
)
expected_isel = [
{'A': [9, -1, 0]},
Expand All @@ -76,6 +96,7 @@ def test_invalid_placement_raises(self):
'X_stop': 10,
'X_bad': 2,
},
is_sel_or_dropsel=False,
)

with self.subTest('Not ending in (start|stop|step|list) raises 2'):
Expand All @@ -86,6 +107,7 @@ def test_invalid_placement_raises(self):
'X_stop': 10,
'X_step_and_more': 2,
},
is_sel_or_dropsel=False,
)

with self.subTest('Not ending in (start|stop|step|list) raises 2'):
Expand All @@ -96,6 +118,7 @@ def test_invalid_placement_raises(self):
'X_stop': 10,
'X_step_': 2,
},
is_sel_or_dropsel=False,
)


Expand Down

0 comments on commit 2f849ab

Please sign in to comment.