Skip to content

Commit 8b33328

Browse files
authored
Perf speed up pytest (#15951)
* perf: Skip sleep delays in base_mail.py during tests to improve test speed * perf: Mock datetime.now in parallel_request_limiter_v3.py to improve test speed * pref: Mock urllib system calls in test_aiohttp_transport.py to improve test speed * chore: add --durations=50 to visualize slowest tests * pref: reduce setup phase overhead by widening fixture scope in conftest.py * test: stabilize flaky tests * fix: minor issue
1 parent 5ad108b commit 8b33328

File tree

11 files changed

+116
-37
lines changed

11 files changed

+116
-37
lines changed

.github/workflows/test-litellm.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,12 @@ jobs:
3333
poetry run pip install "google-genai==1.22.0"
3434
poetry run pip install "google-cloud-aiplatform>=1.38"
3535
poetry run pip install "fastapi-offline==1.7.3"
36+
poetry run pip install "python-multipart==0.0.18"
3637
- name: Setup litellm-enterprise as local package
3738
run: |
3839
cd enterprise
3940
python -m pip install -e .
4041
cd ..
4142
- name: Run tests
4243
run: |
43-
poetry run pytest tests/test_litellm --tb=short -vv --maxfail=10 -n 4
44+
poetry run pytest tests/test_litellm --tb=short -vv --maxfail=10 -n 4 --durations=50

litellm/proxy/hooks/dynamic_rate_limiter_v3.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
"""
44

55
import os
6-
from typing import Dict, List, Literal, Optional, Union
6+
from datetime import datetime
7+
from typing import Callable, Dict, List, Literal, Optional, Union
78

89
from fastapi import HTTPException
910

@@ -42,9 +43,15 @@ class _PROXY_DynamicRateLimitHandlerV3(CustomLogger):
4243
- When saturated: strict priority-based limits enforced (fair)
4344
- Uses v3 limiter's atomic Lua scripts for race-free increments
4445
"""
45-
def __init__(self, internal_usage_cache: DualCache):
46+
def __init__(
47+
self,
48+
internal_usage_cache: DualCache,
49+
time_provider: Optional[Callable[[], datetime]] = None,
50+
):
4651
self.internal_usage_cache = InternalUsageCache(dual_cache=internal_usage_cache)
47-
self.v3_limiter = _PROXY_MaxParallelRequestsHandler_v3(self.internal_usage_cache)
52+
self.v3_limiter = _PROXY_MaxParallelRequestsHandler_v3(
53+
self.internal_usage_cache, time_provider=time_provider
54+
)
4855

4956
def update_variables(self, llm_router: Router):
5057
self.llm_router = llm_router

litellm/proxy/hooks/parallel_request_limiter_v3.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from typing import (
1212
TYPE_CHECKING,
1313
Any,
14+
Callable,
1415
Dict,
1516
List,
1617
Literal,
@@ -137,8 +138,13 @@ class RateLimitResponseWithDescriptors(TypedDict):
137138

138139

139140
class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
140-
def __init__(self, internal_usage_cache: InternalUsageCache):
141+
def __init__(
142+
self,
143+
internal_usage_cache: InternalUsageCache,
144+
time_provider: Optional[Callable[[], datetime]] = None,
145+
):
141146
self.internal_usage_cache = internal_usage_cache
147+
self._time_provider = time_provider or datetime.now
142148
if self.internal_usage_cache.dual_cache.redis_cache is not None:
143149
self.batch_rate_limiter_script = (
144150
self.internal_usage_cache.dual_cache.redis_cache.async_register_script(
@@ -156,6 +162,10 @@ def __init__(self, internal_usage_cache: InternalUsageCache):
156162

157163
self.window_size = int(os.getenv("LITELLM_RATE_LIMIT_WINDOW_SIZE", 60))
158164

165+
def _get_current_time(self) -> datetime:
166+
"""Return the current time for rate limiting calculations."""
167+
return self._time_provider()
168+
159169
def _is_redis_cluster(self) -> bool:
160170
"""
161171
Check if the dual cache is using Redis cluster.
@@ -425,7 +435,8 @@ async def should_rate_limit(
425435
read_only: If True, only check limits without incrementing counters
426436
"""
427437

428-
now = datetime.now().timestamp()
438+
current_time = self._get_current_time()
439+
now = current_time.timestamp()
429440
now_int = int(now) # Convert to integer for Redis Lua script
430441

431442
# Collect all keys and their metadata upfront
@@ -1090,7 +1101,7 @@ async def async_pre_call_hook(
10901101
descriptor = descriptors[floor(i / 2)]
10911102

10921103
# Calculate reset time (window_start + window_size)
1093-
now = datetime.now().timestamp()
1104+
now = self._get_current_time().timestamp()
10941105
reset_time = now + self.window_size # Conservative estimate
10951106
reset_time_formatted = datetime.fromtimestamp(
10961107
reset_time

tests/test_litellm/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def event_loop():
2626

2727

2828

29-
@pytest.fixture(scope="function", autouse=True)
29+
@pytest.fixture(scope="module", autouse=True)
3030
def setup_and_teardown():
3131
"""
3232
This fixture reloads litellm before every function. To speed up testing by removing callbacks being chained.

tests/test_litellm/enterprise/enterprise_callbacks/send_emails/test_base_email.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,11 @@
44
import unittest.mock as mock
55
from unittest.mock import patch
66

7+
from enterprise.litellm_enterprise.enterprise_callbacks.send_emails.base_email import BaseEmailLogger
78
import pytest
89
from fastapi.testclient import TestClient
910

1011
sys.path.insert(0, os.path.abspath("../../.."))
11-
from litellm_enterprise.enterprise_callbacks.send_emails.base_email import (
12-
BaseEmailLogger,
13-
)
1412
from litellm_enterprise.types.enterprise_callbacks.send_emails import (
1513
EmailEvent,
1614
SendKeyCreatedEmailEvent,
@@ -20,6 +18,13 @@
2018
from litellm.proxy._types import Litellm_EntityType, WebhookEvent
2119

2220

21+
@pytest.fixture(autouse=True)
22+
def no_invitation_wait(monkeypatch):
23+
async def _noop(self):
24+
return None
25+
26+
monkeypatch.setattr(BaseEmailLogger, "_wait_for_invitation_creation", _noop)
27+
2328
@pytest.fixture
2429
def base_email_logger():
2530
return BaseEmailLogger()

tests/test_litellm/llms/custom_httpx/test_aiohttp_transport.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,8 @@ async def test_handle_async_request_uses_env_proxy(monkeypatch):
188188
monkeypatch.setenv("HTTPS_PROXY", proxy_url)
189189
monkeypatch.setenv("https_proxy", proxy_url)
190190
monkeypatch.delenv("DISABLE_AIOHTTP_TRUST_ENV", raising=False)
191+
monkeypatch.setattr("urllib.request.getproxies", lambda: {"http": proxy_url, "https": proxy_url})
192+
monkeypatch.setattr("urllib.request.proxy_bypass", lambda host: False)
191193

192194
captured = {}
193195

tests/test_litellm/llms/vertex_ai/vertex_gemma_models/test_vertex_gemma_transformation.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,13 @@
1111

1212
import litellm
1313

14+
@pytest.fixture(autouse=True)
15+
def _reset_litellm_http_client_cache():
16+
"""Ensure each test gets a fresh async HTTP client mock."""
17+
from litellm import in_memory_llm_clients_cache
18+
19+
in_memory_llm_clients_cache.flush_cache()
20+
1421

1522
class TestVertexGemmaCompletion:
1623
"""Test completion flow for Vertex AI Gemma models using litellm.acompletion()"""

tests/test_litellm/proxy/hooks/test_dynamic_rate_limiter_v3.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import os
99
import sys
1010
import time
11+
from datetime import datetime, timedelta
1112
from unittest.mock import AsyncMock, patch
1213

1314
import pytest
@@ -22,6 +23,24 @@
2223
)
2324

2425

26+
class TimeController:
27+
def __init__(self):
28+
self._current = datetime.utcnow()
29+
30+
def now(self) -> datetime:
31+
return self._current
32+
33+
def advance(self, seconds: float) -> None:
34+
self._current += timedelta(seconds=seconds)
35+
36+
37+
@pytest.fixture
38+
def time_controller(monkeypatch):
39+
controller = TimeController()
40+
monkeypatch.setattr(time, "time", lambda: controller.now().timestamp())
41+
return controller
42+
43+
2544
@pytest.mark.asyncio
2645
async def test_priority_weight_allocation():
2746
"""
@@ -195,7 +214,7 @@ async def test_concurrent_priority_requests():
195214

196215

197216
@pytest.mark.asyncio
198-
async def test_100_concurrent_priority_requests():
217+
async def test_100_concurrent_priority_requests(time_controller):
199218
"""
200219
Stress test: 100 concurrent requests with mixed priorities over 10 seconds.
201220
@@ -211,7 +230,9 @@ async def test_100_concurrent_priority_requests():
211230
litellm.priority_reservation = {"high": 0.9, "low": 0.1}
212231

213232
dual_cache = DualCache()
214-
handler = DynamicRateLimitHandler(internal_usage_cache=dual_cache)
233+
handler = DynamicRateLimitHandler(
234+
internal_usage_cache=dual_cache, time_provider=time_controller.now
235+
)
215236

216237
model = "stress-test-model"
217238
total_tpm = 1000
@@ -307,7 +328,8 @@ async def test_user_descriptors(user_data):
307328

308329
# Add small delay between batches to spread over ~10 seconds
309330
if batch_idx < len(batches) - 1: # Don't sleep after last batch
310-
await asyncio.sleep(1.0) # 1 second between batches
331+
await asyncio.sleep(0)
332+
time_controller.advance(1.0) # simulate 1s passing between batches
311333

312334
end_time = time.time()
313335
total_duration = end_time - start_time

tests/test_litellm/proxy/hooks/test_parallel_request_limiter_v3.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import asyncio
66
import os
77
import sys
8-
from datetime import datetime
8+
import time
9+
from datetime import datetime, timedelta
910
from typing import Any, Dict, List, Optional
1011

1112
import pytest
@@ -21,10 +22,27 @@
2122
from litellm.proxy.utils import InternalUsageCache, ProxyLogging, hash_token
2223
from litellm.types.utils import ModelResponse, Usage
2324

25+
class TimeController:
26+
def __init__(self):
27+
self._current = datetime.utcnow()
28+
29+
def now(self) -> datetime:
30+
return self._current
31+
32+
def advance(self, seconds: float) -> None:
33+
self._current += timedelta(seconds=seconds)
34+
35+
36+
@pytest.fixture
37+
def time_controller(monkeypatch):
38+
controller = TimeController()
39+
monkeypatch.setattr(time, "time", lambda: controller.now().timestamp())
40+
return controller
41+
2442

2543
@pytest.mark.flaky(reruns=3)
2644
@pytest.mark.asyncio
27-
async def test_sliding_window_rate_limit_v3(monkeypatch):
45+
async def test_sliding_window_rate_limit_v3(monkeypatch, time_controller):
2846
"""
2947
Test the sliding window rate limiting functionality
3048
"""
@@ -34,7 +52,8 @@ async def test_sliding_window_rate_limit_v3(monkeypatch):
3452
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, rpm_limit=3)
3553
local_cache = DualCache()
3654
parallel_request_handler = _PROXY_MaxParallelRequestsHandler(
37-
internal_usage_cache=InternalUsageCache(local_cache)
55+
internal_usage_cache=InternalUsageCache(local_cache),
56+
time_provider=time_controller.now,
3857
)
3958

4059
# Mock the batch_rate_limiter_script to simulate window expiry and use correct key construction
@@ -103,7 +122,7 @@ async def mock_batch_rate_limiter(*args, **kwargs):
103122
assert "Rate limit exceeded" in str(exc_info.value.detail)
104123

105124
# Wait for window to expire (2 seconds)
106-
await asyncio.sleep(3)
125+
time_controller.advance(3)
107126

108127
print("WAITED 3 seconds")
109128

@@ -116,7 +135,7 @@ async def mock_batch_rate_limiter(*args, **kwargs):
116135

117136

118137
@pytest.mark.asyncio
119-
async def test_rate_limiter_script_return_values_v3(monkeypatch):
138+
async def test_rate_limiter_script_return_values_v3(monkeypatch, time_controller):
120139
"""
121140
Test that the rate limiter script returns both counter and window values correctly
122141
"""
@@ -126,7 +145,8 @@ async def test_rate_limiter_script_return_values_v3(monkeypatch):
126145
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, rpm_limit=3)
127146
local_cache = DualCache()
128147
parallel_request_handler = _PROXY_MaxParallelRequestsHandler(
129-
internal_usage_cache=InternalUsageCache(local_cache)
148+
internal_usage_cache=InternalUsageCache(local_cache),
149+
time_provider=time_controller.now,
130150
)
131151

132152
# Mock the batch_rate_limiter_script to simulate window expiry and use correct key construction
@@ -199,7 +219,7 @@ async def mock_batch_rate_limiter(*args, **kwargs):
199219
assert new_counter_value == 2, "Counter should be 2 after second request"
200220

201221
# Wait for window to expire
202-
await asyncio.sleep(3)
222+
time_controller.advance(3)
203223

204224
# Make request after window expiry
205225
await parallel_request_handler.async_pre_call_hook(
@@ -226,7 +246,7 @@ async def mock_batch_rate_limiter(*args, **kwargs):
226246
)
227247
@pytest.mark.flaky(reruns=3)
228248
@pytest.mark.asyncio
229-
async def test_normal_router_call_tpm_v3(monkeypatch, rate_limit_object):
249+
async def test_normal_router_call_tpm_v3(monkeypatch, rate_limit_object, time_controller):
230250
"""
231251
Test normal router call with parallel request limiter v3 for TPM rate limiting
232252
"""
@@ -276,7 +296,8 @@ async def test_normal_router_call_tpm_v3(monkeypatch, rate_limit_object):
276296
)
277297
local_cache = DualCache()
278298
parallel_request_handler = _PROXY_MaxParallelRequestsHandler(
279-
internal_usage_cache=InternalUsageCache(local_cache)
299+
internal_usage_cache=InternalUsageCache(local_cache),
300+
time_provider=time_controller.now,
280301
)
281302

282303
# Mock the batch_rate_limiter_script to simulate window expiry and use correct key construction
@@ -359,7 +380,8 @@ def get_value_for_key(rate_limit_object, user_api_key_dict, model_name):
359380
},
360381
mock_response="hello",
361382
)
362-
await asyncio.sleep(1) # success is done in a separate thread
383+
await asyncio.sleep(0)
384+
time_controller.advance(1)
363385

364386
# Verify the token count is tracked
365387
counter_value = await local_cache.async_get_cache(key=counter_key)
@@ -383,7 +405,7 @@ def get_value_for_key(rate_limit_object, user_api_key_dict, model_name):
383405
)
384406

385407
# Wait for window to expire
386-
await asyncio.sleep(3)
408+
time_controller.advance(3)
387409

388410
# Make request after window expiry
389411
await parallel_request_handler.async_pre_call_hook(

tests/test_litellm/proxy/management_endpoints/test_ui_sso.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -125,16 +125,14 @@ def test_get_microsoft_callback_response():
125125
"surname": "User",
126126
}
127127

128-
future = asyncio.Future()
129-
future.set_result(mock_response)
130-
131128
with patch.dict(
132129
os.environ,
133130
{"MICROSOFT_CLIENT_SECRET": "mock_secret", "MICROSOFT_TENANT": "mock_tenant"},
134131
):
132+
mock_verify = AsyncMock(return_value=mock_response)
135133
with patch(
136134
"fastapi_sso.sso.microsoft.MicrosoftSSO.verify_and_process",
137-
return_value=future,
135+
new=mock_verify,
138136
):
139137
# Act
140138
result = asyncio.run(
@@ -166,15 +164,14 @@ def test_get_microsoft_callback_response_raw_sso_response():
166164
"surname": "User",
167165
}
168166

169-
future = asyncio.Future()
170-
future.set_result(mock_response)
171167
with patch.dict(
172168
os.environ,
173169
{"MICROSOFT_CLIENT_SECRET": "mock_secret", "MICROSOFT_TENANT": "mock_tenant"},
174170
):
171+
mock_verify = AsyncMock(return_value=mock_response)
175172
with patch(
176173
"fastapi_sso.sso.microsoft.MicrosoftSSO.verify_and_process",
177-
return_value=future,
174+
new=mock_verify,
178175
):
179176
# Act
180177
result = asyncio.run(
@@ -207,12 +204,10 @@ def test_get_google_callback_response():
207204
"family_name": "User",
208205
}
209206

210-
future = asyncio.Future()
211-
future.set_result(mock_response)
212-
213207
with patch.dict(os.environ, {"GOOGLE_CLIENT_SECRET": "mock_secret"}):
208+
mock_verify = AsyncMock(return_value=mock_response)
214209
with patch(
215-
"fastapi_sso.sso.google.GoogleSSO.verify_and_process", return_value=future
210+
"fastapi_sso.sso.google.GoogleSSO.verify_and_process", new=mock_verify
216211
):
217212
# Act
218213
result = asyncio.run(
@@ -2072,4 +2067,3 @@ async def test_get_generic_sso_redirect_response_with_pkce(self):
20722067
assert "code_challenge=" in updated_location
20732068
assert "code_challenge_method=S256" in updated_location
20742069
assert f"state={test_state}" in updated_location
2075-

0 commit comments

Comments
 (0)