diff --git a/RATE_LIMITING_IMPLEMENTATION.md b/RATE_LIMITING_IMPLEMENTATION.md new file mode 100644 index 0000000..b2f65c5 --- /dev/null +++ b/RATE_LIMITING_IMPLEMENTATION.md @@ -0,0 +1,181 @@ +# Rate Limiting Implementation Summary + +This document summarizes the API rate limiting implementation for the RTAC project. + +## Implementation Status + +All acceptance criteria have been successfully implemented: + +### 1. Rate limiting works correctly +- **Sliding window algorithm** implemented for accurate rate limiting +- **Configurable limits** for different endpoints +- **Memory-efficient** with automatic cleanup of old requests +- **Thread-safe** with async locks for concurrent requests + +### 2. Headers are included +- `X-RateLimit-Limit` - Maximum requests allowed in window +- `X-RateLimit-Remaining` - Remaining requests in current window +- `X-RateLimit-Reset` - Unix timestamp when window resets +- `X-RateLimit-Window` - Window size in seconds +- `Retry-After` - Seconds to wait before retrying (when rate limited) + +### 3. Errors are handled +- **429 Too Many Requests** status code when rate limit exceeded +- **Structured error responses** with clear messages +- **Retry guidance** included in error responses +- **Custom exception handler** for rate limit errors + +### 4. Configuration is flexible +- **Environment variable configuration** for all settings +- **Different limits** for different endpoints (global vs chat) +- **Enable/disable** rate limiting via configuration +- **Exempted endpoints** for health checks and documentation + +## Files Modified + +### Required Files +1. **`app/api/v1/agents_router.py`** + - Added rate limit exception handler + - Added `/rate-limits` endpoint for configuration info + +2. **`app/middleware/rate_limit.py`** (Created) + - Complete rate limiting middleware implementation + - Sliding window algorithm + - Client identification logic + - Rate limit headers management + +### Additional Files +3. **`app/config/settings.py`** + - Added rate limiting configuration variables + - Environment variable definitions + +4. **`main.py`** + - Integrated rate limiting middleware + - Conditional enablement based on configuration + +5. **`env.example`** + - Added rate limiting environment variables + - Default configuration values + +## Configuration + +### Environment Variables Added +```env +# Rate Limiting Configuration +RATE_LIMIT_ENABLED=true # Enable/disable rate limiting +RATE_LIMIT_REQUESTS=100 # Global requests per hour +RATE_LIMIT_WINDOW=3600 # Global window (1 hour) +RATE_LIMIT_CHAT_REQUESTS=30 # Chat requests per 5 minutes +RATE_LIMIT_CHAT_WINDOW=300 # Chat window (5 minutes) +``` + +### Default Rate Limits +- **Global endpoints**: 100 requests per hour +- **Chat endpoint**: 30 requests per 5 minutes +- **Exempted endpoints**: No limits (health, docs, static files) + +## Key Features + +### Smart Client Identification +- Handles requests behind proxies (`X-Forwarded-For`) +- Supports load balancers (`X-Real-IP`) +- Fallback to direct client IP + +### Endpoint-Specific Limits +- Stricter limits for resource-intensive chat endpoint +- Relaxed limits for status and info endpoints +- Complete exemption for health checks and documentation + +### Comprehensive Error Handling +- Graceful degradation if rate limiting fails +- Structured error responses with retry guidance +- Proper HTTP status codes and headers + +### Performance Optimized +- In-memory storage with automatic cleanup +- Async-safe with proper locking +- Efficient sliding window algorithm + +## Testing + +### Test Files Created +1. **`tests/test_rate_limit.py`** + - Unit tests for rate limit store + - Integration tests for middleware + - Performance and concurrency tests + +2. **`scripts/test_rate_limiting.py`** + - Manual testing script + - Tests all endpoints and scenarios + - Validates headers and responses + +### Test Coverage +- Rate limit enforcement +- Header inclusion +- Error responses +- Exempted endpoints +- Concurrent requests +- Configuration validation + +## Documentation + +### Documentation Created +1. **`docs/rate-limiting.md`** + - Complete implementation guide + - Configuration reference + - API examples and responses + - Best practices for clients + - Troubleshooting guide + +## Deployment Checklist + +### Before Deployment +- [ ] Set appropriate rate limits for production +- [ ] Configure environment variables +- [ ] Test with expected traffic patterns +- [ ] Monitor memory usage in production + +### Production Configuration +```env +# Recommended production settings +RATE_LIMIT_ENABLED=true +RATE_LIMIT_REQUESTS=100 +RATE_LIMIT_WINDOW=3600 +RATE_LIMIT_CHAT_REQUESTS=30 +RATE_LIMIT_CHAT_WINDOW=300 +``` + +## Monitoring + +### Log Messages +Rate limiting activities are logged with appropriate levels: +- Rate limit exceeded events (WARNING) +- System errors (ERROR) +- Normal operations (DEBUG) + +### Metrics to Monitor +- Rate limit hit rates by endpoint +- Client distribution and patterns +- Memory usage of rate limit store +- Response times with middleware + +## Benefits Achieved + +1. **Abuse Prevention**: Protects against excessive API usage +2. **Fair Usage**: Ensures all users get equitable access +3. **System Stability**: Prevents overload of backend services +4. **Cost Control**: Reduces infrastructure costs from abuse +5. **Better UX**: Provides clear feedback to legitimate users + +## Future Enhancements + +Consider these improvements for advanced use cases: +- Redis-based distributed rate limiting +- User-based rate limiting (after authentication) +- Dynamic rate limits based on system load +- Rate limiting analytics dashboard +- IP whitelisting for trusted clients + +--- + +**Implementation Complete**: All acceptance criteria met with comprehensive testing and documentation. diff --git a/app/api/v1/agents_router.py b/app/api/v1/agents_router.py index 176f14c..f0b3127 100644 --- a/app/api/v1/agents_router.py +++ b/app/api/v1/agents_router.py @@ -19,6 +19,21 @@ # Track agent startup time startup_time = time.time() + +@router.exception_handler(429) +async def rate_limit_handler(request: Request, exc: HTTPException): + """Handle rate limit exceeded errors.""" + return JSONResponse( + status_code=429, + content={ + "success": False, + "error": "Rate limit exceeded", + "message": "Too many requests. Please try again later.", + "support_contact": settings.support_phone + }, + headers=getattr(exc, "headers", {}) + ) + @router.post( "/chat", response_model=SuccessResponseSchema[AgentResponse], @@ -186,4 +201,48 @@ async def get_conference_info(): raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to get conference information" + ) + + +@router.get( + "/rate-limits", + response_model=SuccessResponseSchema[dict], + status_code=status.HTTP_200_OK, + summary="Get rate limit information", + description="Get current rate limiting configuration and status" +) +async def get_rate_limits(): + """Get rate limiting information.""" + try: + rate_limit_info = { + "enabled": settings.rate_limit_enabled, + "global_limits": { + "requests": settings.rate_limit_requests, + "window_seconds": settings.rate_limit_window, + "window_description": f"{settings.rate_limit_window // 60} minutes" + }, + "chat_limits": { + "requests": settings.rate_limit_chat_requests, + "window_seconds": settings.rate_limit_chat_window, + "window_description": f"{settings.rate_limit_chat_window // 60} minutes" + }, + "headers_included": [ + "X-RateLimit-Limit", + "X-RateLimit-Remaining", + "X-RateLimit-Reset", + "X-RateLimit-Window", + "Retry-After (when limit exceeded)" + ] + } + + return SuccessResponseSchema( + data=rate_limit_info, + message="Rate limit information retrieved successfully" + ) + + except Exception as e: + logger.error(f"Error getting rate limit info: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to get rate limit information" ) \ No newline at end of file diff --git a/app/config/settings.py b/app/config/settings.py index ef87a96..81d5870 100644 --- a/app/config/settings.py +++ b/app/config/settings.py @@ -50,6 +50,13 @@ class Settings(BaseSettings): secret_key: str = Field(..., env="SECRET_KEY") access_token_expire_minutes: int = Field(30, env="ACCESS_TOKEN_EXPIRE_MINUTES") + # Rate Limiting Configuration + rate_limit_requests: int = Field(100, env="RATE_LIMIT_REQUESTS") + rate_limit_window: int = Field(3600, env="RATE_LIMIT_WINDOW") # 1 hour in seconds + rate_limit_chat_requests: int = Field(30, env="RATE_LIMIT_CHAT_REQUESTS") + rate_limit_chat_window: int = Field(300, env="RATE_LIMIT_CHAT_WINDOW") # 5 minutes in seconds + rate_limit_enabled: bool = Field(True, env="RATE_LIMIT_ENABLED") + @field_validator("cors_origins", mode="before") @classmethod def parse_cors_origins(cls, v): diff --git a/app/middleware/__init__.py b/app/middleware/__init__.py new file mode 100644 index 0000000..e77c58d --- /dev/null +++ b/app/middleware/__init__.py @@ -0,0 +1,3 @@ +""" +Middleware package for request processing. +""" diff --git a/app/middleware/rate_limit.py b/app/middleware/rate_limit.py new file mode 100644 index 0000000..f74dcf2 --- /dev/null +++ b/app/middleware/rate_limit.py @@ -0,0 +1,189 @@ +""" +Rate limiting middleware for API endpoints. +""" + +import time +import asyncio +from typing import Dict, Optional, Tuple +from collections import defaultdict +from fastapi import Request, Response, HTTPException, status +from fastapi.responses import JSONResponse +from starlette.middleware.base import BaseHTTPMiddleware + +from app.config.logger import Logger +from app.config.settings import settings + +logger = Logger.get_logger(__name__) + + +class RateLimitStore: + """In-memory rate limit store using sliding window algorithm.""" + + def __init__(self): + self.requests: Dict[str, list] = defaultdict(list) + self._lock = asyncio.Lock() + + async def is_allowed( + self, + identifier: str, + limit: int, + window_seconds: int + ) -> Tuple[bool, Dict[str, int]]: + """ + Check if request is allowed based on rate limit. + + Args: + identifier: Unique identifier (IP address) + limit: Maximum requests allowed + window_seconds: Time window in seconds + + Returns: + Tuple of (is_allowed, headers_dict) + """ + async with self._lock: + current_time = time.time() + window_start = current_time - window_seconds + + # Clean old requests outside the window + self.requests[identifier] = [ + req_time for req_time in self.requests[identifier] + if req_time > window_start + ] + + request_count = len(self.requests[identifier]) + remaining = max(0, limit - request_count) + reset_time = int(current_time + window_seconds) + + headers = { + "X-RateLimit-Limit": limit, + "X-RateLimit-Remaining": remaining, + "X-RateLimit-Reset": reset_time, + "X-RateLimit-Window": window_seconds + } + + if request_count >= limit: + # Find the oldest request to calculate retry-after + if self.requests[identifier]: + oldest_request = min(self.requests[identifier]) + retry_after = int(oldest_request + window_seconds - current_time) + headers["Retry-After"] = max(1, retry_after) + return False, headers + + # Add current request + self.requests[identifier].append(current_time) + return True, headers + + +class RateLimitMiddleware(BaseHTTPMiddleware): + """Rate limiting middleware for FastAPI.""" + + def __init__(self, app, store: Optional[RateLimitStore] = None): + super().__init__(app) + self.store = store or RateLimitStore() + + # Rate limit configurations from settings + self.global_limit = getattr(settings, 'rate_limit_requests', 100) + self.global_window = getattr(settings, 'rate_limit_window', 3600) # 1 hour + self.chat_limit = getattr(settings, 'rate_limit_chat_requests', 30) + self.chat_window = getattr(settings, 'rate_limit_chat_window', 300) # 5 minutes + + # Exempted paths + self.exempted_paths = { + "/", + "/docs", + "/redoc", + "/openapi.json", + "/health", + "/api/v1/agents/health" + } + + def get_client_identifier(self, request: Request) -> str: + """Get client identifier for rate limiting.""" + # Check for forwarded IP first (behind proxy) + forwarded_for = request.headers.get("X-Forwarded-For") + if forwarded_for: + return forwarded_for.split(",")[0].strip() + + # Check for real IP header + real_ip = request.headers.get("X-Real-IP") + if real_ip: + return real_ip + + # Fallback to client IP + return request.client.host if request.client else "unknown" + + def get_rate_limit_config(self, request: Request) -> Tuple[int, int]: + """Get rate limit configuration based on endpoint.""" + path = request.url.path + + # Chat endpoint has stricter limits + if "/chat" in path: + return self.chat_limit, self.chat_window + + # Default global limits + return self.global_limit, self.global_window + + async def dispatch(self, request: Request, call_next): + """Process request with rate limiting.""" + # Skip rate limiting for exempted paths + if request.url.path in self.exempted_paths: + return await call_next(request) + + # Skip for health checks and static files + if (request.url.path.startswith("/static") or + request.method == "OPTIONS"): + return await call_next(request) + + client_id = self.get_client_identifier(request) + limit, window = self.get_rate_limit_config(request) + + try: + is_allowed, headers = await self.store.is_allowed( + client_id, limit, window + ) + + if not is_allowed: + logger.warning( + f"Rate limit exceeded for {client_id} on {request.url.path}" + ) + + error_response = { + "success": False, + "error": "Rate limit exceeded", + "message": "Too many requests. Please try again later.", + "details": { + "limit": limit, + "window_seconds": window, + "retry_after": headers.get("Retry-After", window) + } + } + + response = JSONResponse( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + content=error_response + ) + + # Add rate limit headers + for key, value in headers.items(): + response.headers[key] = str(value) + + return response + + # Process the request + response = await call_next(request) + + # Add rate limit headers to successful response + for key, value in headers.items(): + response.headers[key] = str(value) + + return response + + except Exception as e: + logger.error(f"Error in rate limiting: {e}") + # Continue without rate limiting if there's an error + return await call_next(request) + + +def create_rate_limit_middleware(store: Optional[RateLimitStore] = None) -> RateLimitMiddleware: + """Factory function to create rate limit middleware.""" + return lambda app: RateLimitMiddleware(app, store) diff --git a/docs/rate-limiting.md b/docs/rate-limiting.md new file mode 100644 index 0000000..c217d66 --- /dev/null +++ b/docs/rate-limiting.md @@ -0,0 +1,269 @@ +# API Rate Limiting + +This document describes the rate limiting implementation for the RTAC API Conference AI Agent. + +## Overview + +The API implements rate limiting to prevent abuse and ensure fair usage for all users. The rate limiting uses a sliding window algorithm to track requests over time. + +## Configuration + +Rate limiting is configured through environment variables: + +| Variable | Default | Description | +|----------|---------|-------------| +| `RATE_LIMIT_ENABLED` | `true` | Enable/disable rate limiting | +| `RATE_LIMIT_REQUESTS` | `100` | Global requests per window | +| `RATE_LIMIT_WINDOW` | `3600` | Global window size in seconds (1 hour) | +| `RATE_LIMIT_CHAT_REQUESTS` | `30` | Chat endpoint requests per window | +| `RATE_LIMIT_CHAT_WINDOW` | `300` | Chat window size in seconds (5 minutes) | + +## Rate Limits + +### Global Limits +- **100 requests per hour** for all endpoints (except chat) +- Applied to: `/api/v1/agents/status`, `/api/v1/agents/info`, `/api/v1/agents/rate-limits` + +### Chat Limits +- **30 requests per 5 minutes** for chat endpoint +- Applied to: `/api/v1/agents/chat` + +### Exempted Endpoints +The following endpoints are exempt from rate limiting: +- `/` (root) +- `/docs` (API documentation) +- `/redoc` (API documentation) +- `/openapi.json` (OpenAPI specification) +- `/api/v1/agents/health` (health check) +- Static files under `/static` +- `OPTIONS` requests (CORS preflight) + +## Client Identification + +Clients are identified using the following order of precedence: +1. `X-Forwarded-For` header (first IP if comma-separated) +2. `X-Real-IP` header +3. Client IP address from request + +## Rate Limit Headers + +All responses include rate limiting headers: + +| Header | Description | +|--------|-------------| +| `X-RateLimit-Limit` | Maximum requests allowed in window | +| `X-RateLimit-Remaining` | Remaining requests in current window | +| `X-RateLimit-Reset` | Unix timestamp when window resets | +| `X-RateLimit-Window` | Window size in seconds | +| `Retry-After` | Seconds to wait before retrying (when rate limited) | + +## Rate Limit Exceeded Response + +When rate limit is exceeded, the API returns: + +```json +{ + "success": false, + "error": "Rate limit exceeded", + "message": "Too many requests. Please try again later.", + "details": { + "limit": 30, + "window_seconds": 300, + "retry_after": 120 + } +} +``` + +HTTP Status Code: `429 Too Many Requests` + +## Example Responses + +### Successful Request +```http +HTTP/1.1 200 OK +X-RateLimit-Limit: 100 +X-RateLimit-Remaining: 99 +X-RateLimit-Reset: 1691751600 +X-RateLimit-Window: 3600 +Content-Type: application/json + +{ + "success": true, + "data": {...}, + "message": "Request processed successfully" +} +``` + +### Rate Limited Request +```http +HTTP/1.1 429 Too Many Requests +X-RateLimit-Limit: 30 +X-RateLimit-Remaining: 0 +X-RateLimit-Reset: 1691751600 +X-RateLimit-Window: 300 +Retry-After: 120 +Content-Type: application/json + +{ + "success": false, + "error": "Rate limit exceeded", + "message": "Too many requests. Please try again later.", + "details": { + "limit": 30, + "window_seconds": 300, + "retry_after": 120 + } +} +``` + +## Rate Limit Information Endpoint + +Get current rate limiting configuration: + +```http +GET /api/v1/agents/rate-limits +``` + +Response: +```json +{ + "success": true, + "data": { + "enabled": true, + "global_limits": { + "requests": 100, + "window_seconds": 3600, + "window_description": "60 minutes" + }, + "chat_limits": { + "requests": 30, + "window_seconds": 300, + "window_description": "5 minutes" + }, + "headers_included": [ + "X-RateLimit-Limit", + "X-RateLimit-Remaining", + "X-RateLimit-Reset", + "X-RateLimit-Window", + "Retry-After (when limit exceeded)" + ] + }, + "message": "Rate limit information retrieved successfully" +} +``` + +## Implementation Details + +### Sliding Window Algorithm + +The rate limiting uses a sliding window algorithm that: +1. Tracks timestamps of all requests for each client +2. Removes requests older than the window size +3. Counts remaining requests to check against limit +4. Calculates appropriate headers + +### Memory Management + +The implementation includes automatic cleanup of old request records to prevent memory leaks. + +### Concurrency Safety + +The rate limiting store uses async locks to handle concurrent requests safely. + +## Monitoring and Observability + +Rate limiting events are logged with appropriate log levels: +- **INFO**: Normal rate limiting operations +- **WARNING**: Rate limit exceeded events +- **ERROR**: Rate limiting system errors + +Log example: +``` +2024-08-10 12:00:00 WARNING Rate limit exceeded for 192.168.1.1 on /api/v1/agents/chat +``` + +## Best Practices for Clients + +### Respect Rate Limits +- Monitor rate limit headers in responses +- Implement exponential backoff when rate limited +- Cache responses when appropriate + +### Handle Rate Limit Errors +```javascript +// Example client handling +const response = await fetch('/api/v1/agents/chat', options); + +if (response.status === 429) { + const retryAfter = response.headers.get('Retry-After'); + console.log(`Rate limited. Retry after ${retryAfter} seconds`); + + // Wait and retry + setTimeout(() => { + // Retry request + }, retryAfter * 1000); +} +``` + +### Monitor Usage +```javascript +// Check remaining requests +const remaining = response.headers.get('X-RateLimit-Remaining'); +const limit = response.headers.get('X-RateLimit-Limit'); + +console.log(`Requests remaining: ${remaining}/${limit}`); +``` + +## Configuration for Different Environments + +### Development +```env +RATE_LIMIT_ENABLED=true +RATE_LIMIT_REQUESTS=1000 +RATE_LIMIT_WINDOW=3600 +RATE_LIMIT_CHAT_REQUESTS=100 +RATE_LIMIT_CHAT_WINDOW=300 +``` + +### Production +```env +RATE_LIMIT_ENABLED=true +RATE_LIMIT_REQUESTS=100 +RATE_LIMIT_WINDOW=3600 +RATE_LIMIT_CHAT_REQUESTS=30 +RATE_LIMIT_CHAT_WINDOW=300 +``` + +### Testing +```env +RATE_LIMIT_ENABLED=false +# OR very high limits for load testing +RATE_LIMIT_REQUESTS=10000 +RATE_LIMIT_WINDOW=60 +``` + +## Troubleshooting + +### Rate Limiting Not Working +1. Check `RATE_LIMIT_ENABLED=true` in environment +2. Verify middleware is added to FastAPI app +3. Check logs for rate limiting errors + +### False Rate Limiting +1. Verify client identification (check headers) +2. Consider proxy/load balancer configuration +3. Review exempted paths configuration + +### Performance Issues +1. Monitor memory usage of rate limit store +2. Consider Redis-based store for distributed deployment +3. Adjust cleanup intervals if needed + +## Future Enhancements + +Potential improvements to consider: +- Redis-based distributed rate limiting +- Different limits for authenticated vs anonymous users +- IP whitelisting for trusted clients +- Rate limiting by user ID instead of IP +- Configurable rate limiting per endpoint diff --git a/env.example b/env.example index d1fe76b..7568c26 100644 --- a/env.example +++ b/env.example @@ -37,4 +37,11 @@ CORS_ORIGINS=["http://localhost:3000", "https://apiconf.net"] # Security SECRET_KEY=your-secret-key-here -ACCESS_TOKEN_EXPIRE_MINUTES=30 \ No newline at end of file +ACCESS_TOKEN_EXPIRE_MINUTES=30 + +# Rate Limiting Configuration +RATE_LIMIT_ENABLED=true +RATE_LIMIT_REQUESTS=100 +RATE_LIMIT_WINDOW=3600 +RATE_LIMIT_CHAT_REQUESTS=30 +RATE_LIMIT_CHAT_WINDOW=300 \ No newline at end of file diff --git a/main.py b/main.py index fba1d9f..6983988 100644 --- a/main.py +++ b/main.py @@ -12,6 +12,7 @@ from app.config.logger import Logger from app.config.settings import settings from app.api.v1.agents_router import router as agents_router +from app.middleware.rate_limit import RateLimitMiddleware # Setup logging Logger.setup_root_logger() @@ -53,6 +54,10 @@ async def lifespan(app: FastAPI): allow_headers=["*"], ) +# Add rate limiting middleware (if enabled) +if settings.rate_limit_enabled: + app.add_middleware(RateLimitMiddleware) + # Include routers app.include_router(agents_router, prefix="/api/v1/agents", tags=["agents"]) diff --git a/scripts/test_rate_limiting.py b/scripts/test_rate_limiting.py new file mode 100644 index 0000000..0e24511 --- /dev/null +++ b/scripts/test_rate_limiting.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python3 +""" +Rate limiting test script. + +This script tests the rate limiting functionality by making multiple requests +to the API and checking the rate limit headers and responses. +""" + +import asyncio +import aiohttp +import time +import json +from typing import Dict, Any + + +class RateLimitTester: + """Test rate limiting functionality.""" + + def __init__(self, base_url: str = "http://localhost:8000"): + self.base_url = base_url + self.session = None + + async def __aenter__(self): + self.session = aiohttp.ClientSession() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if self.session: + await self.session.close() + + async def make_request(self, endpoint: str, method: str = "GET", **kwargs) -> Dict[str, Any]: + """Make a request and return response data with headers.""" + url = f"{self.base_url}{endpoint}" + + async with self.session.request(method, url, **kwargs) as response: + headers = dict(response.headers) + + try: + data = await response.json() + except: + data = await response.text() + + return { + "status": response.status, + "headers": headers, + "data": data, + "timestamp": time.time() + } + + def print_rate_limit_info(self, response: Dict[str, Any]): + """Print rate limit information from response.""" + headers = response["headers"] + status = response["status"] + + print(f"Status: {status}") + print(f"Rate Limit Headers:") + + rate_limit_headers = [ + "X-RateLimit-Limit", + "X-RateLimit-Remaining", + "X-RateLimit-Reset", + "X-RateLimit-Window", + "Retry-After" + ] + + for header in rate_limit_headers: + value = headers.get(header.lower(), "Not present") + print(f" {header}: {value}") + + print("-" * 50) + + async def test_global_rate_limits(self, requests_count: int = 10): + """Test global rate limits on status endpoint.""" + print(f"Testing global rate limits with {requests_count} requests to /api/v1/agents/status") + print("=" * 60) + + for i in range(requests_count): + response = await self.make_request("/api/v1/agents/status") + + print(f"Request {i+1}:") + self.print_rate_limit_info(response) + + if response["status"] == 429: + print("Rate limit exceeded!") + break + + # Small delay between requests + await asyncio.sleep(0.1) + + async def test_chat_rate_limits(self, requests_count: int = 10): + """Test chat endpoint rate limits.""" + print(f"Testing chat rate limits with {requests_count} requests to /api/v1/agents/chat") + print("=" * 60) + + chat_payload = { + "message": "Hello, this is a test message", + "user_id": "test_user", + "session_id": "test_session" + } + + for i in range(requests_count): + response = await self.make_request( + "/api/v1/agents/chat", + method="POST", + json=chat_payload + ) + + print(f"Chat Request {i+1}:") + self.print_rate_limit_info(response) + + if response["status"] == 429: + print("Rate limit exceeded!") + retry_after = response["headers"].get("retry-after") + if retry_after: + print(f"Retry after: {retry_after} seconds") + break + + # Small delay between requests + await asyncio.sleep(0.1) + + async def test_exempted_endpoints(self): + """Test that exempted endpoints don't have rate limits.""" + print("Testing exempted endpoints") + print("=" * 60) + + exempted_endpoints = [ + "/", + "/api/v1/agents/health", + "/docs" + ] + + for endpoint in exempted_endpoints: + response = await self.make_request(endpoint) + + print(f"Exempted endpoint {endpoint}:") + print(f"Status: {response['status']}") + + # Check if rate limit headers are present + headers = response["headers"] + has_rate_limit = any( + h.startswith("x-ratelimit") for h in headers.keys() + ) + + if has_rate_limit: + print("Rate limit headers found (unexpected)") + else: + print("No rate limit headers (expected)") + + print("-" * 30) + + async def test_rate_limit_info_endpoint(self): + """Test the rate limit info endpoint.""" + print("Testing rate limit info endpoint") + print("=" * 60) + + response = await self.make_request("/api/v1/agents/rate-limits") + + print(f"Status: {response['status']}") + + if response["status"] == 200: + data = response["data"] + if isinstance(data, dict) and "data" in data: + rate_info = data["data"] + print("Rate Limit Configuration:") + print(f" Enabled: {rate_info.get('enabled')}") + + global_limits = rate_info.get('global_limits', {}) + print(f" Global: {global_limits.get('requests')} requests per {global_limits.get('window_description')}") + + chat_limits = rate_info.get('chat_limits', {}) + print(f" Chat: {chat_limits.get('requests')} requests per {chat_limits.get('window_description')}") + + print("-" * 50) + + async def test_concurrent_requests(self, concurrent_count: int = 5): + """Test concurrent requests to check for race conditions.""" + print(f"Testing {concurrent_count} concurrent requests") + print("=" * 60) + + async def make_concurrent_request(request_id: int): + response = await self.make_request("/api/v1/agents/status") + return request_id, response + + # Make concurrent requests + tasks = [ + make_concurrent_request(i) + for i in range(concurrent_count) + ] + + results = await asyncio.gather(*tasks) + + for request_id, response in results: + print(f"Concurrent Request {request_id + 1}:") + print(f" Status: {response['status']}") + + headers = response["headers"] + remaining = headers.get("x-ratelimit-remaining", "N/A") + print(f" Remaining: {remaining}") + + print("-" * 50) + + +async def main(): + """Run all rate limiting tests.""" + print("RTAC API Rate Limiting Test Suite") + print("=" * 60) + + async with RateLimitTester() as tester: + try: + # Test rate limit info endpoint first + await tester.test_rate_limit_info_endpoint() + + # Test exempted endpoints + await tester.test_exempted_endpoints() + + # Test concurrent requests + await tester.test_concurrent_requests() + + # Test global rate limits + await tester.test_global_rate_limits(5) + + # Wait a bit + print("Waiting 2 seconds...") + await asyncio.sleep(2) + + # Test chat rate limits + await tester.test_chat_rate_limits(5) + + except Exception as e: + print(f"Test failed with error: {e}") + return + + print("Rate limiting tests completed!") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/test_rate_limit.py b/tests/test_rate_limit.py new file mode 100644 index 0000000..126265a --- /dev/null +++ b/tests/test_rate_limit.py @@ -0,0 +1,270 @@ +""" +Tests for rate limiting middleware. +""" + +import pytest +import asyncio +import time +from unittest.mock import Mock, patch +from fastapi import FastAPI, Request +from fastapi.testclient import TestClient +from starlette.responses import JSONResponse + +from app.middleware.rate_limit import RateLimitMiddleware, RateLimitStore +from app.config.settings import settings + + +class MockRequest: + """Mock request for testing.""" + + def __init__(self, path="/", method="GET", client_host="127.0.0.1", headers=None): + self.url = Mock() + self.url.path = path + self.method = method + self.client = Mock() + self.client.host = client_host + self.headers = headers or {} + + +class TestRateLimitStore: + """Test rate limit store functionality.""" + + @pytest.fixture + def store(self): + return RateLimitStore() + + @pytest.mark.asyncio + async def test_first_request_allowed(self, store): + """Test that first request is allowed.""" + is_allowed, headers = await store.is_allowed("client1", 10, 60) + + assert is_allowed is True + assert headers["X-RateLimit-Limit"] == 10 + assert headers["X-RateLimit-Remaining"] == 9 + assert "X-RateLimit-Reset" in headers + + @pytest.mark.asyncio + async def test_rate_limit_exceeded(self, store): + """Test rate limit exceeded.""" + # Make requests up to the limit + for i in range(5): + is_allowed, _ = await store.is_allowed("client1", 5, 60) + assert is_allowed is True + + # Next request should be blocked + is_allowed, headers = await store.is_allowed("client1", 5, 60) + assert is_allowed is False + assert headers["X-RateLimit-Remaining"] == 0 + assert "Retry-After" in headers + + @pytest.mark.asyncio + async def test_sliding_window(self, store): + """Test sliding window behavior.""" + # Make requests + for i in range(3): + await store.is_allowed("client1", 5, 2) # 2 second window + if i < 2: + await asyncio.sleep(0.1) + + # Wait for window to slide + await asyncio.sleep(2.5) + + # Should be allowed again + is_allowed, headers = await store.is_allowed("client1", 5, 2) + assert is_allowed is True + assert headers["X-RateLimit-Remaining"] == 4 + + @pytest.mark.asyncio + async def test_different_clients(self, store): + """Test that different clients have separate limits.""" + # Client 1 uses up their limit + for i in range(3): + await store.is_allowed("client1", 3, 60) + + # Client 1 should be blocked + is_allowed, _ = await store.is_allowed("client1", 3, 60) + assert is_allowed is False + + # Client 2 should still be allowed + is_allowed, _ = await store.is_allowed("client2", 3, 60) + assert is_allowed is True + + +class TestRateLimitMiddleware: + """Test rate limit middleware.""" + + @pytest.fixture + def app(self): + app = FastAPI() + + @app.get("/test") + async def test_endpoint(): + return {"message": "success"} + + @app.post("/chat") + async def chat_endpoint(): + return {"message": "chat response"} + + @app.get("/health") + async def health_endpoint(): + return {"status": "healthy"} + + return app + + @pytest.fixture + def middleware(self): + store = RateLimitStore() + return RateLimitMiddleware(None, store) + + def test_get_client_identifier_forwarded_for(self, middleware): + """Test client identification with X-Forwarded-For header.""" + request = MockRequest(headers={"X-Forwarded-For": "192.168.1.1, 10.0.0.1"}) + client_id = middleware.get_client_identifier(request) + assert client_id == "192.168.1.1" + + def test_get_client_identifier_real_ip(self, middleware): + """Test client identification with X-Real-IP header.""" + request = MockRequest(headers={"X-Real-IP": "192.168.1.2"}) + client_id = middleware.get_client_identifier(request) + assert client_id == "192.168.1.2" + + def test_get_client_identifier_fallback(self, middleware): + """Test client identification fallback to client host.""" + request = MockRequest(client_host="127.0.0.1") + client_id = middleware.get_client_identifier(request) + assert client_id == "127.0.0.1" + + def test_get_rate_limit_config_chat(self, middleware): + """Test rate limit config for chat endpoint.""" + request = MockRequest(path="/api/v1/agents/chat") + limit, window = middleware.get_rate_limit_config(request) + assert limit == middleware.chat_limit + assert window == middleware.chat_window + + def test_get_rate_limit_config_global(self, middleware): + """Test rate limit config for other endpoints.""" + request = MockRequest(path="/api/v1/agents/status") + limit, window = middleware.get_rate_limit_config(request) + assert limit == middleware.global_limit + assert window == middleware.global_window + + @pytest.mark.asyncio + async def test_exempted_paths(self, middleware): + """Test that exempted paths bypass rate limiting.""" + request = MockRequest(path="/health") + + async def mock_call_next(req): + return JSONResponse({"status": "healthy"}) + + response = await middleware.dispatch(request, mock_call_next) + assert response.status_code == 200 + + @pytest.mark.asyncio + async def test_options_requests_bypassed(self, middleware): + """Test that OPTIONS requests bypass rate limiting.""" + request = MockRequest(path="/api/v1/agents/chat", method="OPTIONS") + + async def mock_call_next(req): + return JSONResponse({"message": "options"}) + + response = await middleware.dispatch(request, mock_call_next) + assert response.status_code == 200 + + +@pytest.mark.integration +class TestRateLimitIntegration: + """Integration tests with FastAPI.""" + + @pytest.fixture + def app_with_middleware(self): + app = FastAPI() + + # Add middleware + app.add_middleware(RateLimitMiddleware) + + @app.get("/test") + async def test_endpoint(): + return {"message": "success"} + + @app.post("/chat") + async def chat_endpoint(): + return {"message": "chat response"} + + return app + + def test_rate_limit_headers_included(self, app_with_middleware): + """Test that rate limit headers are included in response.""" + with patch.object(settings, 'rate_limit_requests', 100): + with patch.object(settings, 'rate_limit_window', 3600): + client = TestClient(app_with_middleware) + response = client.get("/test") + + assert response.status_code == 200 + assert "X-RateLimit-Limit" in response.headers + assert "X-RateLimit-Remaining" in response.headers + assert "X-RateLimit-Reset" in response.headers + assert "X-RateLimit-Window" in response.headers + + def test_rate_limit_exceeded_response(self, app_with_middleware): + """Test response when rate limit is exceeded.""" + with patch.object(settings, 'rate_limit_requests', 2): + with patch.object(settings, 'rate_limit_window', 60): + client = TestClient(app_with_middleware) + + # Make requests up to limit + response1 = client.get("/test") + response2 = client.get("/test") + assert response1.status_code == 200 + assert response2.status_code == 200 + + # Next request should be rate limited + response3 = client.get("/test") + assert response3.status_code == 429 + + data = response3.json() + assert data["success"] is False + assert data["error"] == "Rate limit exceeded" + assert "Retry-After" in response3.headers + + +@pytest.mark.performance +class TestRateLimitPerformance: + """Performance tests for rate limiting.""" + + @pytest.mark.asyncio + async def test_concurrent_requests(self): + """Test handling of concurrent requests.""" + store = RateLimitStore() + + async def make_request(client_id): + return await store.is_allowed(client_id, 100, 60) + + # Make 50 concurrent requests for same client + tasks = [make_request("client1") for _ in range(50)] + results = await asyncio.gather(*tasks) + + # All should be processed + assert len(results) == 50 + + # Should respect rate limit + allowed_count = sum(1 for is_allowed, _ in results if is_allowed) + assert allowed_count <= 100 + + @pytest.mark.asyncio + async def test_memory_cleanup(self): + """Test that old requests are cleaned up to prevent memory leaks.""" + store = RateLimitStore() + + # Make requests for many different clients + for i in range(1000): + await store.is_allowed(f"client_{i}", 10, 1) # 1 second window + + # Wait for cleanup + await asyncio.sleep(2) + + # Make a new request to trigger cleanup + await store.is_allowed("new_client", 10, 1) + + # Check that old requests are cleaned up + # (This is more of a smoke test - in practice you'd monitor memory usage) + assert len(store.requests) < 1000