9
9
10
10
from smithy_core import URI
11
11
from smithy_core .codecs import Codec
12
+ from smithy_core .exceptions import SerializationError
12
13
from smithy_core .schemas import Schema
13
14
from smithy_core .serializers import (
14
15
InterceptingSerializer ,
24
25
HTTPQueryTrait ,
25
26
HTTPTrait ,
26
27
MediaTypeTrait ,
28
+ RequiresLengthTrait ,
27
29
TimestampFormatTrait ,
28
30
)
29
31
from smithy_core .types import PathPattern , TimestampFormat
30
32
from smithy_core .utils import serialize_float
31
33
32
- from . import tuples_to_fields
34
+ from . import Field , tuples_to_fields
33
35
from .aio import HTTPRequest as _HTTPRequest
34
36
from .aio import HTTPResponse as _HTTPResponse
35
37
from .aio .interfaces import HTTPRequest , HTTPResponse
43
45
__all__ = ["HTTPRequestSerializer" , "HTTPResponseSerializer" ]
44
46
45
47
48
+ # TODO: refactor this to share code with response serializer
46
49
class HTTPRequestSerializer (SpecificShapeSerializer ):
47
50
"""Binds a serializable shape to an HTTP request.
48
51
@@ -82,8 +85,12 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
82
85
host_prefix = self ._endpoint_trait .host_prefix
83
86
84
87
content_type = self ._payload_codec .media_type
88
+ content_length : int | None = None
89
+ content_length_required = False
90
+
85
91
binding_matcher = RequestBindingMatcher (schema )
86
92
if (payload_member := binding_matcher .payload_member ) is not None :
93
+ content_length_required = RequiresLengthTrait in payload_member
87
94
if payload_member .shape_type in (
88
95
ShapeType .BLOB ,
89
96
ShapeType .STRING ,
@@ -105,6 +112,10 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
105
112
)
106
113
yield binding_serializer
107
114
payload = payload_serializer .payload
115
+ try :
116
+ content_length = len (payload )
117
+ except TypeError :
118
+ pass
108
119
else :
109
120
if (media_type := payload_member .get_trait (MediaTypeTrait )) is not None :
110
121
content_type = media_type .value
@@ -117,6 +128,8 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
117
128
binding_matcher ,
118
129
)
119
130
yield binding_serializer
131
+ content_length = payload .tell ()
132
+ payload .seek (0 )
120
133
else :
121
134
payload = BytesIO ()
122
135
payload_serializer = self ._payload_codec .create_serializer (payload )
@@ -131,25 +144,36 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
131
144
binding_matcher ,
132
145
)
133
146
yield binding_serializer
147
+ content_length = payload .tell ()
134
148
else :
135
149
content_type = None
150
+ content_length = 0
136
151
binding_serializer = HTTPRequestBindingSerializer (
137
152
payload_serializer ,
138
153
self ._http_trait .path ,
139
154
host_prefix ,
140
155
binding_matcher ,
141
156
)
142
157
yield binding_serializer
143
-
144
- if (
145
- seek := getattr (payload , "seek" , None )
146
- ) is not None and not iscoroutinefunction (seek ):
147
- seek (0 )
158
+ payload .seek (0 )
148
159
149
160
headers = binding_serializer .header_serializer .headers
150
161
if content_type is not None :
151
162
headers .append (("content-type" , content_type ))
152
163
164
+ if content_length is not None :
165
+ headers .append (("content-length" , str (content_length )))
166
+
167
+ fields = tuples_to_fields (headers )
168
+ if content_length_required and "content-length" not in fields :
169
+ content_length = _compute_content_length (payload )
170
+ if content_length is None :
171
+ raise SerializationError (
172
+ "This operation requires the the content length of the input "
173
+ "stream, but it was not provided and was unable to be computed."
174
+ )
175
+ fields .set_field (Field (name = "content-length" , values = [str (content_length )]))
176
+
153
177
self .result = _HTTPRequest (
154
178
method = self ._http_trait .method ,
155
179
destination = URI (
@@ -160,11 +184,30 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
160
184
prefix = self ._http_trait .query or "" ,
161
185
),
162
186
),
163
- fields = tuples_to_fields ( headers ) ,
187
+ fields = fields ,
164
188
body = payload ,
165
189
)
166
190
167
191
192
+ def _compute_content_length (payload : Any ) -> int | None :
193
+ if (tell := getattr (payload , "tell" , None )) is not None and not iscoroutinefunction (
194
+ tell
195
+ ):
196
+ start : int = tell ()
197
+ if (end := _seek (payload , 0 , 2 )) is not None :
198
+ content_length : int = end - start
199
+ _seek (payload , start , 0 )
200
+ return content_length
201
+ return None
202
+
203
+
204
+ def _seek (payload : Any , pos : int , whence : int = 0 ) -> None :
205
+ if (seek := getattr (payload , "seek" , None )) is not None and not iscoroutinefunction (
206
+ seek
207
+ ):
208
+ seek (pos , whence )
209
+
210
+
168
211
class HTTPRequestBindingSerializer (InterceptingSerializer ):
169
212
"""Delegates HTTP request bindings to binding-location-specific serializers."""
170
213
@@ -235,8 +278,12 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
235
278
binding_serializer : HTTPResponseBindingSerializer
236
279
237
280
content_type : str | None = self ._payload_codec .media_type
281
+ content_length : int | None = None
282
+ content_length_required = False
283
+
238
284
binding_matcher = ResponseBindingMatcher (schema )
239
285
if (payload_member := binding_matcher .payload_member ) is not None :
286
+ content_length_required = RequiresLengthTrait in payload_member
240
287
if payload_member .shape_type in (ShapeType .BLOB , ShapeType .STRING ):
241
288
if (media_type := payload_member .get_trait (MediaTypeTrait )) is not None :
242
289
content_type = media_type .value
@@ -250,6 +297,10 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
250
297
)
251
298
yield binding_serializer
252
299
payload = payload_serializer .payload
300
+ try :
301
+ content_length = len (payload )
302
+ except TypeError :
303
+ pass
253
304
else :
254
305
if (media_type := payload_member .get_trait (MediaTypeTrait )) is not None :
255
306
content_type = media_type .value
@@ -259,6 +310,8 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
259
310
payload_serializer , binding_matcher
260
311
)
261
312
yield binding_serializer
313
+ content_length = payload .tell ()
314
+ payload .seek (0 )
262
315
else :
263
316
payload = BytesIO ()
264
317
payload_serializer = self ._payload_codec .create_serializer (payload )
@@ -270,23 +323,34 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
270
323
body_serializer , binding_matcher
271
324
)
272
325
yield binding_serializer
326
+ content_length = payload .tell ()
273
327
else :
274
328
content_type = None
329
+ content_length = 0
275
330
binding_serializer = HTTPResponseBindingSerializer (
276
331
payload_serializer ,
277
332
binding_matcher ,
278
333
)
279
334
yield binding_serializer
280
-
281
- if (
282
- seek := getattr (payload , "seek" , None )
283
- ) is not None and not iscoroutinefunction (seek ):
284
- seek (0 )
335
+ payload .seek (0 )
285
336
286
337
headers = binding_serializer .header_serializer .headers
287
338
if content_type is not None :
288
339
headers .append (("content-type" , content_type ))
289
340
341
+ if content_length is not None :
342
+ headers .append (("content-length" , str (content_length )))
343
+
344
+ fields = tuples_to_fields (headers )
345
+ if content_length_required and "content-length" not in fields :
346
+ content_length = _compute_content_length (payload )
347
+ if content_length is None :
348
+ raise SerializationError (
349
+ "This operation requires the the content length of the input "
350
+ "stream, but it was not provided and was unable to be computed."
351
+ )
352
+ fields .set_field (Field (name = "content-length" , values = [str (content_length )]))
353
+
290
354
status = binding_serializer .response_code_serializer .response_code
291
355
if status is None :
292
356
if binding_matcher .response_status > 0 :
0 commit comments