diff --git a/sgschema/schema.py b/sgschema/schema.py index b73c405..0bf4c7f 100644 --- a/sgschema/schema.py +++ b/sgschema/schema.py @@ -255,7 +255,7 @@ def _resolve_field(self, entity_spec, field_spec, auto_prefix=True, implicit_ali return [field_spec] - def resolve_field(self, entity_type, field_spec=None, auto_prefix=True, implicit_aliases=True, strict=False): + def resolve_field(self, entity_type, field_spec, auto_prefix=True, implicit_aliases=True, strict=False): spec_parts = field_spec.split('.') @@ -285,14 +285,27 @@ def resolve_field(self, entity_type, field_spec=None, auto_prefix=True, implicit resolved_fields.append(field) return resolved_fields - def resolve_structure(self, x, **kwargs): + def resolve_one_field(self, entity_type, field_spec, **kwargs): + res = self.resolve_field(entity_type, field_spec, **kwargs) + if len(res) == 1: + return res[0] + else: + raise ValueError('%r returned %s %s fields' % (field_spec, len(res), entity_type)) + + def resolve_fields(self, entity_type, field_specs, **kwargs): + res = [] + for field_spec in field_specs: + res.extend(self.resolve_field(entity_type, field_spec, **kwargs)) + return res + + def resolve_structure(self, x, entity_type=None, **kwargs): if isinstance(x, (list, tuple)): return type(x)(self.resolve_structure(x, **kwargs) for x in x) elif isinstance(x, dict): - if 'type' in x and x['type'] in self.entities: - entity_type = x['type'] + entity_type = entity_type or x.get('type') + if entity_type and entity_type in self.entities: new_values = {} for field_spec, value in x.iteritems(): value = self.resolve_structure(value) diff --git a/tests/test_fields.py b/tests/test_fields.py index 83696e0..a07908c 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -25,6 +25,7 @@ def setUp(self): }, 'field_tags': { 'y': ['attr'], + 'multi': ['multi_a', 'multi_b'], } } }, @@ -69,3 +70,15 @@ def test_missing(self): self.assertEqual(self.s.resolve_field('Entity', 'missing'), ['missing']) self.assertRaises(ValueError, self.s.resolve_field, 'Entity', 'missing', strict=True) + def test_one(self): + self.assertEqual(self.s.resolve_one_field('Entity', 'sg_type'), 'sg_type') + self.assertEqual(self.s.resolve_one_field('Entity', '$a'), 'attr') + self.assertEqual(self.s.resolve_one_field('Entity', '#x'), 'attr') + self.assertRaises(ValueError, self.s.resolve_one_field, 'Entity', '#missing') + self.assertRaises(ValueError, self.s.resolve_one_field, 'Entity', '#multi') + + def test_many(self): + self.assertEqual(self.s.resolve_fields('Entity', ['sg_type', 'version', '#x', '#multi']), [ + 'sg_type', 'sg_version', 'attr', 'multi_a', 'multi_b', + ]) +