Skip to content

Commit ff9d10f

Browse files
Don't rewind past already-read streams
1 parent 772d193 commit ff9d10f

File tree

1 file changed

+22
-25
lines changed

1 file changed

+22
-25
lines changed

packages/smithy-http/src/smithy_http/serializers.py

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
129129
)
130130
yield binding_serializer
131131
content_length = payload.tell()
132+
payload.seek(0)
132133
else:
133134
payload = BytesIO()
134135
payload_serializer = self._payload_codec.create_serializer(payload)
@@ -154,8 +155,8 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
154155
binding_matcher,
155156
)
156157
yield binding_serializer
158+
payload.seek(0)
157159

158-
self._seek(payload, 0)
159160
headers = binding_serializer.header_serializer.headers
160161
if content_type is not None:
161162
headers.append(("content-type", content_type))
@@ -165,7 +166,7 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
165166

166167
fields = tuples_to_fields(headers)
167168
if content_length_required and "content-length" not in fields:
168-
content_length = self._compute_content_length(payload)
169+
content_length = _compute_content_length(payload)
169170
if content_length is None:
170171
raise SerializationError(
171172
"This operation requires the the content length of the input "
@@ -187,17 +188,24 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
187188
body=payload,
188189
)
189190

190-
def _seek(self, payload: Any, pos: int, whence: int = 0) -> None:
191-
if (
192-
seek := getattr(payload, "seek", None)
193-
) is not None and not iscoroutinefunction(seek):
194-
seek(pos, whence)
195191

196-
def _compute_content_length(self, payload: Any) -> int | None:
197-
content_length = self._seek(payload, 0, 2)
198-
if content_length is not None:
199-
self._seek(payload, 0, 0)
192+
def _compute_content_length(payload: Any) -> int | None:
193+
if (tell := getattr(payload, "tell", None)) is not None and not iscoroutinefunction(
194+
tell
195+
):
196+
start: int = tell()
197+
if (end := _seek(payload, 0, 2)) is not None:
198+
content_length: int = end - start
199+
_seek(payload, start, 0)
200200
return content_length
201+
return None
202+
203+
204+
def _seek(payload: Any, pos: int, whence: int = 0) -> None:
205+
if (seek := getattr(payload, "seek", None)) is not None and not iscoroutinefunction(
206+
seek
207+
):
208+
seek(pos, whence)
201209

202210

203211
class HTTPRequestBindingSerializer(InterceptingSerializer):
@@ -303,6 +311,7 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
303311
)
304312
yield binding_serializer
305313
content_length = payload.tell()
314+
payload.seek(0)
306315
else:
307316
payload = BytesIO()
308317
payload_serializer = self._payload_codec.create_serializer(payload)
@@ -323,8 +332,8 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
323332
binding_matcher,
324333
)
325334
yield binding_serializer
335+
payload.seek(0)
326336

327-
self._seek(payload, 0)
328337
headers = binding_serializer.header_serializer.headers
329338
if content_type is not None:
330339
headers.append(("content-type", content_type))
@@ -334,7 +343,7 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
334343

335344
fields = tuples_to_fields(headers)
336345
if content_length_required and "content-length" not in fields:
337-
content_length = self._compute_content_length(payload)
346+
content_length = _compute_content_length(payload)
338347
if content_length is None:
339348
raise SerializationError(
340349
"This operation requires the the content length of the input "
@@ -355,18 +364,6 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
355364
status=status,
356365
)
357366

358-
def _seek(self, payload: Any, pos: int, whence: int = 0) -> int | None:
359-
if (
360-
seek := getattr(payload, "seek", None)
361-
) is not None and not iscoroutinefunction(seek):
362-
return seek(pos, whence)
363-
364-
def _compute_content_length(self, payload: Any) -> int | None:
365-
content_length = self._seek(payload, 0, 2)
366-
if content_length is not None:
367-
self._seek(payload, 0, 0)
368-
return content_length
369-
370367

371368
class HTTPResponseBindingSerializer(InterceptingSerializer):
372369
"""Delegates HTTP response bindings to binding-location-specific serializers."""

0 commit comments

Comments
 (0)