From 86ca70d4b2fa114b5134f6e81f3539eaf5e20957 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20M=C3=BCller?= Date: Sat, 24 Feb 2024 18:53:19 +0100 Subject: [PATCH] Make deep_map aware of dict_values (used in `MaterializationWrapper.__call__()` to fill `input_tables`) --- src/pydiverse/pipedag/util/deep_map.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/pydiverse/pipedag/util/deep_map.py b/src/pydiverse/pipedag/util/deep_map.py index 5c4aef9d..cb7efecd 100644 --- a/src/pydiverse/pipedag/util/deep_map.py +++ b/src/pydiverse/pipedag/util/deep_map.py @@ -8,6 +8,7 @@ from typing import Callable _nil = [] +_dict_values_class = type({}.values()) def deep_map(x, fn: Callable, memo=None): @@ -27,6 +28,8 @@ def deep_map(x, fn: Callable, memo=None): y = _deep_map_tuple(x, fn, memo) elif cls == dict: y = _deep_map_dict(x, fn, memo) + elif cls == _dict_values_class: + y = _deep_map_dict_values(x, fn, memo) else: y = fn(x) @@ -46,6 +49,23 @@ def _deep_map_list(x, fn, memo): return fn(y) +def _deep_map_dict_values(x, fn, memo): + y = [deep_map(a, fn, memo) for a in x] + # We're not going to put the dict_values in the memo, but it's still important we + # check for it, in case the tuple contains recursive mutable structures. + try: + return memo[id(x)] + except KeyError: + pass + for k, j in zip(x, y): + if k is not j: + y = {i: v for i, v in zip(range(len(y)), y)}.values() + break + else: + y = x + return fn(y) + + def _deep_map_tuple(x, fn, memo): y = [deep_map(a, fn, memo) for a in x] # We're not going to put the tuple in the memo, but it's still important we