Skip to content

Commit

Permalink
refactor swap_dims using atomic ops
Browse files Browse the repository at this point in the history
  • Loading branch information
benbovy committed Apr 5, 2024
1 parent 653fa40 commit 4102b9f
Showing 1 changed file with 11 additions and 28 deletions.
39 changes: 11 additions & 28 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4334,9 +4334,9 @@ def rename_dims(
f"cannot rename {k!r} because it is not found "
f"in the dimensions of this dataset {tuple(self.dims)}"
)
if v in self.dims or v in self:
if v in self.dims:
raise ValueError(
f"Cannot rename {k} to {v} because {v} already exists. "
f"Cannot rename dimension {k} to {v} because dimension {v} already exists. "
"Try using swap_dims instead."
)

Expand Down Expand Up @@ -4464,33 +4464,16 @@ def swap_dims(
f"variable along the old dimension {current_name!r}"
)

result_dims = {dims_dict.get(dim, dim) for dim in self.dims}
result = self.rename_dims(dims_dict)
result = result.drop_indexes(dims_dict.keys(), errors="ignore")
for dim in dims_dict.values():
if dim in result._variables:
if dim not in result._coord_names:
result = result.set_coords(dim)
if dim not in result._indexes:
result = result.set_xindex(dim)

coord_names = self._coord_names.copy()
coord_names.update({dim for dim in dims_dict.values() if dim in self.variables})

variables: dict[Hashable, Variable] = {}
indexes: dict[Hashable, Index] = {}
for current_name, current_variable in self.variables.items():
dims = tuple(dims_dict.get(dim, dim) for dim in current_variable.dims)
var: Variable
if current_name in result_dims:
var = current_variable.to_index_variable()
var.dims = dims
if current_name in self._indexes:
indexes[current_name] = self._indexes[current_name]
variables[current_name] = var
else:
index, index_vars = create_default_index_implicit(var)
indexes.update({name: index for name in index_vars})
variables.update(index_vars)
coord_names.update(index_vars)
else:
var = current_variable.to_base_variable()
var.dims = dims
variables[current_name] = var

return self._replace_with_new_dims(variables, coord_names, indexes=indexes)
return result

def expand_dims(
self,
Expand Down

0 comments on commit 4102b9f

Please sign in to comment.