diff --git a/docs/my-website/docs/providers/vertex.md b/docs/my-website/docs/providers/vertex.md index 6f0210f270b8..2eeb37492ccc 100644 --- a/docs/my-website/docs/providers/vertex.md +++ b/docs/my-website/docs/providers/vertex.md @@ -1604,6 +1604,53 @@ litellm.vertex_location = "us-central1 # Your Location | gemini-2.5-flash-preview-09-2025 | `completion('gemini-2.5-flash-preview-09-2025', messages)`, `completion('vertex_ai/gemini-2.5-flash-preview-09-2025', messages)` | | gemini-2.5-flash-lite-preview-09-2025 | `completion('gemini-2.5-flash-lite-preview-09-2025', messages)`, `completion('vertex_ai/gemini-2.5-flash-lite-preview-09-2025', messages)` | +## Private Service Connect (PSC) Endpoints + +LiteLLM supports Vertex AI models deployed to Private Service Connect (PSC) endpoints, allowing you to use custom `api_base` URLs for private deployments. + +### Usage + +```python +from litellm import completion + +# Use PSC endpoint with custom api_base +response = completion( + model="vertex_ai/1234567890", # Numeric endpoint ID + messages=[{"role": "user", "content": "Hello!"}], + api_base="http://10.96.32.8", # Your PSC endpoint + vertex_project="my-project-id", + vertex_location="us-central1" +) +``` + +**Key Features:** +- Supports both numeric endpoint IDs and custom model names +- Works with both completion and embedding endpoints +- Automatically constructs full PSC URL: `{api_base}/v1/projects/{project}/locations/{location}/endpoints/{model}:{endpoint}` +- Compatible with streaming requests + +### Configuration + +Add PSC endpoints to your `config.yaml`: + +```yaml +model_list: + - model_name: psc-gemini + litellm_params: + model: vertex_ai/1234567890 # Numeric endpoint ID + api_base: "http://10.96.32.8" # Your PSC endpoint + vertex_project: "my-project-id" + vertex_location: "us-central1" + vertex_credentials: "/path/to/service_account.json" + - model_name: psc-embedding + litellm_params: + model: vertex_ai/text-embedding-004 + api_base: "http://10.96.32.8" # Your PSC endpoint + vertex_project: "my-project-id" + vertex_location: "us-central1" + vertex_credentials: "/path/to/service_account.json" +``` + ## Fine-tuned Models You can call fine-tuned Vertex AI Gemini models through LiteLLM diff --git a/litellm/llms/vertex_ai/batches/handler.py b/litellm/llms/vertex_ai/batches/handler.py index 7932881f482c..b40f0a72a50d 100644 --- a/litellm/llms/vertex_ai/batches/handler.py +++ b/litellm/llms/vertex_ai/batches/handler.py @@ -61,6 +61,10 @@ def create_batch( stream=None, auth_header=None, url=default_api_base, + model=None, + vertex_project=vertex_project or project_id, + vertex_location=vertex_location or "us-central1", + vertex_api_version="v1", ) headers = { @@ -166,6 +170,10 @@ def retrieve_batch( stream=None, auth_header=None, url=default_api_base, + model=None, + vertex_project=vertex_project or project_id, + vertex_location=vertex_location or "us-central1", + vertex_api_version="v1", ) headers = { diff --git a/litellm/llms/vertex_ai/common_utils.py b/litellm/llms/vertex_ai/common_utils.py index a750b5e985f5..8708e450e06c 100644 --- a/litellm/llms/vertex_ai/common_utils.py +++ b/litellm/llms/vertex_ai/common_utils.py @@ -57,6 +57,9 @@ def get_vertex_ai_model_route(model: str, litellm_params: Optional[dict] = None) >>> get_vertex_ai_model_route("openai/gpt-oss-120b") VertexAIModelRoute.MODEL_GARDEN + + >>> get_vertex_ai_model_route("1234567890", {"api_base": "http://10.96.32.8"}) + VertexAIModelRoute.GEMINI # Numeric endpoints with api_base use HTTP path """ from litellm.llms.vertex_ai.vertex_ai_partner_models.main import ( VertexAIPartnerModels, @@ -67,6 +70,11 @@ def get_vertex_ai_model_route(model: str, litellm_params: Optional[dict] = None) if "gemini" in litellm_params["base_model"]: return VertexAIModelRoute.GEMINI + # Check if numeric endpoint ID with custom api_base (PSC endpoint) + # Route to GEMINI (HTTP path) to support PSC endpoints properly + if model.isdigit() and litellm_params and litellm_params.get("api_base"): + return VertexAIModelRoute.GEMINI + # Check for partner models (llama, mistral, claude, etc.) if VertexAIPartnerModels.is_vertex_partner_model(model=model): return VertexAIModelRoute.PARTNER_MODELS diff --git a/litellm/llms/vertex_ai/context_caching/vertex_ai_context_caching.py b/litellm/llms/vertex_ai/context_caching/vertex_ai_context_caching.py index 70b068b5a4d9..dabc620a6da3 100644 --- a/litellm/llms/vertex_ai/context_caching/vertex_ai_context_caching.py +++ b/litellm/llms/vertex_ai/context_caching/vertex_ai_context_caching.py @@ -79,6 +79,10 @@ def _get_token_and_url_context_caching( stream=None, auth_header=auth_header, url=url, + model=None, + vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_api_version="v1beta1" if custom_llm_provider == "vertex_ai_beta" else "v1", ) def check_cache( diff --git a/litellm/llms/vertex_ai/vertex_embeddings/transformation.py b/litellm/llms/vertex_ai/vertex_embeddings/transformation.py index 97af558041d3..caaf00e199e3 100644 --- a/litellm/llms/vertex_ai/vertex_embeddings/transformation.py +++ b/litellm/llms/vertex_ai/vertex_embeddings/transformation.py @@ -167,6 +167,9 @@ def _transform_openai_request_to_fine_tuned_embedding_request( vertex_request["parameters"] = TextEmbeddingFineTunedParameters( **optional_params ) + # Remove 'shared_session' from parameters if present + if vertex_request["parameters"] is not None and "shared_session" in vertex_request["parameters"]: + del vertex_request["parameters"]["shared_session"] # type: ignore[typeddict-item] return vertex_request diff --git a/litellm/llms/vertex_ai/vertex_llm_base.py b/litellm/llms/vertex_ai/vertex_llm_base.py index 8f7c846bd7d7..2a7f081c58a0 100644 --- a/litellm/llms/vertex_ai/vertex_llm_base.py +++ b/litellm/llms/vertex_ai/vertex_llm_base.py @@ -241,6 +241,9 @@ def get_complete_vertex_url( auth_header=None, url=default_api_base, model=model, + vertex_project=vertex_project or project_id, + vertex_location=vertex_location or "us-central1", + vertex_api_version="v1", # Partner models typically use v1 ) return api_base @@ -289,9 +292,18 @@ def _check_custom_proxy( auth_header: Optional[str], url: str, model: Optional[str] = None, + vertex_project: Optional[str] = None, + vertex_location: Optional[str] = None, + vertex_api_version: Optional[Literal["v1", "v1beta1"]] = None, ) -> Tuple[Optional[str], str]: """ for cloudflare ai gateway - https://github.com/BerriAI/litellm/issues/4317 + + Handles custom api_base for: + 1. Gemini (Google AI Studio) - constructs /models/{model}:{endpoint} + 2. Vertex AI with standard proxies - constructs {api_base}:{endpoint} + 3. Vertex AI with PSC endpoints - constructs full path structure + {api_base}/v1/projects/{project}/locations/{location}/endpoints/{model}:{endpoint} ## Returns - (auth_header, url) - Tuple[Optional[str], str] @@ -312,8 +324,34 @@ def _check_custom_proxy( gemini_api_key # cloudflare expects api key as bearer token ) else: - url = "{}:{}".format(api_base, endpoint) - + # For Vertex AI + # Check if this is a PSC endpoint or custom deployment + # PSC/custom endpoints need the full path structure + if vertex_project and vertex_location and model: + # Check if model is numeric (endpoint ID) or if api_base doesn't contain googleapis.com + # These are indicators of PSC/custom endpoints + is_psc_or_custom = ( + "googleapis.com" not in api_base.lower() or model.isdigit() + ) + + if is_psc_or_custom: + # Construct full PSC/custom endpoint URL + # Format: {api_base}/v1/projects/{project}/locations/{location}/endpoints/{model}:{endpoint} + version = vertex_api_version or "v1" + url = "{}/{}/projects/{}/locations/{}/endpoints/{}:{}".format( + api_base.rstrip("/"), + version, + vertex_project, + vertex_location, + model, + endpoint, + ) + else: + # Standard proxy - just append endpoint + url = "{}:{}".format(api_base, endpoint) + else: + # Fallback to simple format if we don't have all parameters + url = "{}:{}".format(api_base, endpoint) if stream is True: url = url + "?alt=sse" return auth_header, url @@ -340,6 +378,7 @@ def _get_token_and_url( Returns token, url """ + version: Optional[Literal["v1beta1", "v1"]] = None if custom_llm_provider == "gemini": url, endpoint = _get_gemini_url( mode=mode, @@ -355,7 +394,7 @@ def _get_token_and_url( ) ### SET RUNTIME ENDPOINT ### - version: Literal["v1beta1", "v1"] = ( + version = ( "v1beta1" if should_use_v1beta1_features is True else "v1" ) url, endpoint = _get_vertex_url( @@ -376,6 +415,9 @@ def _get_token_and_url( stream=stream, url=url, model=model, + vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_api_version=version, ) def _handle_reauthentication( diff --git a/litellm/llms/vertex_ai/vertex_model_garden/main.py b/litellm/llms/vertex_ai/vertex_model_garden/main.py index 1c57096734b4..225e75a5add3 100644 --- a/litellm/llms/vertex_ai/vertex_model_garden/main.py +++ b/litellm/llms/vertex_ai/vertex_model_garden/main.py @@ -123,6 +123,10 @@ def completion( stream=stream, auth_header=None, url=default_api_base, + model=model, + vertex_project=vertex_project or project_id, + vertex_location=vertex_location or "us-central1", + vertex_api_version="v1beta1", ) model = "" return openai_like_chat_completions.completion( diff --git a/tests/test_litellm/llms/vertex_ai/test_vertex_ai_psc_endpoint_support.py b/tests/test_litellm/llms/vertex_ai/test_vertex_ai_psc_endpoint_support.py new file mode 100644 index 000000000000..46f365094c0b --- /dev/null +++ b/tests/test_litellm/llms/vertex_ai/test_vertex_ai_psc_endpoint_support.py @@ -0,0 +1,258 @@ +""" +Unit tests for Vertex AI Private Service Connect (PSC) endpoint support + +Tests that LiteLLM properly constructs URLs when using custom api_base +for PSC endpoints. +""" + +import pytest +import sys +import os + +# Add the litellm package to the path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../../..")) + +from litellm.llms.vertex_ai.vertex_llm_base import VertexBase + + +class TestVertexAIPSCEndpointSupport: + """Test cases for PSC endpoint URL construction""" + + def test_psc_endpoint_url_construction_basic(self): + """Test basic PSC endpoint URL construction for predict endpoint""" + vertex_base = VertexBase() + psc_api_base = "http://10.96.32.8" + endpoint_id = "1234567890" + project_id = "test-project" + location = "us-central1" + + auth_header, url = vertex_base._check_custom_proxy( + api_base=psc_api_base, + custom_llm_provider="vertex_ai", + gemini_api_key=None, + endpoint="predict", + stream=False, + auth_header="test-token", + url="", # This will be replaced + model=endpoint_id, + vertex_project=project_id, + vertex_location=location, + vertex_api_version="v1", + ) + + expected_url = f"{psc_api_base}/v1/projects/{project_id}/locations/{location}/endpoints/{endpoint_id}:predict" + assert ( + url == expected_url + ), f"Expected {expected_url}, but got {url}" + + def test_psc_endpoint_url_construction_with_streaming(self): + """Test PSC endpoint URL construction with streaming enabled""" + vertex_base = VertexBase() + psc_api_base = "http://10.96.32.8" + endpoint_id = "1234567890" + project_id = "test-project" + location = "us-central1" + + auth_header, url = vertex_base._check_custom_proxy( + api_base=psc_api_base, + custom_llm_provider="vertex_ai", + gemini_api_key=None, + endpoint="streamGenerateContent", + stream=True, + auth_header="test-token", + url="", + model=endpoint_id, + vertex_project=project_id, + vertex_location=location, + vertex_api_version="v1", + ) + + expected_url = f"{psc_api_base}/v1/projects/{project_id}/locations/{location}/endpoints/{endpoint_id}:streamGenerateContent?alt=sse" + assert ( + url == expected_url + ), f"Expected {expected_url}, but got {url}" + + def test_psc_endpoint_url_construction_v1beta1(self): + """Test PSC endpoint URL construction with v1beta1 API version""" + vertex_base = VertexBase() + psc_api_base = "http://10.96.32.8" + endpoint_id = "1234567890" + project_id = "test-project" + location = "us-central1" + + auth_header, url = vertex_base._check_custom_proxy( + api_base=psc_api_base, + custom_llm_provider="vertex_ai", + gemini_api_key=None, + endpoint="predict", + stream=False, + auth_header="test-token", + url="", + model=endpoint_id, + vertex_project=project_id, + vertex_location=location, + vertex_api_version="v1beta1", + ) + + expected_url = f"{psc_api_base}/v1beta1/projects/{project_id}/locations/{location}/endpoints/{endpoint_id}:predict" + assert ( + url == expected_url + ), f"Expected {expected_url}, but got {url}" + + def test_psc_endpoint_url_with_https(self): + """Test PSC endpoint URL construction with HTTPS""" + vertex_base = VertexBase() + psc_api_base = "https://10.96.32.8" + endpoint_id = "1234567890" + project_id = "test-project" + location = "us-central1" + + auth_header, url = vertex_base._check_custom_proxy( + api_base=psc_api_base, + custom_llm_provider="vertex_ai", + gemini_api_key=None, + endpoint="predict", + stream=False, + auth_header="test-token", + url="", + model=endpoint_id, + vertex_project=project_id, + vertex_location=location, + vertex_api_version="v1", + ) + + expected_url = f"{psc_api_base}/v1/projects/{project_id}/locations/{location}/endpoints/{endpoint_id}:predict" + assert ( + url == expected_url + ), f"Expected {expected_url}, but got {url}" + + def test_psc_endpoint_with_trailing_slash(self): + """Test that trailing slashes in api_base are handled correctly""" + vertex_base = VertexBase() + psc_api_base = "http://10.96.32.8/" + endpoint_id = "1234567890" + project_id = "test-project" + location = "us-central1" + + auth_header, url = vertex_base._check_custom_proxy( + api_base=psc_api_base, + custom_llm_provider="vertex_ai", + gemini_api_key=None, + endpoint="predict", + stream=False, + auth_header="test-token", + url="", + model=endpoint_id, + vertex_project=project_id, + vertex_location=location, + vertex_api_version="v1", + ) + + # rstrip('/') should remove the trailing slash + expected_url = f"{psc_api_base.rstrip('/')}/v1/projects/{project_id}/locations/{location}/endpoints/{endpoint_id}:predict" + assert ( + url == expected_url + ), f"Expected {expected_url}, but got {url}" + + def test_standard_proxy_with_googleapis(self): + """Test that standard proxies with googleapis.com in URL use simple format""" + vertex_base = VertexBase() + proxy_api_base = "https://my-proxy.googleapis.com" + endpoint_id = "gemini-pro" # Not numeric + project_id = "test-project" + location = "us-central1" + + auth_header, url = vertex_base._check_custom_proxy( + api_base=proxy_api_base, + custom_llm_provider="vertex_ai", + gemini_api_key=None, + endpoint="generateContent", + stream=False, + auth_header="test-token", + url="", + model=endpoint_id, + vertex_project=project_id, + vertex_location=location, + vertex_api_version="v1", + ) + + # Should use simple format: api_base:endpoint + expected_url = f"{proxy_api_base}:generateContent" + assert ( + url == expected_url + ), f"Expected {expected_url}, but got {url}" + + def test_custom_proxy_with_numeric_model(self): + """Test that numeric model IDs trigger PSC-style URL construction""" + vertex_base = VertexBase() + proxy_api_base = "https://my-custom-proxy.example.com" + endpoint_id = "9876543210" # Numeric endpoint ID + project_id = "test-project" + location = "us-central1" + + auth_header, url = vertex_base._check_custom_proxy( + api_base=proxy_api_base, + custom_llm_provider="vertex_ai", + gemini_api_key=None, + endpoint="predict", + stream=False, + auth_header="test-token", + url="", + model=endpoint_id, + vertex_project=project_id, + vertex_location=location, + vertex_api_version="v1", + ) + + # Numeric model should trigger full path construction + expected_url = f"{proxy_api_base}/v1/projects/{project_id}/locations/{location}/endpoints/{endpoint_id}:predict" + assert ( + url == expected_url + ), f"Expected {expected_url}, but got {url}" + + def test_no_api_base_returns_original_url(self): + """Test that when api_base is None, the original URL is returned""" + vertex_base = VertexBase() + original_url = "https://us-central1-aiplatform.googleapis.com/v1/projects/test/locations/us-central1/publishers/google/models/gemini-pro:generateContent" + + auth_header, url = vertex_base._check_custom_proxy( + api_base=None, + custom_llm_provider="vertex_ai", + gemini_api_key=None, + endpoint="generateContent", + stream=False, + auth_header="test-token", + url=original_url, + model="gemini-pro", + vertex_project="test-project", + vertex_location="us-central1", + vertex_api_version="v1", + ) + + # When api_base is None, original URL should be returned unchanged + assert url == original_url, f"Expected {original_url}, but got {url}" + + def test_auth_header_preserved(self): + """Test that auth_header is properly preserved""" + vertex_base = VertexBase() + psc_api_base = "http://10.96.32.8" + test_auth_header = "Bearer test-token-12345" + + auth_header, url = vertex_base._check_custom_proxy( + api_base=psc_api_base, + custom_llm_provider="vertex_ai", + gemini_api_key=None, + endpoint="predict", + stream=False, + auth_header=test_auth_header, + url="", + model="1234567890", + vertex_project="test-project", + vertex_location="us-central1", + vertex_api_version="v1", + ) + + assert ( + auth_header == test_auth_header + ), f"Auth header should be preserved, got {auth_header}" +