3838import com .tencent .trpc .proto .http .common .RpcServerContextWithHttp ;
3939import com .tencent .trpc .proto .http .common .TrpcServletRequestWrapper ;
4040import com .tencent .trpc .proto .http .common .TrpcServletResponseWrapper ;
41+ import java .io .IOException ;
4142import java .lang .reflect .Type ;
4243import java .net .InetSocketAddress ;
4344import java .nio .charset .StandardCharsets ;
4445import java .util .Enumeration ;
4546import java .util .Map ;
47+ import java .util .concurrent .CompletableFuture ;
48+ import java .util .concurrent .TimeoutException ;
49+ import java .util .concurrent .atomic .AtomicBoolean ;
4650import java .util .concurrent .CompletionStage ;
4751import java .util .concurrent .CountDownLatch ;
4852import java .util .concurrent .TimeUnit ;
@@ -65,34 +69,41 @@ public abstract class AbstractHttpExecutor {
6569
6670 protected void execute (HttpServletRequest request , HttpServletResponse response ,
6771 RpcMethodInfoAndInvoker methodInfoAndInvoker ) {
68-
72+ AtomicBoolean responded = new AtomicBoolean ( false );
6973 try {
7074
7175 DefRequest rpcRequest = buildDefRequest (request , response , methodInfoAndInvoker );
7276
73- CountDownLatch countDownLatch = new CountDownLatch ( 1 );
77+ CompletableFuture < Void > completionFuture = new CompletableFuture <>( );
7478
7579 // use a thread pool for asynchronous processing
76- invokeRpcRequest (methodInfoAndInvoker .getInvoker (), rpcRequest , countDownLatch );
80+ invokeRpcRequest (methodInfoAndInvoker .getInvoker (), rpcRequest , completionFuture , responded );
7781
7882 // If the request carries a timeout, use this timeout to wait for the request to be processed.
7983 // If not carried, use the default timeout.
8084 long requestTimeout = rpcRequest .getMeta ().getTimeout ();
8185 if (requestTimeout <= 0 ) {
8286 requestTimeout = methodInfoAndInvoker .getInvoker ().getConfig ().getRequestTimeout ();
8387 }
84- if (requestTimeout > 0 && !countDownLatch .await (requestTimeout , TimeUnit .MILLISECONDS )) {
85- throw TRpcException .newFrameException (ErrorCode .TRPC_SERVER_TIMEOUT_ERR ,
86- "wait http request execute timeout" );
88+ if (requestTimeout > 0 ) {
89+ try {
90+ completionFuture .get (requestTimeout , TimeUnit .MILLISECONDS );
91+ } catch (TimeoutException ex ) {
92+ if (responded .compareAndSet (false , true )) {
93+ doErrorReply (request , response ,
94+ TRpcException .newFrameException (ErrorCode .TRPC_SERVER_TIMEOUT_ERR ,
95+ "wait http request execute timeout" ));
96+ }
97+ }
8798 } else {
88- countDownLatch . await ();
99+ completionFuture . get ();
89100 }
90-
91101 } catch (Exception ex ) {
92102 logger .error ("dispatch request [{}] error" , request , ex );
93- doErrorReply (request , response , ex );
103+ if (responded .compareAndSet (false , true )) {
104+ doErrorReply (request , response , ex );
105+ }
94106 }
95-
96107 }
97108
98109 /**
@@ -107,55 +118,83 @@ protected void execute(HttpServletRequest request, HttpServletResponse response,
107118 /**
108119 * Request processing
109120 *
110- * @param countDownLatch latch used to wait for the request processing
121+ * @param invoker the invoker
122+ * @param rpcRequest the rpc request
123+ * @param completionFuture the completion future
124+ * @param responded the responded flag
111125 */
112- private void invokeRpcRequest (ProviderInvoker <?> invoker , DefRequest rpcRequest , CountDownLatch countDownLatch ) {
126+ private void invokeRpcRequest (ProviderInvoker <?> invoker , DefRequest rpcRequest ,
127+ CompletableFuture <Void > completionFuture ,
128+ AtomicBoolean responded ) {
113129
114130 WorkerPool workerPool = invoker .getConfig ().getWorkerPoolObj ();
115131
116132 if (null == workerPool ) {
117133 logger .error ("dispatch rpcRequest [{}] error, workerPool is empty" , rpcRequest );
118- throw TRpcException .newFrameException (ErrorCode .TRPC_SERVER_NOSERVICE_ERR ,
119- "not found service, workerPool is empty" );
134+ completionFuture .completeExceptionally (TRpcException .newFrameException (ErrorCode .TRPC_SERVER_NOSERVICE_ERR ,
135+ "not found service, workerPool is empty" ));
136+ return ;
120137 }
121138
122139 workerPool .execute (() -> {
123-
124- // Get the original http response
125- HttpServletResponse response = getOriginalResponse (rpcRequest );
126-
127- // Invoke the routing implementation method to handle the request.
128- CompletionStage <Response > future = invoker .invoke (rpcRequest );
129- future .whenComplete ((result , t ) -> {
130- try {
131- // Throw the call exception, which will be handled uniformly by the exception handling program.
132- if (t != null ) {
133- throw t ;
134- }
135-
136- // Throw a business logic exception, which will be handled uniformly
137- // by the exception handling program.
138- Throwable ex = result .getException ();
139- if (ex != null ) {
140- throw ex ;
140+ try {
141+ // Get the original http response
142+ HttpServletResponse response = getOriginalResponse (rpcRequest );
143+ // Invoke the routing implementation method to handle the request.
144+ CompletionStage <Response > rpcFuture = invoker .invoke (rpcRequest );
145+
146+ rpcFuture .whenComplete ((result , throwable ) -> {
147+ try {
148+ if (responded .get ()) {
149+ return ;
150+ }
151+
152+ // Throw the call exception, which will be handled uniformly by the exception handling program.
153+ if (throwable != null ) {
154+ throw throwable ;
155+ }
156+
157+ // Throw a business logic exception, which will be handled uniformly
158+ // by the exception handling program.
159+ if (result .getException () != null ) {
160+ throw result .getException ();
161+ }
162+
163+ // normal response
164+ if (responded .compareAndSet (false , true )) {
165+ response .setStatus (HttpStatus .SC_OK );
166+ httpCodec .writeHttpResponse (response , result );
167+ response .flushBuffer ();
168+ }
169+
170+ completionFuture .complete (null );
171+ } catch (Throwable t ) {
172+ handleError (t , rpcRequest , response , responded , completionFuture );
141173 }
174+ });
142175
143- // normal response
144- response .setStatus (HttpStatus .SC_OK );
145- httpCodec .writeHttpResponse (response , result );
146- response .flushBuffer ();
147- } catch (Throwable e ) {
148- HttpServletRequest request = getOriginalRequest (rpcRequest );
149- logger .warn ("reply message error, channel: [{}], msg:[{}]" , request .getRemoteAddr (), request , e );
150- httpErrorReply (request , response ,
151- ErrorResponse .create (request , HttpStatus .SC_SERVICE_UNAVAILABLE , e ));
152- } finally {
153- countDownLatch .countDown ();
154- }
155- });
176+ } catch (Exception e ) {
177+ handleError (e , rpcRequest , getOriginalResponse (rpcRequest ), responded , completionFuture );
178+ }
156179 });
157180 }
158181
182+ /**
183+ * Handle error
184+ */
185+ private void handleError (Throwable t , DefRequest rpcRequest , HttpServletResponse response ,
186+ AtomicBoolean responded , CompletableFuture <Void > completionFuture ) {
187+ try {
188+ if (responded .compareAndSet (false , true )) {
189+ HttpServletRequest request = getOriginalRequest (rpcRequest );
190+ logger .warn ("reply message error, channel: [{}], msg:[{}]" , request .getRemoteAddr (), request , t );
191+ httpErrorReply (request , response , ErrorResponse .create (request , HttpStatus .SC_SERVICE_UNAVAILABLE , t ));
192+ }
193+ } finally {
194+ completionFuture .completeExceptionally (t );
195+ }
196+ }
197+
159198 /**
160199 * Build the context request.
161200 *
@@ -480,4 +519,4 @@ private String getString(String[] callInfos, int length, int cursor) {
480519 return callInfos .length < length ? StringUtils .EMPTY : callInfos [cursor ];
481520 }
482521
483- }
522+ }
0 commit comments