Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions litellm/llms/custom_httpx/http_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,14 @@ async def post(
files=files,
content=content,
)
response = await self.client.send(req, stream=stream)
response.raise_for_status()
return response
try:
response = await self.client.send(req, stream=stream)
response.raise_for_status()
return response
except asyncio.CancelledError:
# If the request was cancelled, ensure we propagate the cancellation
# This will cause the HTTP connection to be closed, which downstream services can detect
raise
except (httpx.RemoteProtocolError, httpx.ConnectError):
# Retry the request with a new session if there is a connection error
new_client = self.create_client(
Expand Down
5 changes: 5 additions & 0 deletions litellm/llms/custom_httpx/llm_http_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ async def _make_common_async_call(
response: Optional[httpx.Response] = None
for i in range(max(max_retry_on_unprocessable_entity_error, 1)):
try:
# Make HTTP request with proper cancellation support
response = await async_httpx_client.post(
url=api_base,
headers=headers,
Expand All @@ -132,6 +133,10 @@ async def _make_common_async_call(
stream=stream,
logging_obj=logging_obj,
)
except asyncio.CancelledError:
# If the request was cancelled, ensure we propagate the cancellation
# This will cause the HTTP connection to be closed, which your GPU service will detect
raise
except httpx.HTTPStatusError as e:
hit_max_retry = i + 1 == max_retry_on_unprocessable_entity_error
should_retry = provider_config.should_retry_llm_api_inside_llm_translation_on_http_error(
Expand Down
12 changes: 9 additions & 3 deletions litellm/llms/openai/completion/handler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import json
from typing import Callable, List, Optional, Union

Expand Down Expand Up @@ -176,9 +177,14 @@ async def acompletion(
else:
openai_aclient = client

raw_response = await openai_aclient.completions.with_raw_response.create(
**data
)
try:
raw_response = await openai_aclient.completions.with_raw_response.create(
**data
)
except asyncio.CancelledError:
# If the request was cancelled, ensure we propagate the cancellation
# This will cause the HTTP connection to be closed, which downstream services can detect
raise
response = raw_response.parse()
response_json = response.model_dump()

Expand Down
14 changes: 10 additions & 4 deletions litellm/llms/openai/openai.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import time
import types
from typing import (
Expand Down Expand Up @@ -432,11 +433,16 @@ async def make_openai_chat_completion_request(
"""
start_time = time.time()
try:
raw_response = (
await openai_aclient.chat.completions.with_raw_response.create(
**data, timeout=timeout
try:
raw_response = (
await openai_aclient.chat.completions.with_raw_response.create(
**data, timeout=timeout
)
)
)
except asyncio.CancelledError:
# If the request was cancelled, ensure we propagate the cancellation
# This will cause the HTTP connection to be closed, which downstream services can detect
raise
end_time = time.time()

if hasattr(raw_response, "headers"):
Expand Down
69 changes: 53 additions & 16 deletions litellm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,14 +560,30 @@ async def acompletion(
await asyncio.sleep(mock_delay)

try:
# Use a partial function to pass your keyword arguments
func = partial(completion, **completion_kwargs, **kwargs)

# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)

init_response = await loop.run_in_executor(None, func_with_context)
# First, try calling completion directly to see if it returns a coroutine (native async)
# This allows providers with async implementations to be properly cancellable
init_response = completion(**completion_kwargs, **kwargs)

# If the provider returned a coroutine, await it directly (fully cancellable)
if asyncio.iscoroutine(init_response):
try:
init_response = await init_response
except asyncio.CancelledError:
# Re-raise the CancelledError to propagate cancellation
raise
# If it's not a coroutine, we need to handle it differently
elif not isinstance(init_response, (dict, ModelResponse, CustomStreamWrapper)):
# This shouldn't happen in normal cases, but fallback to executor if needed
func = partial(completion, **completion_kwargs, **kwargs)
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)

try:
init_response = await loop.run_in_executor(None, func_with_context)
except asyncio.CancelledError:
# Re-raise the CancelledError to propagate cancellation
raise

if isinstance(init_response, dict) or isinstance(
init_response, ModelResponse
): ## CACHING SCENARIO
Expand All @@ -592,6 +608,9 @@ async def acompletion(
loop=loop
) # sets the logging event loop if the user does sync streaming (e.g. on proxy for sagemaker calls)
return response
except asyncio.CancelledError:
# Ensure CancelledError is properly propagated without being caught by the general exception handler
raise
except Exception as e:
custom_llm_provider = custom_llm_provider or "openai"
raise exception_type(
Expand Down Expand Up @@ -4800,14 +4819,29 @@ async def atext_completion(
kwargs["acompletion"] = True
custom_llm_provider = None
try:
# Use a partial function to pass your keyword arguments
func = partial(text_completion, *args, **kwargs)

# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)

init_response = await loop.run_in_executor(None, func_with_context)
# First, try calling text_completion directly to see if it returns a coroutine (native async)
# This allows providers with async implementations to be properly cancellable
init_response = text_completion(*args, **kwargs)

# If the provider returned a coroutine, await it directly (fully cancellable)
if asyncio.iscoroutine(init_response):
try:
init_response = await init_response
except asyncio.CancelledError:
# Re-raise the CancelledError to propagate cancellation
raise
# If it's not a coroutine, we need to handle it differently
elif not isinstance(init_response, (dict, TextCompletionResponse, CustomStreamWrapper)):
# This shouldn't happen in normal cases, but fallback to executor if needed
func = partial(text_completion, *args, **kwargs)
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)

try:
init_response = await loop.run_in_executor(None, func_with_context)
except asyncio.CancelledError:
# Re-raise the CancelledError to propagate cancellation
raise
if isinstance(init_response, dict) or isinstance(
init_response, TextCompletionResponse
): ## CACHING SCENARIO
Expand Down Expand Up @@ -4850,6 +4884,9 @@ async def atext_completion(
custom_llm_provider=custom_llm_provider,
)
return text_completion_response
except asyncio.CancelledError:
# Ensure CancelledError is properly propagated without being caught by the general exception handler
raise
except Exception as e:
custom_llm_provider = custom_llm_provider or "openai"
raise exception_type(
Expand Down
6 changes: 6 additions & 0 deletions litellm/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,6 +1068,9 @@ async def acompletion(
)

return response
except asyncio.CancelledError:
# Ensure CancelledError is properly propagated without being caught by the general exception handler
raise
except Exception as e:
asyncio.create_task(
send_llm_exception_alert(
Expand Down Expand Up @@ -2506,6 +2509,9 @@ async def atext_completion(
response = await self.async_function_with_fallbacks(**kwargs)

return response
except asyncio.CancelledError:
# Ensure CancelledError is properly propagated without being caught by the general exception handler
raise
except Exception as e:
asyncio.create_task(
send_llm_exception_alert(
Expand Down
Loading
Loading