Skip to content

Commit 7ff801d

Browse files
committed
Add DataTree.match_names to match node names
1 parent 25debff commit 7ff801d

File tree

2 files changed

+65
-0
lines changed

2 files changed

+65
-0
lines changed

xarray/core/datatree.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1349,6 +1349,50 @@ def match(self, pattern: str) -> DataTree:
13491349
}
13501350
return DataTree.from_dict(matching_nodes, name=self.root.name)
13511351

1352+
def match_names(self, names: Iterable[str]) -> DataTree:
1353+
"""
1354+
Filter nodes by name.
1355+
1356+
Parameters
1357+
----------
1358+
names: Iterable[str]
1359+
The list of node names to retain.
1360+
1361+
Returns
1362+
-------
1363+
DataTree
1364+
1365+
See Also
1366+
--------
1367+
match
1368+
filter
1369+
pipe
1370+
map_over_subtree
1371+
1372+
Examples
1373+
--------
1374+
>>> dt = DataTree.from_dict(
1375+
... {
1376+
... "/a/A": None,
1377+
... "/a/B": None,
1378+
... "/a/C": None,
1379+
... "/C/D": None,
1380+
... "/E/F": None,
1381+
... }
1382+
... )
1383+
>>> dt.match_names(["A", "C"])
1384+
DataTree('None', parent=None)
1385+
├── DataTree('a')
1386+
│ └── DataTree('A')
1387+
│ └── DataTree('C')
1388+
└── DataTree('C')
1389+
"""
1390+
names = set(names)
1391+
matching_nodes = {
1392+
node.path: node.ds for node in self.subtree if node.name in names
1393+
}
1394+
return DataTree.from_dict(matching_nodes, name=self.root.name)
1395+
13521396
def map_over_subtree(
13531397
self,
13541398
func: Callable,

xarray/tests/test_datatree.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,6 +1025,27 @@ def test_match(self):
10251025
)
10261026
assert_identical(result, expected)
10271027

1028+
def test_match_names(self):
1029+
# TODO is this example going to cause problems with case sensitivity?
1030+
dt: DataTree = DataTree.from_dict(
1031+
{
1032+
"/a/A": None,
1033+
"/a/B": None,
1034+
"/a/C": None,
1035+
"/C/D": None,
1036+
"/E/F": None,
1037+
}
1038+
)
1039+
result = dt.match_names(["A", "C"])
1040+
expected = DataTree.from_dict(
1041+
{
1042+
"/a/A": None,
1043+
"/a/C": None,
1044+
"/C": None,
1045+
}
1046+
)
1047+
assert_identical(result, expected)
1048+
10281049
def test_filter(self):
10291050
simpsons: DataTree = DataTree.from_dict(
10301051
d={

0 commit comments

Comments
 (0)