diff --git a/flask_apispec/wrapper.py b/flask_apispec/wrapper.py index 067261f..f04e6d0 100644 --- a/flask_apispec/wrapper.py +++ b/flask_apispec/wrapper.py @@ -1,15 +1,18 @@ # -*- coding: utf-8 -*- +try: + from collections.abc import Mapping +except ImportError: # Python 2 + from collections import Mapping -from six.moves import http_client as http import flask - +import marshmallow as ma import werkzeug +from six.moves import http_client as http from webargs import flaskparser from flask_apispec import utils -import marshmallow as ma MARSHMALLOW_VERSION_INFO = tuple( [int(part) for part in ma.__version__.split('.') if part.isdigit()] @@ -43,8 +46,11 @@ def call_view(self, *args, **kwargs): parsed = parser.parse(schema, locations=option['kwargs']['locations']) if getattr(schema, 'many', False): args += tuple(parsed) - else: + elif isinstance(parsed, Mapping): kwargs.update(parsed) + else: + args += (parsed, ) + return self.func(*args, **kwargs) def marshal_result(self, unpacked, status_code): diff --git a/tests/test_views.py b/tests/test_views.py index 6ae9aa6..a192cf7 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -3,7 +3,7 @@ import json from flask import make_response -from marshmallow import fields, Schema +from marshmallow import fields, Schema, post_load from flask_apispec.utils import Ref from flask_apispec.views import MethodResource @@ -30,6 +30,31 @@ def view(**kwargs): res = client.get('/', {'name': 'freddie'}) assert res.json == {'name': 'freddie'} + def test_use_kwargs_schema_with_post_load(self, app, client): + class User: + def __init__(self, name): + self.name = name + + def update(self, name): + self.name = name + + class ArgSchema(Schema): + name = fields.Str() + + @post_load + def make_object(self, data): + return User(**data) + + @app.route('/', methods=('POST', )) + @use_kwargs(ArgSchema()) + def view(user): + assert isinstance(user, User) + return {'name': user.name} + + data = {'name': 'freddie'} + res = client.post('/', data) + assert res.json == data + def test_use_kwargs_schema_many(self, app, client): class ArgSchema(Schema): name = fields.Str()