Skip to content

Commit

Permalink
12k->14.5k
Browse files Browse the repository at this point in the history
  • Loading branch information
hmaarrfk committed May 5, 2024
1 parent 3290e24 commit 1a659bc
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 13 deletions.
69 changes: 58 additions & 11 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,10 @@ def unstack(self) -> tuple[dict[Hashable, Index], pd.MultiIndex]:
raise NotImplementedError()

def create_variables(
self, variables: Mapping[Any, Variable] | None = None
self,
variables: Mapping[Any, Variable] | None = None,
*,
fastpath=False,
) -> IndexVars:
"""Maybe create new coordinate variables from this index.
Expand Down Expand Up @@ -575,13 +578,19 @@ class PandasIndex(Index):

__slots__ = ("index", "dim", "coord_dtype")

def __init__(self, array: Any, dim: Hashable, coord_dtype: Any = None):
# make a shallow copy: cheap and because the index name may be updated
# here or in other constructors (cannot use pd.Index.rename as this
# constructor is also called from PandasMultiIndex)
index = safe_cast_to_index(array).copy()
def __init__(
self, array: Any, dim: Hashable, coord_dtype: Any = None, *, fastpath=False
):
if fastpath:
index = array
else:
index = safe_cast_to_index(array)

if index.name is None:
# make a shallow copy: cheap and because the index name may be updated
# here or in other constructors (cannot use pd.Index.rename as this
# constructor is also called from PandasMultiIndex)
index = index.copy()
index.name = dim

self.index = index
Expand All @@ -596,7 +605,7 @@ def _replace(self, index, dim=None, coord_dtype=None):
dim = self.dim
if coord_dtype is None:
coord_dtype = self.coord_dtype
return type(self)(index, dim, coord_dtype)
return type(self)(index, dim, coord_dtype, fastpath=True)

@classmethod
def from_variables(
Expand Down Expand Up @@ -641,6 +650,8 @@ def from_variables(

obj = cls(data, dim, coord_dtype=var.dtype)
assert not isinstance(obj.index, pd.MultiIndex)
# Rename safely
obj.index = obj.index.copy()
obj.index.name = name

return obj
Expand Down Expand Up @@ -684,7 +695,7 @@ def concat(
return cls(new_pd_index, dim=dim, coord_dtype=coord_dtype)

def create_variables(
self, variables: Mapping[Any, Variable] | None = None
self, variables: Mapping[Any, Variable] | None = None, *, fastpath=False
) -> IndexVars:
from xarray.core.variable import IndexVariable

Expand All @@ -701,7 +712,9 @@ def create_variables(
encoding = None

data = PandasIndexingAdapter(self.index, dtype=self.coord_dtype)
var = IndexVariable(self.dim, data, attrs=attrs, encoding=encoding)
var = IndexVariable(
self.dim, data, attrs=attrs, encoding=encoding, fastpath=fastpath
)
return {name: var}

def to_pandas_index(self) -> pd.Index:
Expand Down Expand Up @@ -1122,7 +1135,7 @@ def reorder_levels(
return self._replace(index, level_coords_dtype=level_coords_dtype)

def create_variables(
self, variables: Mapping[Any, Variable] | None = None
self, variables: Mapping[Any, Variable] | None = None, *, fastpath=False
) -> IndexVars:
from xarray.core.variable import IndexVariable

Expand Down Expand Up @@ -1772,6 +1785,37 @@ def check_variables():
return not not_equal


def _apply_indexes_fast(indexes: Indexes[Index], args: Mapping[Any, Any], func: str):
# This function avoids the call to indexes.group_by_index
# which is really slow when repeatidly iterating through
# an array. However, it fails to return the correct ID for
# multi-index arrays
indexes_fast, coords = indexes._indexes, indexes._variables

new_indexes: dict[Hashable, Index] = {k: v for k, v in indexes_fast.items()}
new_index_variables: dict[Hashable, Variable] = {}
for name, index in indexes_fast.items():
coord = coords[name]
if hasattr(coord, "_indexes"):
index_vars = {n: coords[n] for n in coord._indexes}
else:
index_vars = {name: coord}
index_dims = {d for var in index_vars.values() for d in var.dims}
index_args = {k: v for k, v in args.items() if k in index_dims}

if index_args:
new_index = getattr(index, func)(index_args)
if new_index is not None:
new_indexes.update({k: new_index for k in index_vars})
new_index_vars = new_index.create_variables(index_vars, fastpath=True)
new_index_variables.update(new_index_vars)
new_index_variables.update(new_index_vars)
else:
for k in index_vars:
new_indexes.pop(k, None)
return new_indexes, new_index_variables


def _apply_indexes(
indexes: Indexes[Index],
args: Mapping[Any, Any],
Expand Down Expand Up @@ -1800,7 +1844,10 @@ def isel_indexes(
indexes: Indexes[Index],
indexers: Mapping[Any, Any],
) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]:
return _apply_indexes(indexes, indexers, "isel")
if any(isinstance(v, PandasMultiIndex) for v in indexes._indexes.values()):
return _apply_indexes(indexes, indexers, "isel")
else:
return _apply_indexes_fast(indexes, indexers, "isel")


def roll_indexes(
Expand Down
7 changes: 5 additions & 2 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1662,10 +1662,13 @@ class PandasIndexingAdapter(ExplicitlyIndexedNDArrayMixin):

__slots__ = ("array", "_dtype")

def __init__(self, array: pd.Index, dtype: DTypeLike = None):
def __init__(self, array: pd.Index, dtype: DTypeLike = None, *, fastpath=False):
from xarray.core.indexes import safe_cast_to_index

self.array = safe_cast_to_index(array)
if fastpath:
self.array = array
else:
self.array = safe_cast_to_index(array)

if dtype is None:
self._dtype = get_valid_numpy_dtype(array)
Expand Down

0 comments on commit 1a659bc

Please sign in to comment.