2
2
3
3
use std:: { error:: Error , fmt, string:: FromUtf8Error , sync:: Arc } ;
4
4
5
+ use http_body_util:: BodyExt as _;
5
6
use hyper:: {
7
+ body,
6
8
header:: { self , HeaderValue } ,
7
- Body , Method , Request , Response , StatusCode ,
9
+ Method , Request , Response , StatusCode ,
8
10
} ;
9
11
use juniper:: {
10
12
http:: { GraphQLBatchRequest , GraphQLRequest as JuniperGraphQLRequest , GraphQLRequest } ,
@@ -16,7 +18,7 @@ use url::form_urlencoded;
16
18
pub async fn graphql_sync < CtxT , QueryT , MutationT , SubscriptionT , S > (
17
19
root_node : Arc < RootNode < ' static , QueryT , MutationT , SubscriptionT , S > > ,
18
20
context : Arc < CtxT > ,
19
- req : Request < Body > ,
21
+ req : Request < body :: Incoming > ,
20
22
) -> Response < String >
21
23
where
22
24
QueryT : GraphQLType < S , Context = CtxT > ,
37
39
pub async fn graphql < CtxT , QueryT , MutationT , SubscriptionT , S > (
38
40
root_node : Arc < RootNode < ' static , QueryT , MutationT , SubscriptionT , S > > ,
39
41
context : Arc < CtxT > ,
40
- req : Request < Body > ,
42
+ req : Request < body :: Incoming > ,
41
43
) -> Response < String >
42
44
where
43
45
QueryT : GraphQLTypeAsync < S , Context = CtxT > ,
56
58
}
57
59
58
60
async fn parse_req < S : ScalarValue > (
59
- req : Request < Body > ,
61
+ req : Request < body :: Incoming > ,
60
62
) -> Result < GraphQLBatchRequest < S > , Response < String > > {
61
63
match * req. method ( ) {
62
64
Method :: GET => parse_get_req ( req) ,
@@ -77,7 +79,7 @@ async fn parse_req<S: ScalarValue>(
77
79
}
78
80
79
81
fn parse_get_req < S : ScalarValue > (
80
- req : Request < Body > ,
82
+ req : Request < body :: Incoming > ,
81
83
) -> Result < GraphQLBatchRequest < S > , GraphQLRequestError > {
82
84
req. uri ( )
83
85
. query ( )
@@ -90,27 +92,29 @@ fn parse_get_req<S: ScalarValue>(
90
92
}
91
93
92
94
async fn parse_post_json_req < S : ScalarValue > (
93
- body : Body ,
95
+ body : body :: Incoming ,
94
96
) -> Result < GraphQLBatchRequest < S > , GraphQLRequestError > {
95
- let chunk = hyper:: body:: to_bytes ( body)
97
+ let chunk = body
98
+ . collect ( )
96
99
. await
97
100
. map_err ( GraphQLRequestError :: BodyHyper ) ?;
98
101
99
- let input = String :: from_utf8 ( chunk. iter ( ) . cloned ( ) . collect ( ) )
102
+ let input = String :: from_utf8 ( chunk. to_bytes ( ) . iter ( ) . cloned ( ) . collect ( ) )
100
103
. map_err ( GraphQLRequestError :: BodyUtf8 ) ?;
101
104
102
105
serde_json:: from_str :: < GraphQLBatchRequest < S > > ( & input)
103
106
. map_err ( GraphQLRequestError :: BodyJSONError )
104
107
}
105
108
106
109
async fn parse_post_graphql_req < S : ScalarValue > (
107
- body : Body ,
110
+ body : body :: Incoming ,
108
111
) -> Result < GraphQLBatchRequest < S > , GraphQLRequestError > {
109
- let chunk = hyper:: body:: to_bytes ( body)
112
+ let chunk = body
113
+ . collect ( )
110
114
. await
111
115
. map_err ( GraphQLRequestError :: BodyHyper ) ?;
112
116
113
- let query = String :: from_utf8 ( chunk. iter ( ) . cloned ( ) . collect ( ) )
117
+ let query = String :: from_utf8 ( chunk. to_bytes ( ) . iter ( ) . cloned ( ) . collect ( ) )
114
118
. map_err ( GraphQLRequestError :: BodyUtf8 ) ?;
115
119
116
120
Ok ( GraphQLBatchRequest :: Single ( GraphQLRequest :: new (
@@ -306,18 +310,19 @@ impl Error for GraphQLRequestError {
306
310
307
311
#[ cfg( test) ]
308
312
mod tests {
309
- use hyper:: {
310
- server:: Server ,
311
- service:: { make_service_fn, service_fn} ,
312
- Method , Response , StatusCode ,
313
+ use std:: {
314
+ convert:: Infallible , error:: Error , net:: SocketAddr , panic, sync:: Arc , time:: Duration ,
313
315
} ;
316
+
317
+ use hyper:: { server:: conn:: http1, service:: service_fn, Method , Response , StatusCode } ;
318
+ use hyper_util:: rt:: TokioIo ;
314
319
use juniper:: {
315
320
http:: tests as http_tests,
316
321
tests:: fixtures:: starwars:: schema:: { Database , Query } ,
317
322
EmptyMutation , EmptySubscription , RootNode ,
318
323
} ;
319
- use reqwest:: { self , blocking:: Response as ReqwestResponse } ;
320
- use std :: { convert :: Infallible , net:: SocketAddr , sync :: Arc , thread , time:: Duration } ;
324
+ use reqwest:: blocking:: Response as ReqwestResponse ;
325
+ use tokio :: { net:: TcpListener , task , time:: sleep } ;
321
326
322
327
struct TestHyperIntegration {
323
328
port : u16 ,
@@ -373,7 +378,7 @@ mod tests {
373
378
374
379
async fn run_hyper_integration ( is_sync : bool ) {
375
380
let port = if is_sync { 3002 } else { 3001 } ;
376
- let addr: SocketAddr = ( [ 127 , 0 , 0 , 1 ] , port) . into ( ) ;
381
+ let addr = SocketAddr :: from ( ( [ 127 , 0 , 0 , 1 ] , port) ) ;
377
382
378
383
let db = Arc :: new ( Database :: new ( ) ) ;
379
384
let root_node = Arc :: new ( RootNode :: new (
@@ -382,59 +387,74 @@ mod tests {
382
387
EmptySubscription :: < Database > :: new ( ) ,
383
388
) ) ;
384
389
385
- let new_service = make_service_fn ( move |_| {
386
- let root_node = root_node. clone ( ) ;
387
- let ctx = db. clone ( ) ;
390
+ let server: task:: JoinHandle < Result < ( ) , Box < dyn Error + Send + Sync > > > =
391
+ task:: spawn ( async move {
392
+ let listener = TcpListener :: bind ( addr) . await ?;
393
+
394
+ loop {
395
+ let ( stream, _) = listener. accept ( ) . await ?;
396
+ let io = TokioIo :: new ( stream) ;
388
397
389
- async move {
390
- Ok :: < _ , hyper:: Error > ( service_fn ( move |req| {
391
398
let root_node = root_node. clone ( ) ;
392
- let ctx = ctx. clone ( ) ;
393
- let matches = {
394
- let path = req. uri ( ) . path ( ) ;
395
- match req. method ( ) {
396
- & Method :: POST | & Method :: GET => {
397
- path == "/graphql" || path == "/graphql/"
398
- }
399
- _ => false ,
399
+ let db = db. clone ( ) ;
400
+
401
+ _ = task:: spawn ( async move {
402
+ let root_node = root_node. clone ( ) ;
403
+ let db = db. clone ( ) ;
404
+
405
+ if let Err ( e) = http1:: Builder :: new ( )
406
+ . serve_connection (
407
+ io,
408
+ service_fn ( move |req| {
409
+ let root_node = root_node. clone ( ) ;
410
+ let db = db. clone ( ) ;
411
+ let matches = {
412
+ let path = req. uri ( ) . path ( ) ;
413
+ match req. method ( ) {
414
+ & Method :: POST | & Method :: GET => {
415
+ path == "/graphql" || path == "/graphql/"
416
+ }
417
+ _ => false ,
418
+ }
419
+ } ;
420
+ async move {
421
+ Ok :: < _ , Infallible > ( if matches {
422
+ if is_sync {
423
+ super :: graphql_sync ( root_node, db, req) . await
424
+ } else {
425
+ super :: graphql ( root_node, db, req) . await
426
+ }
427
+ } else {
428
+ let mut resp = Response :: new ( String :: new ( ) ) ;
429
+ * resp. status_mut ( ) = StatusCode :: NOT_FOUND ;
430
+ resp
431
+ } )
432
+ }
433
+ } ) ,
434
+ )
435
+ . await
436
+ {
437
+ eprintln ! ( "server error: {e}" ) ;
400
438
}
401
- } ;
402
- async move {
403
- Ok :: < _ , Infallible > ( if matches {
404
- if is_sync {
405
- super :: graphql_sync ( root_node, ctx, req) . await
406
- } else {
407
- super :: graphql ( root_node, ctx, req) . await
408
- }
409
- } else {
410
- let mut resp = Response :: new ( String :: new ( ) ) ;
411
- * resp. status_mut ( ) = StatusCode :: NOT_FOUND ;
412
- resp
413
- } )
414
- }
415
- } ) )
416
- }
417
- } ) ;
418
-
419
- let ( shutdown_fut, shutdown) = futures:: future:: abortable ( async {
420
- tokio:: time:: sleep ( Duration :: from_secs ( 60 ) ) . await ;
421
- } ) ;
422
-
423
- let server = Server :: bind ( & addr)
424
- . serve ( new_service)
425
- . with_graceful_shutdown ( async {
426
- shutdown_fut. await . unwrap_err ( ) ;
439
+ } ) ;
440
+ }
427
441
} ) ;
428
442
429
- tokio:: task:: spawn_blocking ( move || {
430
- thread:: sleep ( Duration :: from_millis ( 10 ) ) ; // wait 10ms for server to bind
443
+ sleep ( Duration :: from_secs ( 10 ) ) . await ; // wait 10ms for `server` to bind
444
+
445
+ match task:: spawn_blocking ( move || {
431
446
let integration = TestHyperIntegration { port } ;
432
447
http_tests:: run_http_test_suite ( & integration) ;
433
- shutdown. abort ( ) ;
434
- } ) ;
448
+ } )
449
+ . await
450
+ {
451
+ Err ( f) if f. is_panic ( ) => panic:: resume_unwind ( f. into_panic ( ) ) ,
452
+ Ok ( ( ) ) | Err ( _) => { }
453
+ }
435
454
436
- if let Err ( e) = server. await {
437
- eprintln ! ( "server error: {e}" ) ;
455
+ server. abort ( ) ;
456
+ if let Ok ( Err ( e) ) = server. await {
457
+ panic ! ( "server failed: {e}" ) ;
438
458
}
439
459
}
440
460
0 commit comments