diff --git a/sgschema/entity.py b/sgschema/entity.py index 687d0d4..2203bbb 100644 --- a/sgschema/entity.py +++ b/sgschema/entity.py @@ -24,7 +24,7 @@ def _reduce_raw(self, schema, raw_entity): def __getstate__(self): return dict((k, v) for k, v in ( - ('fields', dict((n, f.__getstate__()) for n, f in self.fields.iteritems())), + ('fields', self.fields), ('field_aliases', self.field_aliases), ('field_tags', self.field_tags), ) if v) diff --git a/sgschema/schema.py b/sgschema/schema.py index 56676af..ed95a8a 100644 --- a/sgschema/schema.py +++ b/sgschema/schema.py @@ -32,26 +32,21 @@ def from_cache(cls, base_url): if not isinstance(base_url, basestring): base_url = base_url.base_url - # Try to return a single instance. try: - return cls._cache_instances[base_url] + # Try to return a single instance. + schema = cls._cache_instances[base_url] + except KeyError: - pass - import pkg_resources - for ep in pkg_resources.iter_entry_points('sgschema_cache'): - func = ep.load() - cache = func(base_url) - if cache is not None: - break - else: - raise ValueError('cannot find cache for %s' % base_url) + schema = cls() + schema.load_entry_points(base_url, group='sgschema_cache') + + # Cache it so we only load it once. + cls._cache_instances[base_url] = schema - schema = cls() - schema.load(cache) + if not schema.entities: + raise ValueError('no data cached') - # Cache it so we only load it once. - cls._cache_instances[base_url] = schema return schema def __init__(self): @@ -132,7 +127,7 @@ def dump_raw(self, path): def __getstate__(self): return dict((k, v) for k, v in ( - ('entities', dict((n, e.__getstate__()) for n, e in self.entities.iteritems())), + ('entities', self.entities), ('entity_aliases', self.entity_aliases), ('entity_tags', self.entity_tags), ) if v) @@ -146,13 +141,6 @@ def dump(self, path): with open(path, 'w') as fh: fh.write(json.dumps(self, indent=4, sort_keys=True, default=lambda x: x.__getstate__())) - def load_directory(self, dir_path): - """Load all ``.json`` files in the given directory.""" - for file_name in os.listdir(dir_path): - if file_name.startswith('.') or not file_name.endswith('.json'): - continue - self.load(os.path.join(dir_path, file_name)) - def load_raw(self, path): """Load a JSON file containing a raw schema.""" raw = json.loads(open(path).read()) @@ -171,26 +159,101 @@ def load_raw(self, path): self._reduce_raw() - def load(self, input_): + def load_directory(self, dir_path): + """Load all ``.json`` and ``.yaml`` files in the given directory.""" + for file_name in os.listdir(dir_path): + if file_name.startswith('.'): + continue + if os.path.splitext(file_name)[1] not in ('.json', '.yaml'): + continue + self.load(os.path.join(dir_path, file_name)) + + def load_entry_points(self, base_url, group='sgschema_loaders', verbose=False): + """Call pkg_resources' entry points to get schema data. + + This calls all entry points (sorted by name) in the given group, + passing the ``Schema`` object, and the given ``base_url``. Any + returned values will be passed to :meth:`load`. + + An entry point may raise ``StopIteration`` to force it to be the last. + + By convension, names starting with ``000_`` should provide basic (e.g. + raw or reduced from raw) data, ``100_`` for package defaults, and ``500_`` + for site data, and `zzz_` for user overrides. + + :param str base_url: To be passed to all entry points. + :param str group: The entry point group; defaults to ``sgschema_loaders``. + Another common group is ``sgschema_cache``, which is used by + :meth:`from_cache`. + :param bool verbose: Print entrypoints as we go. + + """ + + import pkg_resources + + entry_points = list(pkg_resources.iter_entry_points(group)) + entry_points.sort(key=lambda ep: ep.name) + + for ep in entry_points: + if verbose: + print 'loading from', ep.name + func = ep.load() + try: + data = func(self, base_url) + except StopIteration: + return + else: + if data: + schema.load(data) + + def load(self, input_, recurse=True): """Load a JSON file or ``dict`` containing schema structures. - If passed a string, we treat is as a path to a JSON file. - If passed a dict, it is handled directly. + If passed a string, we treat is as a path to a JSON or YAML file + (where JSON is preferred due to speed). + + If passed a dict, it is handed off to :meth:`update`. + + If passed another iterable, we recurse passing every element back + into this method. This is useful for entry points yielding paths. """ - if isinstance(input_, basestring): + if isinstance(input_, basestring): encoded = open(input_).read() - raw_schema = json.loads(encoded) + if input_.endswith('.json'): + data = json.loads(encoded) + elif input_.endswith('.yaml'): + import yaml # Delay as long as possible. + data = yaml.load(encoded) + else: + raise ValueError('unknown filetype %s' % os.path.splitext(input_)) + self.update(data) + elif isinstance(input_, dict): - raw_schema = copy.deepcopy(input_) + self.update(input_) + + elif recurse: + for x in input_: + self.load(x, recurse=False) + else: - raise TypeError('require str path or dict schema') + raise TypeError('load needs str, dict, or list') - self.__setstate__(raw_schema) + def update(self, *args, **kwargs): + for arg in args: + if not isinstance(arg, dict): + raise TypeError('Schema.update needs dict') + self.__setstate__(arg) + if kwargs: + self.__setstate__(kwargs) def __setstate__(self, raw_schema): + # We mutate this object, and aren't sure how any pickler will feel + # about it. + raw_schema = copy.deepcopy(raw_schema) + # If it is a dictionary of entity types, pretend it is in an "entities" key. title_cased = sum(int(k[:1].isupper()) for k in raw_schema) if title_cased: