diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index f6001c6..118e988 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -92,7 +92,7 @@ jobs: auto-update-conda: false - name: Install sharrow run: | - python -m pip install -e . + python -m pip install . - name: Conda checkup run: | conda info -a @@ -137,7 +137,7 @@ jobs: conda install jupyter-book ruamel.yaml sphinx-autosummary-accessors -c conda-forge - name: Install sharrow run: | - python -m pip install --no-deps -e . + python -m pip install --no-deps . - name: Conda checkup run: | conda info -a diff --git a/envs/development.yml b/envs/development.yml index dca2a77..2fde740 100644 --- a/envs/development.yml +++ b/envs/development.yml @@ -10,7 +10,6 @@ dependencies: - filelock - ruff - jupyter - - larch>=5.7.1 - nbmake - networkx - notebook @@ -29,4 +28,5 @@ dependencies: - zarr - pip: + - larch6 - -e .. diff --git a/envs/testing.yml b/envs/testing.yml index 1ab2609..38d839c 100644 --- a/envs/testing.yml +++ b/envs/testing.yml @@ -22,6 +22,7 @@ dependencies: - pytest-xdist - nbmake - openmatrix + - h5py - zarr - pip: - larch6 diff --git a/pyproject.toml b/pyproject.toml index 79043bb..97f4094 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,10 @@ Repository = "https://github.com/activitysim/sharrow" [tool.setuptools] packages = ["sharrow", "sharrow.utils"] +include-package-data = true + +[tool.setuptools.package-data] +sharrow = ["*"] [tool.setuptools_scm] fallback_version = "1999" diff --git a/sharrow/dataset.py b/sharrow/dataset.py index d4695df..d136552 100755 --- a/sharrow/dataset.py +++ b/sharrow/dataset.py @@ -269,7 +269,7 @@ def from_table( result = xr.Dataset() if isinstance(index, pd.MultiIndex): dims = tuple( - name if name is not None else "level_%i" % n + name if name is not None else f"level_{n}" for n, name in enumerate(index.names) ) for dim, lev in zip(dims, index.levels): diff --git a/sharrow/example_data.py b/sharrow/example_data.py index 69572c6..95aef53 100644 --- a/sharrow/example_data.py +++ b/sharrow/example_data.py @@ -1,4 +1,5 @@ import os +from importlib.resources import as_file, files import numpy as np import pandas as pd @@ -9,17 +10,14 @@ def get_skims_filename() -> str: return os.path.join(os.path.dirname(__file__), "example_data", "skims.omx") -def get_skims(): +def get_skims_omx(): import openmatrix from . import dataset - zfilename = os.path.join(os.path.dirname(__file__), "example_data", "skims.zarr") - if os.path.exists(zfilename): - skims = dataset.from_zarr(zfilename, consolidated=False) - else: - filename = os.path.join(os.path.dirname(__file__), "example_data", "skims.omx") - with openmatrix.open_file(filename) as f: + with as_file(files("sharrow").joinpath("example_data/skims.omx")) as filename: + skims = None + with openmatrix.open_file(str(filename)) as f: skims = dataset.from_omx_3d( f, index_names=("otaz", "dtaz", "time_period"), @@ -28,39 +26,56 @@ def get_skims(): time_period_sep="__", max_float_precision=32, ).compute() - skims.to_zarr(zfilename) + return skims + + +def get_skims_zarr(): + from . import dataset + + f = files("sharrow").joinpath("example_data/skims.zarr") + with as_file(f) as zfile: + if zfile.exists(): + skims = dataset.from_zarr(zfile, consolidated=False) + else: + skims = None + return skims + + +def get_skims(): + from . import dataset + + f = files("sharrow").joinpath("example_data/skims.zarr") + with as_file(f) as zfile: + if zfile.exists(): + skims = dataset.from_zarr(zfile, consolidated=False) + else: + skims = get_skims_omx() return skims def get_households(): - filename = os.path.join( - os.path.dirname(__file__), "example_data", "households.csv.gz" - ) - return pd.read_csv(filename, index_col="HHID") + with as_file(files("sharrow").joinpath("example_data/households.csv.gz")) as f: + return pd.read_csv(f, index_col="HHID") def get_persons(): - filename = os.path.join(os.path.dirname(__file__), "example_data", "persons.csv.gz") - return pd.read_csv(filename, index_col="PERID") + with as_file(files("sharrow").joinpath("example_data/persons.csv.gz")) as f: + return pd.read_csv(f, index_col="PERID") def get_land_use(): - filename = os.path.join( - os.path.dirname(__file__), "example_data", "land_use.csv.gz" - ) - return pd.read_csv(filename, index_col="TAZ") + with as_file(files("sharrow").joinpath("example_data/land_use.csv.gz")) as f: + return pd.read_csv(f, index_col="TAZ") def get_maz_to_taz(): - filename = os.path.join(os.path.dirname(__file__), "example_data", "maz_to_taz.csv") - return pd.read_csv(filename, index_col="MAZ") + with as_file(files("sharrow").joinpath("example_data/maz_to_taz.csv")) as f: + return pd.read_csv(f, index_col="MAZ") def get_maz_to_maz_walk(): - filename = os.path.join( - os.path.dirname(__file__), "example_data", "maz_to_maz_walk.csv" - ) - return pd.read_csv(filename) + with as_file(files("sharrow").joinpath("example_data/maz_to_maz_walk.csv")) as f: + return pd.read_csv(f) def get_data(): diff --git a/sharrow/relationships.py b/sharrow/relationships.py index f06d220..1adcc10 100644 --- a/sharrow/relationships.py +++ b/sharrow/relationships.py @@ -2,6 +2,7 @@ import logging import warnings from collections.abc import Mapping, Sequence +from numbers import Number from typing import Literal import networkx as nx @@ -10,7 +11,7 @@ import xarray as xr from .dataset import Dataset, construct -from .tree_branch import DataTreeBranch +from .tree_branch import CachedTree, DataTreeBranch try: from dask.array import Array as dask_array_type @@ -898,6 +899,7 @@ def get_expr( *, dtype="float32", with_coords: bool = True, + parser: Literal["pandas", "python"] = "pandas", ): """ Access or evaluate an expression. @@ -905,8 +907,10 @@ def get_expr( Parameters ---------- expression : str - engine : {'sharrow', 'numexpr', 'python'} - The engine used to resolve expressions. + engine : {'sharrow', 'numexpr', 'python', 'pandas-numexpr'} + The engine used to resolve expressions. The numexpr engine uses + that library directly, while the pandas-numexpr engine uses the + pandas `eval` method with the numexpr engine. allow_native : bool, default True If the expression is an array in a dataset of this tree, return that array directly. Set to false to force evaluation, which @@ -918,11 +922,19 @@ def get_expr( Attach coordinates from the root node of the tree to the result. If the coordinates are not needed in the result, the process of attaching them can be skipped. + parser : {'pandas', 'python'} + The parser to use when evaluating the expression. This argument + only applies to pandas-based engines ('python' and 'pandas-numexpr'). + It is ignored when using the 'sharrow' or 'numexpr' engines. Returns ------- DataArray """ + if np.issubdtype(dtype, np.number) and isinstance(dtype, type): + dtype = dtype.__name__ + elif dtype is bool: + dtype = "bool" try: if allow_native: result = self[expression] @@ -938,16 +950,49 @@ def get_expr( .isel(expressions=0) ) elif engine == "numexpr": + import numexpr as ne + from xarray import DataArray + + try: + result = DataArray( + ne.evaluate(expression, local_dict=CachedTree(self)), + ) + except Exception: + if dtype is None: + dtype = "float32" + result = ( + self.setup_flow({expression: expression}, dtype=dtype) + .load_dataarray() + .isel(expressions=0) + ) + else: + if dtype is not None: + result = result.astype(dtype) + # numexpr doesn't carry over the dimension names or coords + result = result.rename( + {result.dims[i]: self.root_dims[i] for i in range(result.ndim)} + ) + if with_coords: + result = result.assign_coords(self.root_dataset.coords) + + elif engine == "pandas-numexpr": from xarray import DataArray self._eval_cache = {} try: result = DataArray( - pd.eval(expression, resolvers=[self], engine="numexpr"), + pd.eval( + expression, + resolvers=[self], + engine="numexpr", + parser=parser, + ), ).astype(dtype) except NotImplementedError: result = DataArray( - pd.eval(expression, resolvers=[self], engine="python"), + pd.eval( + expression, resolvers=[self], engine="python", parser=parser + ), ).astype(dtype) else: # numexpr doesn't carry over the dimension names or coords @@ -964,7 +1009,9 @@ def get_expr( self._eval_cache = {} try: result = DataArray( - pd.eval(expression, resolvers=[self], engine="python"), + pd.eval( + expression, resolvers=[self], engine="python", parser=parser + ), ).astype(dtype) finally: del self._eval_cache @@ -974,7 +1021,7 @@ def get_expr( def eval( self, - expression: str, + expression: str | Number, engine: Literal[None, "numexpr", "sharrow", "python"] = None, *, dtype: np.dtype | str | None = None, @@ -992,7 +1039,7 @@ def eval( Parameters ---------- - expression : str + expression : str | Number engine : {None, 'numexpr', 'sharrow', 'python'} The engine used to resolve expressions. If None, the default is to try 'numexpr' first, then 'sharrow' if that fails. @@ -1007,33 +1054,45 @@ def eval( ------- DataArray """ - if not isinstance(expression, str): - raise TypeError("expression must be a string") - if engine is None: - try: - result = self.get_expr( - expression, - "numexpr", - allow_native=False, - dtype=dtype, - with_coords=with_coords, + # when passing in a numeric value or boolean, simply broadcast it to the root dims + if isinstance(expression, bool): + expression = int(expression) + if isinstance(expression, Number): + this_shape = [self.root_dataset.sizes.get(i) for i in self.root_dims] + result = xr.DataArray( + np.broadcast_to(expression, this_shape), dims=self.root_dims + ) + expression = str(expression) + else: + if not isinstance(expression, str): + raise TypeError( + f"expression must be a string, not a {type(expression)}" ) - except Exception: + if engine is None: + try: + result = self.get_expr( + expression, + "numexpr", + allow_native=False, + dtype=dtype, + with_coords=with_coords, + ) + except Exception: + result = self.get_expr( + expression, + "sharrow", + allow_native=False, + dtype=dtype, + with_coords=with_coords, + ) + else: result = self.get_expr( expression, - "sharrow", + engine, allow_native=False, dtype=dtype, with_coords=with_coords, ) - else: - result = self.get_expr( - expression, - engine, - allow_native=False, - dtype=dtype, - with_coords=with_coords, - ) if with_coords and "expressions" not in result.coords: # add the expression as a scalar coordinate (with no dimension) result = result.assign_coords(expressions=xr.DataArray(expression)) @@ -1081,6 +1140,8 @@ def eval_many( expressions = pd.Series(expressions, index=expressions) if isinstance(expressions, Mapping): expressions = pd.Series(expressions) + if len(expressions) == 0: + raise ValueError("no expressions provided") if result_type == "dataset": arrays = {} for k, v in expressions.items(): diff --git a/sharrow/tests/test_example_data.py b/sharrow/tests/test_example_data.py new file mode 100644 index 0000000..3e4ab86 --- /dev/null +++ b/sharrow/tests/test_example_data.py @@ -0,0 +1,60 @@ +import numpy as np +import pandas as pd + +import sharrow as sh + + +def test_skims(): + skims = sh.example_data.get_skims() + assert isinstance(skims, sh.Dataset) + np.testing.assert_almost_equal( + skims.DIST.values[:2, :3], + np.asarray([[0.12, 0.24, 0.44], [0.37, 0.14, 0.28]]), + ) + + +def test_skims_zarr(): + skims = sh.example_data.get_skims_zarr() + assert isinstance(skims, sh.Dataset) + np.testing.assert_almost_equal( + skims.DIST.values[:2, :3], + np.asarray([[0.12, 0.24, 0.44], [0.37, 0.14, 0.28]]), + ) + + +# def test_skims_omx(): +# skims = sh.example_data.get_skims_omx() +# assert isinstance(skims, sh.Dataset) +# np.testing.assert_almost_equal( +# skims.DIST.values[:2, :3], +# np.asarray([[0.12, 0.24, 0.44], [0.37, 0.14, 0.28]]), +# ) + + +def test_maz_to_taz(): + maz_to_taz = sh.example_data.get_maz_to_taz() + assert isinstance(maz_to_taz, pd.DataFrame) + assert maz_to_taz.index.name == "MAZ" + + +def test_maz_to_maz_walk(): + maz_to_maz_walk = sh.example_data.get_maz_to_maz_walk() + assert isinstance(maz_to_maz_walk, pd.DataFrame) + assert list(maz_to_maz_walk.columns) == ["OMAZ", "DMAZ", "DISTWALK"] + + +def test_land_use(): + land_use = sh.example_data.get_land_use() + assert isinstance(land_use, pd.DataFrame) + assert land_use.index.name == "TAZ" + + +def test_data(): + data = sh.example_data.get_data() + assert isinstance(data, dict) + assert isinstance(data["hhs"], pd.DataFrame) + assert isinstance(data["persons"], pd.DataFrame) + assert isinstance(data["land_use"], pd.DataFrame) + assert isinstance(data["skims"], sh.Dataset) + assert isinstance(data["maz_taz"], pd.DataFrame) + assert isinstance(data["maz_maz_walk"], pd.DataFrame) diff --git a/sharrow/tree_branch.py b/sharrow/tree_branch.py index fc3df36..46de098 100644 --- a/sharrow/tree_branch.py +++ b/sharrow/tree_branch.py @@ -29,3 +29,17 @@ def __getattr__(self, item): return self.tree[self.branch + "." + item] else: raise AttributeError(f"{item} not found in {self.branch}") + + +class CachedTree: + """A wrapper that caches the results of getitem calls.""" + + def __init__(self, tree): + self._tree = tree + self._cache = {} + + def __getitem__(self, item): + x = self._cache.get(item, None) + if x is None: + x = self._cache[item] = self._tree[item] + return x