Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ref resolver #120

Merged
merged 3 commits into from
Sep 7, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions swagger_py_codegen/jsonschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from inspect import getsource

from .base import Code, CodeGenerator
from .parser import schema_var_name
from .parser import RefNode


class Schema(Code):
Expand Down Expand Up @@ -89,10 +89,8 @@ def build_data(swagger):
scopes[(endpoint, method)] = list(security.values()).pop()
break

schemas = OrderedDict([(schema_var_name(path), swagger.get(path)) for path in swagger.definitions])

data = dict(
schemas=schemas,
definitions={'definitions':swagger.origin_data.get('definitions', {})},
validators=validators,
filters=filters,
scopes=scopes,
Expand All @@ -109,7 +107,7 @@ def _process(self):
yield Schema(build_data(self.swagger))


def merge_default(schema, value, get_first=True):
def merge_default(schema, value, get_first=True, resolver=None):
# TODO: more types support
type_defaults = {
'integer': 9573,
Expand All @@ -119,17 +117,17 @@ def merge_default(schema, value, get_first=True):
'boolean': False
}

results = normalize(schema, value, type_defaults)
results = normalize(schema, value, type_defaults, resolver=resolver)
if get_first:
return results[0]
return results


def build_default(schema):
return merge_default(schema, None)
def build_default(schema, resolver=None):
return merge_default(schema, None, resolver=resolver)


def normalize(schema, data, required_defaults=None):
def normalize(schema, data, required_defaults=None, resolver=None):
if required_defaults is None:
required_defaults = {}
errors = []
Expand Down Expand Up @@ -217,7 +215,7 @@ def _normalize_dict(schema, data):

def _normalize_list(schema, data):
result = []
if hasattr(data, '__iter__') and not isinstance(data, dict):
if hasattr(data, '__iter__') and not isinstance(data, (dict, RefNode)):
for item in data:
result.append(_normalize(schema.get('items'), item))
elif 'default' in schema:
Expand All @@ -230,6 +228,15 @@ def _normalize_default(schema, data):
else:
return data

def _normalize_ref(schema, data):
if resolver == None:
raise TypeError("resolver must be provided")
ref = schema.get(u"$ref")
scope, resolved = resolver.resolve(ref)
return _normalize(resolved, data)



def _normalize(schema, data):
if schema is True or schema == {}:
return data
Expand All @@ -239,10 +246,13 @@ def _normalize(schema, data):
'object': _normalize_dict,
'array': _normalize_list,
'default': _normalize_default,
'ref': _normalize_ref
}
type_ = schema.get('type', 'object')
if type_ not in funcs:
type_ = 'default'
if schema.get(u'$ref', None):
type_ = 'ref'

return funcs[type_](schema, data)

Expand Down
47 changes: 33 additions & 14 deletions swagger_py_codegen/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,36 @@ def schema_var_name(path):
return ''.join(map(str.capitalize, map(str, path)))


class RefNode(dict):
class RefNode(object):

def __init__(self, data, ref):
self.ref = ref
super(RefNode, self).__init__(data)
self._data = data


def __getitem__(self, key):
return self._data.__getitem__(key)

def __setitem__(self, key, value):
return self._data.__setitem__(key, value)

def __getattr__(self, key):
return self._data.__getattribute__(key)

def __iter__(self):
return self._data.__iter__()

def __repr__(self):
return schema_var_name(self.ref)
return repr({'$ref':self.ref})

def __eq__(self, other):
if isinstance(other, RefNode):
return self._data == other._data and self.ref == other.ref
else:
return object.__eq__(other)

def copy(self):
return RefNode(self._data, self.ref)

class Swagger(object):

Expand All @@ -40,14 +61,10 @@ def _process_ref(self):
"""
resolve all references util no reference exists
"""
while 1:
li = list(self.search(['**', '$ref']))
if not li:
break
for path, ref in li:
data = resolve(self.data, ref)
path = path[:-1]
self.set(path, data)
for path, ref in self.search(['**', '$ref']):
data = resolve(self.data, ref)
path = path[:-1]
self.set(path, RefNode(data, ref))

def _resolve_definitions(self):
"""
Expand Down Expand Up @@ -76,17 +93,19 @@ def get_definition_refs():
while definition_refs:
ready = {
definition for definition, refs
in six.iteritems(definition_refs) if not refs
in six.iteritems(definition_refs)
}
if not ready:
msg = '$ref circular references found!\n'
raise ValueError(msg)
continue
#msg = '$ref circular references found!\n'
#raise ValueError(msg)
for definition in ready:
del definition_refs[definition]
for refs in six.itervalues(definition_refs):
refs.difference_update(ready)

self._definitions += ready
self._definitions.sort(key=lambda x :x[1])

def search(self, path):
for p, d in dpath.util.search(
Expand Down
12 changes: 6 additions & 6 deletions swagger_py_codegen/templates/falcon/validators.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ from werkzeug.datastructures import MultiDict, Headers
from jsonschema import Draft4Validator

from .schemas import (
validators, filters, scopes, security, base_path, normalize)
validators, filters, scopes, resolver, security, base_path, normalize)


if six.PY3:
Expand Down Expand Up @@ -44,7 +44,7 @@ class JSONEncoder(json.JSONEncoder):
class FalconValidatorAdaptor(object):

def __init__(self, schema):
self.validator = Draft4Validator(schema)
self.validator = Draft4Validator(schema, resolver=resolver)

def validate_number(self, type_, value):
try:
Expand Down Expand Up @@ -87,7 +87,7 @@ class FalconValidatorAdaptor(object):
def validate(self, value):
value = self.type_convert(value)
errors = {e.path[0]: e.message for e in self.validator.iter_errors(value)}
return normalize(self.validator.schema, value)[0], errors
return normalize(self.validator.schema, value, resolver=resolver)[0], errors


def request_validate(req, resp, resource, params):
Expand Down Expand Up @@ -154,15 +154,15 @@ def response_filter(req, resp, resource):
'Not defined',
description='`%d` is not a defined status code.' % status)

_resp, errors = normalize(schemas['schema'], req.context['result'])
_resp, errors = normalize(schemas['schema'], req.context['result'], resolver=resolver)
if schemas['headers']:
headers, header_errors = normalize(
{'properties': schemas['headers']}, headers)
{'properties': schemas['headers']}, headers, resolver=resolver)
errors.extend(header_errors)
if errors:
raise falcon.HTTPInternalServerError(title='Expectation Failed',
description=errors)

if 'result' not in req.context:
return
resp.body = json.dumps(_resp)
resp.body = json.dumps(_resp)
10 changes: 5 additions & 5 deletions swagger_py_codegen/templates/flask/validators.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ from flask_restful.utils import unpack
from jsonschema import Draft4Validator

from .schemas import (
validators, filters, scopes, security, merge_default, normalize)
validators, filters, scopes, resolver, security, merge_default, normalize)


class JSONEncoder(json.JSONEncoder):
Expand All @@ -29,7 +29,7 @@ class JSONEncoder(json.JSONEncoder):
class FlaskValidatorAdaptor(object):

def __init__(self, schema):
self.validator = Draft4Validator(schema)
self.validator = Draft4Validator(schema, resolver=resolver)

def validate_number(self, type_, value):
try:
Expand Down Expand Up @@ -72,7 +72,7 @@ class FlaskValidatorAdaptor(object):
def validate(self, value):
value = self.type_convert(value)
errors = list(e.message for e in self.validator.iter_errors(value))
return normalize(self.validator.schema, value)[0], errors
return normalize(self.validator.schema, value, resolver=resolver)[0], errors


def request_validate(view):
Expand Down Expand Up @@ -136,10 +136,10 @@ def response_filter(view):
# return resp, status, headers
abort(500, message='`%d` is not a defined status code.' % status)

resp, errors = normalize(schemas['schema'], resp)
resp, errors = normalize(schemas['schema'], resp, resolver=resolver)
if schemas['headers']:
headers, header_errors = normalize(
{'properties': schemas['headers']}, headers)
{'properties': schemas['headers']}, headers, resolver=resolver)
errors.extend(header_errors)
if errors:
abort(500, message='Expectation Failed', errors=errors)
Expand Down
8 changes: 5 additions & 3 deletions swagger_py_codegen/templates/jsonschema/schemas.tpl
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
# -*- coding: utf-8 -*-

import six
from jsonschema import RefResolver
from swagger_py_codegen.parser import RefNode

# TODO: datetime support


{% include '_do_not_change.tpl' %}

base_path = '{{base_path}}'

{% for name, value in schemas.items() %}
{{ name }} = {{ value }}
{%- endfor %}
definitions = {{ definitions }}

validators = {
{%- for name, value in validators.items() %}
Expand All @@ -30,6 +31,7 @@ scopes = {
{%- endfor %}
}

resolver = RefResolver.from_schema(definitions)

class Security(object):

Expand Down
10 changes: 5 additions & 5 deletions swagger_py_codegen/templates/sanic/validators.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ from sanic.request import RequestParameters
from jsonschema import Draft4Validator

from .schemas import (
validators, filters, scopes, security, base_path, normalize, current)
validators, filters, scopes, security, resolver, base_path, normalize, current)


def unpack(value):
Expand Down Expand Up @@ -63,7 +63,7 @@ class JSONEncoder(json.JSONEncoder):
class SanicValidatorAdaptor(object):

def __init__(self, schema):
self.validator = Draft4Validator(schema)
self.validator = Draft4Validator(schema, resolver=resolver)

def validate_number(self, type_, value):
try:
Expand Down Expand Up @@ -106,7 +106,7 @@ class SanicValidatorAdaptor(object):
def validate(self, value):
value = self.type_convert(value)
errors = list(e.message for e in self.validator.iter_errors(value))
return normalize(self.validator.schema, value)[0], errors
return normalize(self.validator.schema, value, resolver=resolver)[0], errors


def request_validate(view):
Expand Down Expand Up @@ -175,10 +175,10 @@ def response_filter(view):
# return resp, status, headers
raise ServerError('`%d` is not a defined status code.' % status, 500)

resp, errors = normalize(schemas['schema'], resp)
resp, errors = normalize(schemas['schema'], resp, resolver=resolver)
if schemas['headers']:
headers, header_errors = normalize(
{'properties': schemas['headers']}, headers)
{'properties': schemas['headers']}, headers, resolver=resolver)
errors.extend(header_errors)
if errors:
raise ServerError('Expectation Failed', 500)
Expand Down
12 changes: 6 additions & 6 deletions swagger_py_codegen/templates/tornado/validators.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ import six
from functools import wraps
from jsonschema import Draft4Validator

from .schemas import validators, scopes, normalize, filters
from .schemas import validators, scopes, resolver, normalize, filters


class ValidatorAdaptor(object):

def __init__(self, schema):
self.validator = Draft4Validator(schema)
self.validator = Draft4Validator(schema, resolver=resolver)

def validate_number(self, type_, value):
try:
Expand Down Expand Up @@ -66,7 +66,7 @@ class ValidatorAdaptor(object):
def validate(self, value):
value = self.type_convert(value)
errors = list(e.message for e in self.validator.iter_errors(value))
return normalize(self.validator.schema, value)[0], errors
return normalize(self.validator.schema, value, resolver=resolver)[0], errors

def request_validate(obj):
def _request_validate(view):
Expand Down Expand Up @@ -134,10 +134,10 @@ def response_filter(obj):
raise tornado.web.HTTPError(
500, message='`%d` is not a defined status code.' % status)

resp, errors = normalize(schemas['schema'], resp)
resp, errors = normalize(schemas['schema'], resp, resolver=resolver)
if schemas['headers']:
headers, header_errors = normalize(
{'properties': schemas['headers']}, headers)
{'properties': schemas['headers']}, headers, resolver=resolver)
errors.extend(header_errors)
if errors:
raise tornado.web.HTTPError(
Expand Down Expand Up @@ -167,4 +167,4 @@ def unpack(value):
except ValueError:
pass

return value, 200, {}
return value, 200, {}
Loading