Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/httpcore2/httpcore2/_async/http_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,9 +294,16 @@ async def handle_async_request(self, request: Request) -> Response:
"server_hostname": self._remote_origin.host.decode("ascii"),
"timeout": timeout,
}
async with Trace("start_tls", logger, request, kwargs) as trace:
stream = await stream.start_tls(**kwargs)
trace.return_value = stream
try:
async with Trace("start_tls", logger, request, kwargs) as trace:
stream = await stream.start_tls(**kwargs)
trace.return_value = stream
except Exception:
# If TLS setup fails, close the underlying CONNECT
# connection so the pool can discard it instead of
# leaving it ACTIVE until max_connections is exhausted.
await self._connection.aclose()
raise

# Determine if we should be using HTTP/1.1 or HTTP/2
ssl_object = stream.get_extra_info("ssl_object")
Expand Down
13 changes: 10 additions & 3 deletions src/httpcore2/httpcore2/_sync/http_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,9 +294,16 @@ def handle_request(self, request: Request) -> Response:
"server_hostname": self._remote_origin.host.decode("ascii"),
"timeout": timeout,
}
with Trace("start_tls", logger, request, kwargs) as trace:
stream = stream.start_tls(**kwargs)
trace.return_value = stream
try:
with Trace("start_tls", logger, request, kwargs) as trace:
stream = stream.start_tls(**kwargs)
trace.return_value = stream
except Exception:
# If TLS setup fails, close the underlying CONNECT
# connection so the pool can discard it instead of
# leaving it ACTIVE until max_connections is exhausted.
self._connection.close()
raise

# Determine if we should be using HTTP/1.1 or HTTP/2
ssl_object = stream.get_extra_info("ssl_object")
Expand Down
45 changes: 45 additions & 0 deletions tests/httpcore2/_async/test_http_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,48 @@ def test_proxy_headers() -> None:
auth=("username", "password"),
)
assert proxy.headers == [(b"Proxy-Authorization", b"Basic dXNlcm5hbWU6cGFzc3dvcmQ=")]


class BrokenTLSStream(AsyncMockStream):
async def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
) -> AsyncNetworkStream:
raise OSError("TLS Failure")


class BrokenTLSBackend(AsyncMockBackend):
async def connect_tcp(
self,
host: str,
port: int,
timeout: typing.Optional[float] = None,
local_address: typing.Optional[str] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
) -> AsyncNetworkStream:
return BrokenTLSStream(list(self._buffer))


@pytest.mark.anyio
async def test_proxy_tunneling_tls_error() -> None:
"""
Send an HTTPS request via a proxy where the TLS handshake fails after the
CONNECT tunnel is established. The CONNECT connection must be closed so it
doesn't leak from the pool.
"""
network_backend = BrokenTLSBackend(
[
b"HTTP/1.1 200 OK\r\n\r\n",
]
)

async with AsyncConnectionPool(
proxy=Proxy("http://localhost:8080/"),
network_backend=network_backend,
) as proxy:
with pytest.raises(OSError, match="TLS Failure"):
await proxy.request("GET", "https://example.com/")

assert not proxy.connections
45 changes: 45 additions & 0 deletions tests/httpcore2/_sync/test_http_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,48 @@ def test_proxy_headers() -> None:
auth=("username", "password"),
)
assert proxy.headers == [(b"Proxy-Authorization", b"Basic dXNlcm5hbWU6cGFzc3dvcmQ=")]


class BrokenTLSStream(MockStream):
def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
) -> NetworkStream:
raise OSError("TLS Failure")


class BrokenTLSBackend(MockBackend):
def connect_tcp(
self,
host: str,
port: int,
timeout: typing.Optional[float] = None,
local_address: typing.Optional[str] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
) -> NetworkStream:
return BrokenTLSStream(list(self._buffer))



def test_proxy_tunneling_tls_error() -> None:
"""
Send an HTTPS request via a proxy where the TLS handshake fails after the
CONNECT tunnel is established. The CONNECT connection must be closed so it
doesn't leak from the pool.
"""
network_backend = BrokenTLSBackend(
[
b"HTTP/1.1 200 OK\r\n\r\n",
]
)

with ConnectionPool(
proxy=Proxy("http://localhost:8080/"),
network_backend=network_backend,
) as proxy:
with pytest.raises(OSError, match="TLS Failure"):
proxy.request("GET", "https://example.com/")

assert not proxy.connections
Loading