Skip to content

Commit

Permalink
Raw numexpr engine (#70)
Browse files Browse the repository at this point in the history
* ValueError on no expressions

* allow number as expression in eval

* broadcast expressions

* raw numexpr engine

* ruffen

* numexpr dims

* fix for when passing dtypes instead of dtype names

* ruffen

* fix subdtype

* try again

* explicitly include package data

* change to use importlib.resources

* ruffen

* ruffen

* hide test omx

* don't test in editable mode
  • Loading branch information
jpn-- authored Dec 3, 2024
1 parent 652da47 commit 52c594b
Show file tree
Hide file tree
Showing 9 changed files with 211 additions and 56 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ jobs:
auto-update-conda: false
- name: Install sharrow
run: |
python -m pip install -e .
python -m pip install .
- name: Conda checkup
run: |
conda info -a
Expand Down Expand Up @@ -137,7 +137,7 @@ jobs:
conda install jupyter-book ruamel.yaml sphinx-autosummary-accessors -c conda-forge
- name: Install sharrow
run: |
python -m pip install --no-deps -e .
python -m pip install --no-deps .
- name: Conda checkup
run: |
conda info -a
Expand Down
2 changes: 1 addition & 1 deletion envs/development.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ dependencies:
- filelock
- ruff
- jupyter
- larch>=5.7.1
- nbmake
- networkx
- notebook
Expand All @@ -29,4 +28,5 @@ dependencies:
- zarr

- pip:
- larch6
- -e ..
1 change: 1 addition & 0 deletions envs/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies:
- pytest-xdist
- nbmake
- openmatrix
- h5py
- zarr
- pip:
- larch6
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ Repository = "https://github.com/activitysim/sharrow"

[tool.setuptools]
packages = ["sharrow", "sharrow.utils"]
include-package-data = true

[tool.setuptools.package-data]
sharrow = ["*"]

[tool.setuptools_scm]
fallback_version = "1999"
Expand Down
2 changes: 1 addition & 1 deletion sharrow/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def from_table(
result = xr.Dataset()
if isinstance(index, pd.MultiIndex):
dims = tuple(
name if name is not None else "level_%i" % n
name if name is not None else f"level_{n}"
for n, name in enumerate(index.names)
)
for dim, lev in zip(dims, index.levels):
Expand Down
63 changes: 39 additions & 24 deletions sharrow/example_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from importlib.resources import as_file, files

import numpy as np
import pandas as pd
Expand All @@ -9,17 +10,14 @@ def get_skims_filename() -> str:
return os.path.join(os.path.dirname(__file__), "example_data", "skims.omx")


def get_skims():
def get_skims_omx():
import openmatrix

from . import dataset

zfilename = os.path.join(os.path.dirname(__file__), "example_data", "skims.zarr")
if os.path.exists(zfilename):
skims = dataset.from_zarr(zfilename, consolidated=False)
else:
filename = os.path.join(os.path.dirname(__file__), "example_data", "skims.omx")
with openmatrix.open_file(filename) as f:
with as_file(files("sharrow").joinpath("example_data/skims.omx")) as filename:
skims = None
with openmatrix.open_file(str(filename)) as f:
skims = dataset.from_omx_3d(
f,
index_names=("otaz", "dtaz", "time_period"),
Expand All @@ -28,39 +26,56 @@ def get_skims():
time_period_sep="__",
max_float_precision=32,
).compute()
skims.to_zarr(zfilename)
return skims


def get_skims_zarr():
from . import dataset

f = files("sharrow").joinpath("example_data/skims.zarr")
with as_file(f) as zfile:
if zfile.exists():
skims = dataset.from_zarr(zfile, consolidated=False)
else:
skims = None
return skims


def get_skims():
from . import dataset

f = files("sharrow").joinpath("example_data/skims.zarr")
with as_file(f) as zfile:
if zfile.exists():
skims = dataset.from_zarr(zfile, consolidated=False)
else:
skims = get_skims_omx()
return skims


def get_households():
filename = os.path.join(
os.path.dirname(__file__), "example_data", "households.csv.gz"
)
return pd.read_csv(filename, index_col="HHID")
with as_file(files("sharrow").joinpath("example_data/households.csv.gz")) as f:
return pd.read_csv(f, index_col="HHID")


def get_persons():
filename = os.path.join(os.path.dirname(__file__), "example_data", "persons.csv.gz")
return pd.read_csv(filename, index_col="PERID")
with as_file(files("sharrow").joinpath("example_data/persons.csv.gz")) as f:
return pd.read_csv(f, index_col="PERID")


def get_land_use():
filename = os.path.join(
os.path.dirname(__file__), "example_data", "land_use.csv.gz"
)
return pd.read_csv(filename, index_col="TAZ")
with as_file(files("sharrow").joinpath("example_data/land_use.csv.gz")) as f:
return pd.read_csv(f, index_col="TAZ")


def get_maz_to_taz():
filename = os.path.join(os.path.dirname(__file__), "example_data", "maz_to_taz.csv")
return pd.read_csv(filename, index_col="MAZ")
with as_file(files("sharrow").joinpath("example_data/maz_to_taz.csv")) as f:
return pd.read_csv(f, index_col="MAZ")


def get_maz_to_maz_walk():
filename = os.path.join(
os.path.dirname(__file__), "example_data", "maz_to_maz_walk.csv"
)
return pd.read_csv(filename)
with as_file(files("sharrow").joinpath("example_data/maz_to_maz_walk.csv")) as f:
return pd.read_csv(f)


def get_data():
Expand Down
117 changes: 89 additions & 28 deletions sharrow/relationships.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import warnings
from collections.abc import Mapping, Sequence
from numbers import Number
from typing import Literal

import networkx as nx
Expand All @@ -10,7 +11,7 @@
import xarray as xr

from .dataset import Dataset, construct
from .tree_branch import DataTreeBranch
from .tree_branch import CachedTree, DataTreeBranch

try:
from dask.array import Array as dask_array_type
Expand Down Expand Up @@ -898,15 +899,18 @@ def get_expr(
*,
dtype="float32",
with_coords: bool = True,
parser: Literal["pandas", "python"] = "pandas",
):
"""
Access or evaluate an expression.
Parameters
----------
expression : str
engine : {'sharrow', 'numexpr', 'python'}
The engine used to resolve expressions.
engine : {'sharrow', 'numexpr', 'python', 'pandas-numexpr'}
The engine used to resolve expressions. The numexpr engine uses
that library directly, while the pandas-numexpr engine uses the
pandas `eval` method with the numexpr engine.
allow_native : bool, default True
If the expression is an array in a dataset of this tree, return
that array directly. Set to false to force evaluation, which
Expand All @@ -918,11 +922,19 @@ def get_expr(
Attach coordinates from the root node of the tree to the result.
If the coordinates are not needed in the result, the process
of attaching them can be skipped.
parser : {'pandas', 'python'}
The parser to use when evaluating the expression. This argument
only applies to pandas-based engines ('python' and 'pandas-numexpr').
It is ignored when using the 'sharrow' or 'numexpr' engines.
Returns
-------
DataArray
"""
if np.issubdtype(dtype, np.number) and isinstance(dtype, type):
dtype = dtype.__name__
elif dtype is bool:
dtype = "bool"
try:
if allow_native:
result = self[expression]
Expand All @@ -938,16 +950,49 @@ def get_expr(
.isel(expressions=0)
)
elif engine == "numexpr":
import numexpr as ne
from xarray import DataArray

try:
result = DataArray(
ne.evaluate(expression, local_dict=CachedTree(self)),
)
except Exception:
if dtype is None:
dtype = "float32"
result = (
self.setup_flow({expression: expression}, dtype=dtype)
.load_dataarray()
.isel(expressions=0)
)
else:
if dtype is not None:
result = result.astype(dtype)
# numexpr doesn't carry over the dimension names or coords
result = result.rename(
{result.dims[i]: self.root_dims[i] for i in range(result.ndim)}
)
if with_coords:
result = result.assign_coords(self.root_dataset.coords)

elif engine == "pandas-numexpr":
from xarray import DataArray

self._eval_cache = {}
try:
result = DataArray(
pd.eval(expression, resolvers=[self], engine="numexpr"),
pd.eval(
expression,
resolvers=[self],
engine="numexpr",
parser=parser,
),
).astype(dtype)
except NotImplementedError:
result = DataArray(
pd.eval(expression, resolvers=[self], engine="python"),
pd.eval(
expression, resolvers=[self], engine="python", parser=parser
),
).astype(dtype)
else:
# numexpr doesn't carry over the dimension names or coords
Expand All @@ -964,7 +1009,9 @@ def get_expr(
self._eval_cache = {}
try:
result = DataArray(
pd.eval(expression, resolvers=[self], engine="python"),
pd.eval(
expression, resolvers=[self], engine="python", parser=parser
),
).astype(dtype)
finally:
del self._eval_cache
Expand All @@ -974,7 +1021,7 @@ def get_expr(

def eval(
self,
expression: str,
expression: str | Number,
engine: Literal[None, "numexpr", "sharrow", "python"] = None,
*,
dtype: np.dtype | str | None = None,
Expand All @@ -992,7 +1039,7 @@ def eval(
Parameters
----------
expression : str
expression : str | Number
engine : {None, 'numexpr', 'sharrow', 'python'}
The engine used to resolve expressions. If None, the default is
to try 'numexpr' first, then 'sharrow' if that fails.
Expand All @@ -1007,33 +1054,45 @@ def eval(
-------
DataArray
"""
if not isinstance(expression, str):
raise TypeError("expression must be a string")
if engine is None:
try:
result = self.get_expr(
expression,
"numexpr",
allow_native=False,
dtype=dtype,
with_coords=with_coords,
# when passing in a numeric value or boolean, simply broadcast it to the root dims
if isinstance(expression, bool):
expression = int(expression)
if isinstance(expression, Number):
this_shape = [self.root_dataset.sizes.get(i) for i in self.root_dims]
result = xr.DataArray(
np.broadcast_to(expression, this_shape), dims=self.root_dims
)
expression = str(expression)
else:
if not isinstance(expression, str):
raise TypeError(
f"expression must be a string, not a {type(expression)}"
)
except Exception:
if engine is None:
try:
result = self.get_expr(
expression,
"numexpr",
allow_native=False,
dtype=dtype,
with_coords=with_coords,
)
except Exception:
result = self.get_expr(
expression,
"sharrow",
allow_native=False,
dtype=dtype,
with_coords=with_coords,
)
else:
result = self.get_expr(
expression,
"sharrow",
engine,
allow_native=False,
dtype=dtype,
with_coords=with_coords,
)
else:
result = self.get_expr(
expression,
engine,
allow_native=False,
dtype=dtype,
with_coords=with_coords,
)
if with_coords and "expressions" not in result.coords:
# add the expression as a scalar coordinate (with no dimension)
result = result.assign_coords(expressions=xr.DataArray(expression))
Expand Down Expand Up @@ -1081,6 +1140,8 @@ def eval_many(
expressions = pd.Series(expressions, index=expressions)
if isinstance(expressions, Mapping):
expressions = pd.Series(expressions)
if len(expressions) == 0:
raise ValueError("no expressions provided")
if result_type == "dataset":
arrays = {}
for k, v in expressions.items():
Expand Down
Loading

0 comments on commit 52c594b

Please sign in to comment.