diff --git a/chalice_spec/docs.py b/chalice_spec/docs.py index 9055adb..7b455ca 100644 --- a/chalice_spec/docs.py +++ b/chalice_spec/docs.py @@ -16,7 +16,13 @@ class Response: and an optional description. """ - def __init__(self, model: type, code: int = 200, description: str = "Success", content_type: str = DEFAULT_CONTENT_TYPE): + def __init__( + self, + model: type, + code: int = 200, + description: str = "Success", + content_type: str = DEFAULT_CONTENT_TYPE + ): self.model = model self.code = code self.description = description @@ -64,7 +70,9 @@ def _populate_response(self, response: Union[Response, type]): self.responses = {response.code: {DEFAULT_CONTENT_TYPE: response}} else: # If not, we will use sensible defaults - self.responses = {DEFAULT_CODE: {DEFAULT_CONTENT_TYPE: Response(model=response)}} + self.responses = { + DEFAULT_CODE: {DEFAULT_CONTENT_TYPE: Response(model=response)} + } def _populate_responses(self, responses: List[Response]): self.responses = {} @@ -72,7 +80,9 @@ def _populate_responses(self, responses: List[Response]): if response.code not in self.responses: self.responses[response.code] = {} if response.content_type in self.responses[response.code]: - raise TypeError(f"Multiple responses defined for {response.code} — {response.content_type}") + raise TypeError( + f"Multiple responses defined for {response.code} — {response.content_type}" + ) self.responses[response.code][response.content_type] = response diff --git a/tests/test_chalice.py b/tests/test_chalice.py index 33edcdd..358a6a5 100644 --- a/tests/test_chalice.py +++ b/tests/test_chalice.py @@ -432,7 +432,13 @@ def test_content_types(): "/posts", methods=["POST"], content_types=["multipart/form-data"], - docs=Docs(request=TestSchema, responses=[Resp(model=AnotherSchema, content_type="application/json"), Resp(model=AnotherSchema, content_type="application/xml")]), + docs=Docs( + request=TestSchema, + responses=[ + Resp(model=AnotherSchema, content_type="application/json"), + Resp(model=AnotherSchema, content_type="application/xml") + ], + ), ) def get_post(): pass