Skip to content

Commit

Permalink
ruffen
Browse files Browse the repository at this point in the history
  • Loading branch information
jpn-- committed Oct 28, 2024
1 parent 51d8f66 commit fc5df6c
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 10 deletions.
31 changes: 22 additions & 9 deletions sharrow/relationships.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import xarray as xr

from .dataset import Dataset, construct
from .tree_branch import DataTreeBranch, CachedTree
from .tree_branch import CachedTree, DataTreeBranch

try:
from dask.array import Array as dask_array_type
Expand Down Expand Up @@ -946,13 +946,13 @@ def get_expr(
.isel(expressions=0)
)
elif engine == "numexpr":
from xarray import DataArray
import numexpr as ne
from xarray import DataArray

try:
result = (DataArray(
result = DataArray(
ne.evaluate(expression, local_dict=CachedTree(self)),
))
)
except Exception:
if dtype is None:
dtype = "float32"
Expand All @@ -970,11 +970,18 @@ def get_expr(
self._eval_cache = {}
try:
result = DataArray(
pd.eval(expression, resolvers=[self], engine="numexpr", parser=parser),
pd.eval(
expression,
resolvers=[self],
engine="numexpr",
parser=parser,
),
).astype(dtype)
except NotImplementedError:
result = DataArray(
pd.eval(expression, resolvers=[self], engine="python", parser=parser),
pd.eval(
expression, resolvers=[self], engine="python", parser=parser
),
).astype(dtype)
else:
# numexpr doesn't carry over the dimension names or coords
Expand All @@ -991,7 +998,9 @@ def get_expr(
self._eval_cache = {}
try:
result = DataArray(
pd.eval(expression, resolvers=[self], engine="python", parser=parser),
pd.eval(
expression, resolvers=[self], engine="python", parser=parser
),
).astype(dtype)
finally:
del self._eval_cache
Expand Down Expand Up @@ -1039,11 +1048,15 @@ def eval(
expression = int(expression)
if isinstance(expression, Number):
this_shape = [self.root_dataset.sizes.get(i) for i in self.root_dims]
result = xr.DataArray(np.broadcast_to(expression, this_shape), dims=self.root_dims)
result = xr.DataArray(
np.broadcast_to(expression, this_shape), dims=self.root_dims
)
expression = str(expression)
else:
if not isinstance(expression, str):
raise TypeError(f"expression must be a string, not a {type(expression)}")
raise TypeError(
f"expression must be a string, not a {type(expression)}"
)
if engine is None:
try:
result = self.get_expr(
Expand Down
1 change: 0 additions & 1 deletion sharrow/tree_branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,3 @@ def __getitem__(self, item):
if x is None:
x = self._cache[item] = self._tree[item]
return x

0 comments on commit fc5df6c

Please sign in to comment.