Skip to content

Commit

Permalink
Schema.load_entrypoints
Browse files Browse the repository at this point in the history
  • Loading branch information
mikeboers committed Oct 21, 2015
1 parent c44ef71 commit 235add6
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 32 deletions.
2 changes: 1 addition & 1 deletion sgschema/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def _reduce_raw(self, schema, raw_entity):

def __getstate__(self):
return dict((k, v) for k, v in (
('fields', dict((n, f.__getstate__()) for n, f in self.fields.iteritems())),
('fields', self.fields),
('field_aliases', self.field_aliases),
('field_tags', self.field_tags),
) if v)
Expand Down
125 changes: 94 additions & 31 deletions sgschema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,26 +32,21 @@ def from_cache(cls, base_url):
if not isinstance(base_url, basestring):
base_url = base_url.base_url

# Try to return a single instance.
try:
return cls._cache_instances[base_url]
# Try to return a single instance.
schema = cls._cache_instances[base_url]

except KeyError:
pass

import pkg_resources
for ep in pkg_resources.iter_entry_points('sgschema_cache'):
func = ep.load()
cache = func(base_url)
if cache is not None:
break
else:
raise ValueError('cannot find cache for %s' % base_url)
schema = cls()
schema.load_entry_points(base_url, group='sgschema_cache')

# Cache it so we only load it once.
cls._cache_instances[base_url] = schema

schema = cls()
schema.load(cache)
if not schema.entities:
raise ValueError('no data cached')

# Cache it so we only load it once.
cls._cache_instances[base_url] = schema
return schema

def __init__(self):
Expand Down Expand Up @@ -132,7 +127,7 @@ def dump_raw(self, path):

def __getstate__(self):
return dict((k, v) for k, v in (
('entities', dict((n, e.__getstate__()) for n, e in self.entities.iteritems())),
('entities', self.entities),
('entity_aliases', self.entity_aliases),
('entity_tags', self.entity_tags),
) if v)
Expand All @@ -146,13 +141,6 @@ def dump(self, path):
with open(path, 'w') as fh:
fh.write(json.dumps(self, indent=4, sort_keys=True, default=lambda x: x.__getstate__()))

def load_directory(self, dir_path):
"""Load all ``.json`` files in the given directory."""
for file_name in os.listdir(dir_path):
if file_name.startswith('.') or not file_name.endswith('.json'):
continue
self.load(os.path.join(dir_path, file_name))

def load_raw(self, path):
"""Load a JSON file containing a raw schema."""
raw = json.loads(open(path).read())
Expand All @@ -171,26 +159,101 @@ def load_raw(self, path):

self._reduce_raw()

def load(self, input_):
def load_directory(self, dir_path):
"""Load all ``.json`` and ``.yaml`` files in the given directory."""
for file_name in os.listdir(dir_path):
if file_name.startswith('.'):
continue
if os.path.splitext(file_name)[1] not in ('.json', '.yaml'):
continue
self.load(os.path.join(dir_path, file_name))

def load_entry_points(self, base_url, group='sgschema_loaders', verbose=False):
"""Call pkg_resources' entry points to get schema data.
This calls all entry points (sorted by name) in the given group,
passing the ``Schema`` object, and the given ``base_url``. Any
returned values will be passed to :meth:`load`.
An entry point may raise ``StopIteration`` to force it to be the last.
By convension, names starting with ``000_`` should provide basic (e.g.
raw or reduced from raw) data, ``100_`` for package defaults, and ``500_``
for site data, and `zzz_` for user overrides.
:param str base_url: To be passed to all entry points.
:param str group: The entry point group; defaults to ``sgschema_loaders``.
Another common group is ``sgschema_cache``, which is used by
:meth:`from_cache`.
:param bool verbose: Print entrypoints as we go.
"""

import pkg_resources

entry_points = list(pkg_resources.iter_entry_points(group))
entry_points.sort(key=lambda ep: ep.name)

for ep in entry_points:
if verbose:
print 'loading from', ep.name
func = ep.load()
try:
data = func(self, base_url)
except StopIteration:
return
else:
if data:
schema.load(data)

def load(self, input_, recurse=True):
"""Load a JSON file or ``dict`` containing schema structures.
If passed a string, we treat is as a path to a JSON file.
If passed a dict, it is handled directly.
If passed a string, we treat is as a path to a JSON or YAML file
(where JSON is preferred due to speed).
If passed a dict, it is handed off to :meth:`update`.
If passed another iterable, we recurse passing every element back
into this method. This is useful for entry points yielding paths.
"""

if isinstance(input_, basestring):
if isinstance(input_, basestring):
encoded = open(input_).read()
raw_schema = json.loads(encoded)
if input_.endswith('.json'):
data = json.loads(encoded)
elif input_.endswith('.yaml'):
import yaml # Delay as long as possible.
data = yaml.load(encoded)
else:
raise ValueError('unknown filetype %s' % os.path.splitext(input_))
self.update(data)

elif isinstance(input_, dict):
raw_schema = copy.deepcopy(input_)
self.update(input_)

elif recurse:
for x in input_:
self.load(x, recurse=False)

else:
raise TypeError('require str path or dict schema')
raise TypeError('load needs str, dict, or list')

self.__setstate__(raw_schema)
def update(self, *args, **kwargs):
for arg in args:
if not isinstance(arg, dict):
raise TypeError('Schema.update needs dict')
self.__setstate__(arg)
if kwargs:
self.__setstate__(kwargs)

def __setstate__(self, raw_schema):

# We mutate this object, and aren't sure how any pickler will feel
# about it.
raw_schema = copy.deepcopy(raw_schema)

# If it is a dictionary of entity types, pretend it is in an "entities" key.
title_cased = sum(int(k[:1].isupper()) for k in raw_schema)
if title_cased:
Expand Down

0 comments on commit 235add6

Please sign in to comment.