diff --git a/sgschema/schema.py b/sgschema/schema.py index 081046e..fdf86d0 100644 --- a/sgschema/schema.py +++ b/sgschema/schema.py @@ -162,6 +162,78 @@ def load(self, input_): else: raise ValueError('unknown complex field %s' % key) + def resolve(self, entity_spec, field_spec=None, auto_prefix=True, implicit_aliases=True, strict=False): + + if field_spec is None: # We are resolving an entity. + + m = re.match(r'^([!#$]?)([\w:-]+)$', entity_spec) + if not m: + raise ValueError('%r cannot be an entity' % entity_spec) + operation, entity_spec = m.groups() + + if operation == '!': + return [entity_spec] + if operation == '#': + return list(self.entity_tags.get(entity_spec, ())) + if operation == '$': + try: + return [self.entity_aliases[entity_spec]] + except KeyError: + return [] + + if entity_spec in self.entities: + return [entity_spec] + + if implicit_aliases and entity_spec in self.entity_aliases: + return [self.entity_aliases[entity_spec]] + + if strict: + raise ValueError('%r is not an entity' % entity_spec) + + return [entity_spec] + + # When resolving a field, the entity must exist. + try: + entity = self.entities[entity_spec] + except KeyError: + raise ValueError('%r is not an entity' % entity_spec) + + m = re.match(r'^([!#$]?)([\w:-]+)$', field_spec) + if not m: + raise ValueError('%r cannot be a field' % field_spec) + operation, field_spec = m.groups() + + if operation == '!': + return [field_spec] + if operation == '#': + return list(entity.field_tags.get(field_spec, ())) + if operation == '$': + try: + return [entity.field_aliases[field_spec]] + except KeyError: + return [] + + if field_spec in entity.fields: + return [field_spec] + + if auto_prefix: + prefixed = 'sg_' + field_spec + if prefixed in entity.fields: + return [prefixed] + + if implicit_aliases and field_spec in entity.field_aliases: + return [entity.field_aliases[field_spec]] + + if strict: + raise ValueError('%r is not a field of %s' % (field_spec, entity_spec)) + + return [field_spec] + + + + + + if __name__ == '__main__': diff --git a/tests/test_resolve.py b/tests/test_resolve.py new file mode 100644 index 0000000..737edc3 --- /dev/null +++ b/tests/test_resolve.py @@ -0,0 +1,108 @@ +from . import * + + +class TestResolveEntities(TestCase): + + def setUp(self): + + self.s = s = Schema() + s.load({ + 'entities': { + 'Entity': { + 'aliases': ['A', 'with:Namespace'], + 'tags': ['X'], + } + }, + 'entity_aliases': { + 'B': 'Entity', + }, + 'entity_tags': { + 'Y': ['Entity'], + } + }) + + def test_explicit(self): + self.assertEqual(self.s.resolve('!Entity'), ['Entity']) + self.assertEqual(self.s.resolve('$A'), ['Entity']) + self.assertEqual(self.s.resolve('$B'), ['Entity']) + self.assertEqual(self.s.resolve('#X'), ['Entity']) + self.assertEqual(self.s.resolve('#Y'), ['Entity']) + + def test_namespace(self): + self.assertEqual(self.s.resolve('$with:Namespace'), ['Entity']) + + def test_implicit(self): + self.assertEqual(self.s.resolve('Entity'), ['Entity']) + self.assertEqual(self.s.resolve('A'), ['Entity']) + self.assertEqual(self.s.resolve('B'), ['Entity']) + + def test_missing(self): + self.assertEqual(self.s.resolve('#Missing'), []) + self.assertEqual(self.s.resolve('$Missing'), []) + self.assertEqual(self.s.resolve('!Missing'), ['Missing']) + self.assertEqual(self.s.resolve('Missing'), ['Missing']) + self.assertRaises(ValueError, self.s.resolve, 'Missing', strict=True) + +class TestResolveFields(TestCase): + + def setUp(self): + + self.s = s = Schema() + s.load({ + 'entities': { + 'Entity': { + 'fields': { + 'attr': { + 'aliases': ['a', 'with:namespace'], + 'tags': ['x'], + + }, + 'sg_type': {}, + 'name': {}, + 'sg_name': {}, + }, + 'field_aliases': { + 'b': 'attr', + }, + 'field_tags': { + 'y': ['attr'], + } + } + }, + }) + + def test_explicit(self): + self.assertEqual(self.s.resolve('Entity', '!attr'), ['attr']) + self.assertEqual(self.s.resolve('Entity', '$a'), ['attr']) + self.assertEqual(self.s.resolve('Entity', '$b'), ['attr']) + self.assertEqual(self.s.resolve('Entity', '#x'), ['attr']) + self.assertEqual(self.s.resolve('Entity', '#y'), ['attr']) + + def test_namespace(self): + self.assertEqual(self.s.resolve('Entity', '$with:namespace'), ['attr']) + self.assertEqual(self.s.resolve('Entity', 'with:namespace'), ['attr']) + + def test_implicit(self): + self.assertEqual(self.s.resolve('Entity', 'attr'), ['attr']) + self.assertEqual(self.s.resolve('Entity', 'a'), ['attr']) + self.assertEqual(self.s.resolve('Entity', 'b'), ['attr']) + + def test_prefix(self): + self.assertEqual(self.s.resolve('Entity', 'sg_type'), ['sg_type']) + self.assertEqual(self.s.resolve('Entity', 'type'), ['sg_type']) + self.assertEqual(self.s.resolve('Entity', '!type'), ['type']) + + self.assertEqual(self.s.resolve('Entity', 'sg_name'), ['sg_name']) + self.assertEqual(self.s.resolve('Entity', 'name'), ['name']) # different! + self.assertEqual(self.s.resolve('Entity', '!name'), ['name']) + + def test_missing_entity(self): + self.assertRaises(ValueError, self.s.resolve, 'Missing', 'field_name') + + def test_missing(self): + self.assertEqual(self.s.resolve('Entity', '$missing'), []) + self.assertEqual(self.s.resolve('Entity', '#missing'), []) + self.assertEqual(self.s.resolve('Entity', '!missing'), ['missing']) + self.assertEqual(self.s.resolve('Entity', 'missing'), ['missing']) + self.assertRaises(ValueError, self.s.resolve, 'Entity', 'missing', strict=True) +