Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More options for eval engine #68

Merged
merged 10 commits into from
Oct 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
263 changes: 252 additions & 11 deletions sharrow/relationships.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import ast
import logging
import warnings
from collections.abc import Mapping, Sequence
from typing import Literal

import networkx as nx
import numpy as np
import pandas as pd
import xarray as xr

from .dataset import Dataset, construct
from .tree_branch import DataTreeBranch

try:
from dask.array import Array as dask_array_type
Expand Down Expand Up @@ -69,7 +72,10 @@ def _ixname():
return f"index{inum}"

for k, v in idxs.items():
loaders[k] = xr.DataArray(v, dims=[_ixname() for n in range(v.ndim)])
if isinstance(v, xr.DataArray):
loaders[k] = v
else:
loaders[k] = xr.DataArray(v, dims=[_ixname() for n in range(v.ndim)])
if _names:
ds = source[_names]
else:
Expand All @@ -91,7 +97,10 @@ def _ixname():
return f"index{inum}"

for k, v in idxs.items():
loaders[k] = xr.DataArray(v, dims=[_ixname() for n in range(v.ndim)])
if isinstance(v, xr.DataArray):
loaders[k] = v
else:
loaders[k] = xr.DataArray(v, dims=[_ixname() for n in range(v.ndim)])
if _names:
ds = source[_names]
else:
Expand Down Expand Up @@ -575,8 +584,6 @@ def add_dataset(self, name, dataset, relationships=(), as_root=False):
self.digitize_relationships(inplace=True)

def add_items(self, items):
from collections.abc import Mapping, Sequence

if isinstance(items, Sequence):
for i in items:
self.add_items(i)
Expand Down Expand Up @@ -621,7 +628,15 @@ def _get_relationship(self, edge):
)

def __getitem__(self, item):
return self.get(item)
if hasattr(self, "_eval_cache") and item in self._eval_cache:
return self._eval_cache[item]
try:
return self.get(item)
except KeyError as err:
s = self._graph.nodes.get(item, {}).get("dataset", None)
if s is not None:
return DataTreeBranch(self, item)
raise err

def get(self, item, default=None, broadcast=True, coords=True):
"""
Expand Down Expand Up @@ -687,6 +702,11 @@ def get(self, item, default=None, broadcast=True, coords=True):
add_coords[i] = base_dataset.coords[i]
if add_coords:
result = result.assign_coords(add_coords)
if broadcast:
if self.dim_order is None:
result = result.transpose(*self.root_dims)
else:
result = result.transpose(*self.dim_order)
return result

def finditem(self, item, maybe_in=None):
Expand Down Expand Up @@ -828,6 +848,32 @@ def _getitem(
_positions[r.child_name] = _idx
if top_dim_name is not None:
top_dim_names[r.child_name] = top_dim_name
if len(top_dim_names) > 1:
if len(set(top_dim_names.values())) == 1:
# capture the situation where all top dims are the same
_positions = {
k: xr.DataArray(v, dims=[top_dim_names[k]])
for (k, v) in _positions.items()
}
_labels = {
k: xr.DataArray(v, dims=[top_dim_names[k]])
for (k, v) in _labels.items()
}
# the top dim names have served their purpose, so clear them
top_dim_names = {}
elif len(set(top_dim_names.values())) < len(top_dim_names):
# capture the situation where some but not all top dims are the same
# same as above?
_positions = {
k: xr.DataArray(v, dims=[top_dim_names[k]])
for (k, v) in _positions.items()
}
_labels = {
k: xr.DataArray(v, dims=[top_dim_names[k]])
for (k, v) in _labels.items()
}
# the top dim names have served their purpose, so clear them
top_dim_names = {}
y = xgather(result, _positions, _labels)
if len(result.dims) == 1 and len(y.dims) == 1:
y = y.rename({y.dims[0]: result.dims[0]})
Expand All @@ -844,19 +890,34 @@ def _getitem(

raise KeyError(item)

def get_expr(self, expression, engine="sharrow", allow_native=True):
def get_expr(
self,
expression,
engine="sharrow",
allow_native=True,
*,
dtype="float32",
with_coords: bool = True,
):
"""
Access or evaluate an expression.

Parameters
----------
expression : str
engine : {'sharrow', 'numexpr'}
engine : {'sharrow', 'numexpr', 'python'}
The engine used to resolve expressions.
allow_native : bool, default True
If the expression is an array in a dataset of this tree, return
that array directly. Set to false to force evaluation, which
will also ensure proper broadcasting consistent with this data tree.
dtype : str or dtype, default 'float32'
The dtype to use when creating new arrays. This only applies when
the expression is not returned as a native variable from the tree.
with_coords : bool, default True
Attach coordinates from the root node of the tree to the result.
If the coordinates are not needed in the result, the process
of attaching them can be skipped.

Returns
-------
Expand All @@ -869,21 +930,185 @@ def get_expr(self, expression, engine="sharrow", allow_native=True):
raise KeyError
except (KeyError, IndexError):
if engine == "sharrow":
if dtype is None:
dtype = "float32"
result = (
self.setup_flow({expression: expression})
self.setup_flow({expression: expression}, dtype=dtype)
.load_dataarray()
.isel(expressions=0)
)
elif engine == "numexpr":
from xarray import DataArray

result = DataArray(
pd.eval(expression, resolvers=[self], engine="numexpr"),
)
self._eval_cache = {}
try:
result = DataArray(
pd.eval(expression, resolvers=[self], engine="numexpr"),
).astype(dtype)
except NotImplementedError:
result = DataArray(
pd.eval(expression, resolvers=[self], engine="python"),
).astype(dtype)
else:
# numexpr doesn't carry over the dimension names or coords
result = result.rename(
{result.dims[i]: self.root_dims[i] for i in range(result.ndim)}
)
if with_coords:
result = result.assign_coords(self.root_dataset.coords)
finally:
del self._eval_cache
elif engine == "python":
from xarray import DataArray

self._eval_cache = {}
try:
result = DataArray(
pd.eval(expression, resolvers=[self], engine="python"),
).astype(dtype)
finally:
del self._eval_cache
else:
raise ValueError(f"unknown engine {engine}") from None
return result

def eval(
self,
expression: str,
engine: Literal[None, "numexpr", "sharrow", "python"] = None,
*,
dtype: np.dtype | str | None = None,
name: str | None = None,
with_coords: bool = True,
):
"""
Evaluate an expression.

The resulting DataArray will have dimensions that match the root
Dataset of this tree, and the content will be broadcast to those
dimensions if necessary. The expression evaluated will be assigned
as a scalar coordinate named 'expressions', to facilitate concatenation
with other `eval` results if desired.

Parameters
----------
expression : str
engine : {None, 'numexpr', 'sharrow', 'python'}
The engine used to resolve expressions. If None, the default is
to try 'numexpr' first, then 'sharrow' if that fails.
dtype : str or dtype, optional
The dtype to use for the result. If the engine is `sharrow` and
no value is given, this will default to `float32`, otherwise the
default is to use the dtype of the result of the expression.
name : str, optional
The name to give the resulting DataArray.

Returns
-------
DataArray
"""
if not isinstance(expression, str):
raise TypeError("expression must be a string")
if engine is None:
try:
result = self.get_expr(
expression,
"numexpr",
allow_native=False,
dtype=dtype,
with_coords=with_coords,
)
except Exception:
result = self.get_expr(
expression,
"sharrow",
allow_native=False,
dtype=dtype,
with_coords=with_coords,
)
else:
result = self.get_expr(
expression,
engine,
allow_native=False,
dtype=dtype,
with_coords=with_coords,
)
if with_coords and "expressions" not in result.coords:
# add the expression as a scalar coordinate (with no dimension)
result = result.assign_coords(expressions=xr.DataArray(expression))
if name is not None:
result.name = name
return result

def eval_many(
self,
expressions: Sequence[str] | Mapping[str, str] | pd.Series,
*,
engine: Literal[None, "numexpr", "sharrow", "python"] = None,
dtype=None,
result_type: Literal["dataset", "dataarray"] = "dataset",
with_coords: bool = True,
) -> xr.Dataset | xr.DataArray:
"""
Evaluate multiple expressions.

Parameters
----------
expressions : Sequence[str] or Mapping[str,str] or pd.Series
The expressions to evaluate. If a sequence, the names of the
resulting DataArrays will be the same as the expressions. If a
mapping or Series, the keys or index will be used as the names.
engine : {None, 'numexpr', 'sharrow', 'python'}
The engine used to resolve expressions. If None, the default is to
try 'numexpr' first, then 'sharrow' if that fails.
dtype : str or dtype, optional
The dtype to use for the result. If the engine is `sharrow` and
no value is given, this will default to `float32`, otherwise the
default is to use the dtype of the result of the concatenation of
the expressions.
result_type : {'dataset', 'dataarray'}
Whether to return a Dataset (with a variable for each expression)
or a DataArray (with a dimension across all expressions).

Returns
-------
Dataset or DataArray
"""
if result_type not in {"dataset", "dataarray"}:
raise ValueError("result_type must be one of ['dataset', 'dataarray']")
if not isinstance(expressions, (Mapping, pd.Series)):
expressions = pd.Series(expressions, index=expressions)
if isinstance(expressions, Mapping):
expressions = pd.Series(expressions)
if result_type == "dataset":
arrays = {}
for k, v in expressions.items():
a = self.eval(
v, engine=engine, dtype=dtype, name=k, with_coords=with_coords
)
if "expressions" in a.coords:
a = a.drop_vars("expressions")
arrays[k] = a.assign_attrs(expression=v)
result = xr.Dataset(arrays)
else:
arrays = {}
for k, v in expressions.items():
a = self.eval(
v, engine=engine, dtype=dtype, name=k, with_coords=with_coords
)
if "expressions" in a.coords:
a = a.drop_vars("expressions")
a = a.expand_dims("expressions", -1)
arrays[k] = a
result = xr.concat(list(arrays.values()), "expressions")
if with_coords:
result = result.assign_coords(
expressions=expressions.index,
source=xr.DataArray(expressions.values, dims="expressions"),
)
return result

@property
def subspaces(self):
"""Mapping[str,Dataset] : Direct access to node Dataset objects by name."""
Expand Down Expand Up @@ -1583,3 +1808,19 @@ def merged_dataset(self, columns=None, uniquify=False):
if coords:
result.assign_coords(coords)
return result

def __iter__(self):
"""Iterate over all the datasets."""
import itertools

if hasattr(self, "_eval_cache"):
z = (self._eval_cache,)
else:
z = ()
return itertools.chain(*z, *(v for k, v in self.subspaces_iter()))

def __setitem__(self, key, value):
if hasattr(self, "_eval_cache"):
self._eval_cache[key] = value
else:
raise NotImplementedError("setitem not supported")
Loading
Loading