Skip to content

Commit 4102b9f

Browse files
committed
refactor swap_dims using atomic ops
1 parent 653fa40 commit 4102b9f

File tree

1 file changed

+11
-28
lines changed

1 file changed

+11
-28
lines changed

xarray/core/dataset.py

Lines changed: 11 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4334,9 +4334,9 @@ def rename_dims(
43344334
f"cannot rename {k!r} because it is not found "
43354335
f"in the dimensions of this dataset {tuple(self.dims)}"
43364336
)
4337-
if v in self.dims or v in self:
4337+
if v in self.dims:
43384338
raise ValueError(
4339-
f"Cannot rename {k} to {v} because {v} already exists. "
4339+
f"Cannot rename dimension {k} to {v} because dimension {v} already exists. "
43404340
"Try using swap_dims instead."
43414341
)
43424342

@@ -4464,33 +4464,16 @@ def swap_dims(
44644464
f"variable along the old dimension {current_name!r}"
44654465
)
44664466

4467-
result_dims = {dims_dict.get(dim, dim) for dim in self.dims}
4467+
result = self.rename_dims(dims_dict)
4468+
result = result.drop_indexes(dims_dict.keys(), errors="ignore")
4469+
for dim in dims_dict.values():
4470+
if dim in result._variables:
4471+
if dim not in result._coord_names:
4472+
result = result.set_coords(dim)
4473+
if dim not in result._indexes:
4474+
result = result.set_xindex(dim)
44684475

4469-
coord_names = self._coord_names.copy()
4470-
coord_names.update({dim for dim in dims_dict.values() if dim in self.variables})
4471-
4472-
variables: dict[Hashable, Variable] = {}
4473-
indexes: dict[Hashable, Index] = {}
4474-
for current_name, current_variable in self.variables.items():
4475-
dims = tuple(dims_dict.get(dim, dim) for dim in current_variable.dims)
4476-
var: Variable
4477-
if current_name in result_dims:
4478-
var = current_variable.to_index_variable()
4479-
var.dims = dims
4480-
if current_name in self._indexes:
4481-
indexes[current_name] = self._indexes[current_name]
4482-
variables[current_name] = var
4483-
else:
4484-
index, index_vars = create_default_index_implicit(var)
4485-
indexes.update({name: index for name in index_vars})
4486-
variables.update(index_vars)
4487-
coord_names.update(index_vars)
4488-
else:
4489-
var = current_variable.to_base_variable()
4490-
var.dims = dims
4491-
variables[current_name] = var
4492-
4493-
return self._replace_with_new_dims(variables, coord_names, indexes=indexes)
4476+
return result
44944477

44954478
def expand_dims(
44964479
self,

0 commit comments

Comments
 (0)