diff --git a/TODO.md b/TODO.md index d3b3b4f..7ed3139 100644 --- a/TODO.md +++ b/TODO.md @@ -1,8 +1,7 @@ -- test the loading and resolution of aliases and tags +- Schema.for_url(base_url) -> static class + -- caches of the raw schema; both public ones and the private one -- cache of the reduced schema - role assignments for columns, so that our tools that access roles (via a special syntax) instead of actual column names diff --git a/sgschema/__init__.py b/sgschema/__init__.py index e69de29..3ce1da1 100644 --- a/sgschema/__init__.py +++ b/sgschema/__init__.py @@ -0,0 +1 @@ +from .schema import Schema diff --git a/sgschema/schema.py b/sgschema/schema.py index 692ba15..081046e 100644 --- a/sgschema/schema.py +++ b/sgschema/schema.py @@ -2,13 +2,14 @@ import json import os import re +import copy import requests import yaml from .entity import Entity from .field import Field -from .utils import cached_property +from .utils import cached_property, merge_update class Schema(object): @@ -82,17 +83,6 @@ def _reduce_raw(self): field = entity._get_or_make_field(field_name) field._reduce_raw(self, raw_field) - - def _dump_prep(self, value): - if isinstance(value, unicode): - return value.encode("utf8") - elif isinstance(value, dict): - return {self._dump_prep(k): self._dump_prep(v) for k, v in value.iteritems()} - elif isinstance(value, (tuple, list)): - return [self._dump_prep(x) for x in value] - else: - return value - def dump(self, path, raw=False): if raw: with open(path, 'w') as fh: @@ -130,26 +120,36 @@ def load_raw(self, path): self._reduce_raw() - def load(self, path): + def load(self, input_): - encoded = open(path).read() - raw = json.loads(encoded) - #raw = ast.literal_eval(encoded) + if isinstance(input_, basestring): + encoded = open(input_).read() + raw_schema = json.loads(encoded) + elif isinstance(input_, dict): + raw_schema = copy.deepcopy(input_) + else: + raise TypeError('require str path or dict schema') # If it is a dictionary of entity types, pretend it is in an "entities" key. - title_cased = sum(int(k[:1].isupper()) for k in raw) + title_cased = sum(int(k[:1].isupper()) for k in raw_schema) if title_cased: - if len(raw) != title_cased: + if len(raw_schema) != title_cased: raise ValueError('mix of direct and indirect entity specifications') - raw = {'entities': raw} + raw_schema = {'entities': raw_schema} # Load the two direct fields. - for type_name, value in raw.pop('entities', {}).iteritems(): + for type_name, value in raw_schema.pop('entities', {}).iteritems(): self._get_or_make_entity(type_name)._load(value) - self._entity_aliases.update(raw.pop('entity_aliases', {})) + + merge_update(self._entity_aliases, raw_schema.pop('entity_aliases', {})) + merge_update(self._entity_tags , raw_schema.pop('entity_tags', {})) + + if raw_schema: + raise ValueError('unknown keys: %s' % ', '.join(sorted(raw_schema))) + return # Load any indirect fields. - for key, values in raw.iteritems(): + for key, values in raw_schema.iteritems(): if key.startswith('entity_'): entity_attr = key[7:] for type_name, value in values.iteritems(): @@ -194,4 +194,3 @@ def load(self, path): print schema.entity_aliases['Publish'] print schema.entities['PublishEvent'].field_aliases['type'] print schema.entities['PublishEvent'].field_tags['identifier_column'] - \ No newline at end of file diff --git a/sgschema/utils.py b/sgschema/utils.py index f3b942d..4d7ef90 100644 --- a/sgschema/utils.py +++ b/sgschema/utils.py @@ -16,3 +16,21 @@ def __get__(self, obj, type=None): except KeyError: obj.__dict__[self.__name__] = value = self.func(obj) return value + + +def merge_update(dst, src): + + for k, v in src.iteritems(): + + if k not in dst: + dst[k] = v + continue + + e = dst[k] + if isinstance(e, dict): + merge_update(e, v) + elif isinstance(e, list): + e.extend(v) + else: + dst[k] = v + diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..d1005b7 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,3 @@ +from unittest import TestCase + +from sgschema import Schema diff --git a/tests/test_load.py b/tests/test_load.py new file mode 100644 index 0000000..205fed4 --- /dev/null +++ b/tests/test_load.py @@ -0,0 +1,76 @@ +from . import * + +class TestLoading(TestCase): + + def test_load_entity_tags(self): + + s = Schema() + s.load({ + 'entities': {'Entity': { + 'tags': ['a'], + }}, + 'entity_tags': {'b': ['Entity']}, + }) + + self.assertIn('a', s.entities['Entity'].tags) + self.assertIn('b', s.entities['Entity'].tags) + self.assertIn('Entity', s.entity_tags['a']) + self.assertIn('Entity', s.entity_tags['b']) + + def test_load_field_tags(self): + + s = Schema() + s.load({ + 'Entity': { + 'fields': { + 'sg_type': { + 'tags': ['a'], + }, + }, + 'field_tags': { + 'b': ['sg_type'], + }, + }, + }) + + self.assertIn('a', s.entities['Entity'].fields['sg_type'].tags) + self.assertIn('b', s.entities['Entity'].fields['sg_type'].tags) + self.assertIn('sg_type', s.entities['Entity'].field_tags['a']) + self.assertIn('sg_type', s.entities['Entity'].field_tags['b']) + + def test_load_entity_aliases(self): + + s = Schema() + s.load({ + 'entities': {'Entity': { + 'aliases': ['A'], + }}, + 'entity_aliases': {'B': 'Entity'}, + }) + + self.assertIn('A', s.entities['Entity'].aliases) + self.assertIn('B', s.entities['Entity'].aliases) + self.assertEqual('Entity', s.entity_aliases['A']) + self.assertEqual('Entity', s.entity_aliases['B']) + + def test_load_field_aliases(self): + + s = Schema() + s.load({ + 'Entity': { + 'fields': { + 'sg_type': { + 'aliases': ['a'], + }, + }, + 'field_aliases': { + 'b': 'sg_type', + }, + }, + }) + + self.assertIn('a', s.entities['Entity'].fields['sg_type'].aliases) + self.assertIn('b', s.entities['Entity'].fields['sg_type'].aliases) + self.assertEqual('sg_type', s.entities['Entity'].field_aliases['a']) + self.assertEqual('sg_type', s.entities['Entity'].field_aliases['b']) +