diff --git a/src/httpcore2/httpcore2/_async/http_proxy.py b/src/httpcore2/httpcore2/_async/http_proxy.py index 5cde1dd4..b3c8ddac 100644 --- a/src/httpcore2/httpcore2/_async/http_proxy.py +++ b/src/httpcore2/httpcore2/_async/http_proxy.py @@ -294,9 +294,16 @@ async def handle_async_request(self, request: Request) -> Response: "server_hostname": self._remote_origin.host.decode("ascii"), "timeout": timeout, } - async with Trace("start_tls", logger, request, kwargs) as trace: - stream = await stream.start_tls(**kwargs) - trace.return_value = stream + try: + async with Trace("start_tls", logger, request, kwargs) as trace: + stream = await stream.start_tls(**kwargs) + trace.return_value = stream + except Exception: + # If TLS setup fails, close the underlying CONNECT + # connection so the pool can discard it instead of + # leaving it ACTIVE until max_connections is exhausted. + await self._connection.aclose() + raise # Determine if we should be using HTTP/1.1 or HTTP/2 ssl_object = stream.get_extra_info("ssl_object") diff --git a/src/httpcore2/httpcore2/_sync/http_proxy.py b/src/httpcore2/httpcore2/_sync/http_proxy.py index c0129f1a..46107bb9 100644 --- a/src/httpcore2/httpcore2/_sync/http_proxy.py +++ b/src/httpcore2/httpcore2/_sync/http_proxy.py @@ -294,9 +294,16 @@ def handle_request(self, request: Request) -> Response: "server_hostname": self._remote_origin.host.decode("ascii"), "timeout": timeout, } - with Trace("start_tls", logger, request, kwargs) as trace: - stream = stream.start_tls(**kwargs) - trace.return_value = stream + try: + with Trace("start_tls", logger, request, kwargs) as trace: + stream = stream.start_tls(**kwargs) + trace.return_value = stream + except Exception: + # If TLS setup fails, close the underlying CONNECT + # connection so the pool can discard it instead of + # leaving it ACTIVE until max_connections is exhausted. + self._connection.close() + raise # Determine if we should be using HTTP/1.1 or HTTP/2 ssl_object = stream.get_extra_info("ssl_object") diff --git a/tests/httpcore2/_async/test_http_proxy.py b/tests/httpcore2/_async/test_http_proxy.py index 4754eee1..dd30b67c 100644 --- a/tests/httpcore2/_async/test_http_proxy.py +++ b/tests/httpcore2/_async/test_http_proxy.py @@ -240,3 +240,48 @@ def test_proxy_headers() -> None: auth=("username", "password"), ) assert proxy.headers == [(b"Proxy-Authorization", b"Basic dXNlcm5hbWU6cGFzc3dvcmQ=")] + + +class BrokenTLSStream(AsyncMockStream): + async def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: typing.Optional[str] = None, + timeout: typing.Optional[float] = None, + ) -> AsyncNetworkStream: + raise OSError("TLS Failure") + + +class BrokenTLSBackend(AsyncMockBackend): + async def connect_tcp( + self, + host: str, + port: int, + timeout: typing.Optional[float] = None, + local_address: typing.Optional[str] = None, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> AsyncNetworkStream: + return BrokenTLSStream(list(self._buffer)) + + +@pytest.mark.anyio +async def test_proxy_tunneling_tls_error() -> None: + """ + Send an HTTPS request via a proxy where the TLS handshake fails after the + CONNECT tunnel is established. The CONNECT connection must be closed so it + doesn't leak from the pool. + """ + network_backend = BrokenTLSBackend( + [ + b"HTTP/1.1 200 OK\r\n\r\n", + ] + ) + + async with AsyncConnectionPool( + proxy=Proxy("http://localhost:8080/"), + network_backend=network_backend, + ) as proxy: + with pytest.raises(OSError, match="TLS Failure"): + await proxy.request("GET", "https://example.com/") + + assert not proxy.connections diff --git a/tests/httpcore2/_sync/test_http_proxy.py b/tests/httpcore2/_sync/test_http_proxy.py index f183a1fe..2ef96180 100644 --- a/tests/httpcore2/_sync/test_http_proxy.py +++ b/tests/httpcore2/_sync/test_http_proxy.py @@ -240,3 +240,48 @@ def test_proxy_headers() -> None: auth=("username", "password"), ) assert proxy.headers == [(b"Proxy-Authorization", b"Basic dXNlcm5hbWU6cGFzc3dvcmQ=")] + + +class BrokenTLSStream(MockStream): + def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: typing.Optional[str] = None, + timeout: typing.Optional[float] = None, + ) -> NetworkStream: + raise OSError("TLS Failure") + + +class BrokenTLSBackend(MockBackend): + def connect_tcp( + self, + host: str, + port: int, + timeout: typing.Optional[float] = None, + local_address: typing.Optional[str] = None, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> NetworkStream: + return BrokenTLSStream(list(self._buffer)) + + + +def test_proxy_tunneling_tls_error() -> None: + """ + Send an HTTPS request via a proxy where the TLS handshake fails after the + CONNECT tunnel is established. The CONNECT connection must be closed so it + doesn't leak from the pool. + """ + network_backend = BrokenTLSBackend( + [ + b"HTTP/1.1 200 OK\r\n\r\n", + ] + ) + + with ConnectionPool( + proxy=Proxy("http://localhost:8080/"), + network_backend=network_backend, + ) as proxy: + with pytest.raises(OSError, match="TLS Failure"): + proxy.request("GET", "https://example.com/") + + assert not proxy.connections