diff --git a/sgschema/schema.py b/sgschema/schema.py index d287e9f..fb92712 100644 --- a/sgschema/schema.py +++ b/sgschema/schema.py @@ -436,7 +436,7 @@ def has_field(self, entity_type, field_spec, **kwargs): else: raise - def resolve_structure(self, x, entity_type=None, _seen=None, **kwargs): + def resolve_structure(self, x, entity_type=None, _memo=None, **kwargs): """Traverse a nested structure resolving names in entities. Recurses into ``list``, ``tuple`` and ``dict``, looking for ``dicts`` @@ -450,30 +450,43 @@ def resolve_structure(self, x, entity_type=None, _seen=None, **kwargs): """ # Protect from infinite recursion. - if _seen is None: - _seen = set() + if _memo is None: + _memo = {} id_ = id(x) - if id_ in _seen: - return - _seen.add(id_) + if id_ in _memo: + return _memo[id_] - if isinstance(x, (list, tuple)): - return type(x)(self.resolve_structure(x, None, _seen, **kwargs) for x in x) + # Setup a dumb recursion block. _memo should be updated with the new + # collection before recursing to resolve that collection. + _memo[id_] = None + + _memo[id_] = res = self._resolve_structure(x, entity_type, _memo, kwargs) + return res + + def _resolve_structure(self, x, entity_type, _memo, kwargs): + + if isinstance(x, list): + _memo[id(x)] = new = type(x)() # For recursion. + new.extend(self.resolve_structure(x, None, _memo, **kwargs) for x in x) + return new + + # Tuples dont need to be cached. + if isinstance(x, tuple): + return type(x)(self.resolve_structure(x, None, _memo, **kwargs) for x in x) elif isinstance(x, dict): + _memo[id(x)] = new = {} # For recursion. entity_type = entity_type or x.get('type') if entity_type and entity_type in self.entities: - new_values = {} + # Entities resolve their keys. for field_spec, value in x.iteritems(): - value = self.resolve_structure(value) + value = self.resolve_structure(value, None, _memo) for field in self.resolve_field(entity_type, field_spec, **kwargs): - new_values[field] = value - return new_values + new[field] = value else: - return { - k: self.resolve_structure(v, None, _seen, **kwargs) - for k, v in x.iteritems() - } + for k, v in x.iteritems(): + new[k] = self.resolve_structure(v, None, _memo, **kwargs) + return new else: return x diff --git a/tests/test_structures.py b/tests/test_structures.py index e1d055c..9951d2f 100644 --- a/tests/test_structures.py +++ b/tests/test_structures.py @@ -60,3 +60,31 @@ def test_entity_list(self): } ]) + def test_recursion(self): + + entity = { + 'type': 'Entity', + 'version': 1 + } + entity['self'] = entity + + out = self.s.resolve_structure([entity, entity]) + + # Both items are the same entity. + self.assertEqual(len(out), 2) + entity = out[0] + self.assertIs(entity, out[1]) + + # It contains itself. + self.assertIs(entity, entity['self']) + + entity.pop('self') # just remove it for comparison + + # Simple values. + self.assertEqual(entity, { + 'type': 'Entity', + 'sg_version': 1, + }) + + +