Skip to content

Commit d693c36

Browse files
Compute content length
1 parent 95e3f04 commit d693c36

File tree

3 files changed

+278
-66
lines changed

3 files changed

+278
-66
lines changed

packages/smithy-core/src/smithy_core/traits.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,12 @@ def __post_init__(self):
154154
assert self.document_value is None
155155

156156

157+
@dataclass(init=False, frozen=True)
158+
class RequiresLengthTrait(Trait, id=ShapeID("smithy.api#requiresLength")):
159+
def __post_init__(self):
160+
assert self.document_value is None
161+
162+
157163
@dataclass(init=False, frozen=True)
158164
class UnitTypeTrait(Trait, id=ShapeID("smithy.api#UnitTypeTrait")):
159165
def __post_init__(self):

packages/smithy-http/src/smithy_http/serializers.py

Lines changed: 76 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from smithy_core import URI
1111
from smithy_core.codecs import Codec
12+
from smithy_core.exceptions import SerializationError
1213
from smithy_core.schemas import Schema
1314
from smithy_core.serializers import (
1415
InterceptingSerializer,
@@ -24,12 +25,13 @@
2425
HTTPQueryTrait,
2526
HTTPTrait,
2627
MediaTypeTrait,
28+
RequiresLengthTrait,
2729
TimestampFormatTrait,
2830
)
2931
from smithy_core.types import PathPattern, TimestampFormat
3032
from smithy_core.utils import serialize_float
3133

32-
from . import tuples_to_fields
34+
from . import Field, tuples_to_fields
3335
from .aio import HTTPRequest as _HTTPRequest
3436
from .aio import HTTPResponse as _HTTPResponse
3537
from .aio.interfaces import HTTPRequest, HTTPResponse
@@ -43,6 +45,7 @@
4345
__all__ = ["HTTPRequestSerializer", "HTTPResponseSerializer"]
4446

4547

48+
# TODO: refactor this to share code with response serializer
4649
class HTTPRequestSerializer(SpecificShapeSerializer):
4750
"""Binds a serializable shape to an HTTP request.
4851
@@ -82,8 +85,12 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
8285
host_prefix = self._endpoint_trait.host_prefix
8386

8487
content_type = self._payload_codec.media_type
88+
content_length: int | None = None
89+
content_length_required = False
90+
8591
binding_matcher = RequestBindingMatcher(schema)
8692
if (payload_member := binding_matcher.payload_member) is not None:
93+
content_length_required = RequiresLengthTrait in payload_member
8794
if payload_member.shape_type in (
8895
ShapeType.BLOB,
8996
ShapeType.STRING,
@@ -105,6 +112,10 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
105112
)
106113
yield binding_serializer
107114
payload = payload_serializer.payload
115+
try:
116+
content_length = len(payload)
117+
except TypeError:
118+
pass
108119
else:
109120
if (media_type := payload_member.get_trait(MediaTypeTrait)) is not None:
110121
content_type = media_type.value
@@ -117,6 +128,8 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
117128
binding_matcher,
118129
)
119130
yield binding_serializer
131+
content_length = payload.tell()
132+
payload.seek(0)
120133
else:
121134
payload = BytesIO()
122135
payload_serializer = self._payload_codec.create_serializer(payload)
@@ -131,25 +144,36 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
131144
binding_matcher,
132145
)
133146
yield binding_serializer
147+
content_length = payload.tell()
134148
else:
135149
content_type = None
150+
content_length = 0
136151
binding_serializer = HTTPRequestBindingSerializer(
137152
payload_serializer,
138153
self._http_trait.path,
139154
host_prefix,
140155
binding_matcher,
141156
)
142157
yield binding_serializer
143-
144-
if (
145-
seek := getattr(payload, "seek", None)
146-
) is not None and not iscoroutinefunction(seek):
147-
seek(0)
158+
payload.seek(0)
148159

149160
headers = binding_serializer.header_serializer.headers
150161
if content_type is not None:
151162
headers.append(("content-type", content_type))
152163

164+
if content_length is not None:
165+
headers.append(("content-length", str(content_length)))
166+
167+
fields = tuples_to_fields(headers)
168+
if content_length_required and "content-length" not in fields:
169+
content_length = _compute_content_length(payload)
170+
if content_length is None:
171+
raise SerializationError(
172+
"This operation requires the the content length of the input "
173+
"stream, but it was not provided and was unable to be computed."
174+
)
175+
fields.set_field(Field(name="content-length", values=[str(content_length)]))
176+
153177
self.result = _HTTPRequest(
154178
method=self._http_trait.method,
155179
destination=URI(
@@ -160,11 +184,30 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
160184
prefix=self._http_trait.query or "",
161185
),
162186
),
163-
fields=tuples_to_fields(headers),
187+
fields=fields,
164188
body=payload,
165189
)
166190

167191

192+
def _compute_content_length(payload: Any) -> int | None:
193+
if (tell := getattr(payload, "tell", None)) is not None and not iscoroutinefunction(
194+
tell
195+
):
196+
start: int = tell()
197+
if (end := _seek(payload, 0, 2)) is not None:
198+
content_length: int = end - start
199+
_seek(payload, start, 0)
200+
return content_length
201+
return None
202+
203+
204+
def _seek(payload: Any, pos: int, whence: int = 0) -> None:
205+
if (seek := getattr(payload, "seek", None)) is not None and not iscoroutinefunction(
206+
seek
207+
):
208+
seek(pos, whence)
209+
210+
168211
class HTTPRequestBindingSerializer(InterceptingSerializer):
169212
"""Delegates HTTP request bindings to binding-location-specific serializers."""
170213

@@ -235,8 +278,12 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
235278
binding_serializer: HTTPResponseBindingSerializer
236279

237280
content_type: str | None = self._payload_codec.media_type
281+
content_length: int | None = None
282+
content_length_required = False
283+
238284
binding_matcher = ResponseBindingMatcher(schema)
239285
if (payload_member := binding_matcher.payload_member) is not None:
286+
content_length_required = RequiresLengthTrait in payload_member
240287
if payload_member.shape_type in (ShapeType.BLOB, ShapeType.STRING):
241288
if (media_type := payload_member.get_trait(MediaTypeTrait)) is not None:
242289
content_type = media_type.value
@@ -250,6 +297,10 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
250297
)
251298
yield binding_serializer
252299
payload = payload_serializer.payload
300+
try:
301+
content_length = len(payload)
302+
except TypeError:
303+
pass
253304
else:
254305
if (media_type := payload_member.get_trait(MediaTypeTrait)) is not None:
255306
content_type = media_type.value
@@ -259,6 +310,8 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
259310
payload_serializer, binding_matcher
260311
)
261312
yield binding_serializer
313+
content_length = payload.tell()
314+
payload.seek(0)
262315
else:
263316
payload = BytesIO()
264317
payload_serializer = self._payload_codec.create_serializer(payload)
@@ -270,23 +323,34 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
270323
body_serializer, binding_matcher
271324
)
272325
yield binding_serializer
326+
content_length = payload.tell()
273327
else:
274328
content_type = None
329+
content_length = 0
275330
binding_serializer = HTTPResponseBindingSerializer(
276331
payload_serializer,
277332
binding_matcher,
278333
)
279334
yield binding_serializer
280-
281-
if (
282-
seek := getattr(payload, "seek", None)
283-
) is not None and not iscoroutinefunction(seek):
284-
seek(0)
335+
payload.seek(0)
285336

286337
headers = binding_serializer.header_serializer.headers
287338
if content_type is not None:
288339
headers.append(("content-type", content_type))
289340

341+
if content_length is not None:
342+
headers.append(("content-length", str(content_length)))
343+
344+
fields = tuples_to_fields(headers)
345+
if content_length_required and "content-length" not in fields:
346+
content_length = _compute_content_length(payload)
347+
if content_length is None:
348+
raise SerializationError(
349+
"This operation requires the the content length of the input "
350+
"stream, but it was not provided and was unable to be computed."
351+
)
352+
fields.set_field(Field(name="content-length", values=[str(content_length)]))
353+
290354
status = binding_serializer.response_code_serializer.response_code
291355
if status is None:
292356
if binding_matcher.response_status > 0:

0 commit comments

Comments
 (0)