Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .github/trigger_files/beam_PostCommit_Python.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to run.",
"pr": "36271",
"modification": 36
}

"modification": 37
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

FROM python:3.10-slim
WORKDIR /app
RUN pip install flask gunicorn
COPY echo_server.py main.py
CMD ["gunicorn", "--bind", "0.0.0.0:8080", "main:app"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
<!--
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->

# Vertex AI Custom Prediction Route Test Setup

To run the `test_vertex_ai_custom_prediction_route` in [vertex_ai_inference_it_test.py](../../vertex_ai_inference_it_test.py), you need a dedicated Vertex AI endpoint with an invoke-enabled model deployed.

## Resource Setup Steps

Run these commands in the `apache-beam-testing` project (or your own test project).

### 1. Build and Push Container

From this directory:

```bash
# on Linux
export PROJECT_ID="apache-beam-testing" # Or your project
export IMAGE_URI="gcr.io/${PROJECT_ID}/beam-ml/beam-invoke-echo-model:latest"

docker build -t ${IMAGE_URI} .
docker push ${IMAGE_URI}
```

### 2. Upload Model and Deploy Endpoint

Use the Python SDK to deploy (easier than gcloud for specific invocation flags).

```python
from google.cloud import aiplatform

PROJECT_ID = "apache-beam-testing"
REGION = "us-central1"
IMAGE_URI = f"gcr.io/{PROJECT_ID}/beam-ml/beam-invoke-echo-model:latest"

aiplatform.init(project=PROJECT_ID, location=REGION)

# 1. Upload Model with invoke route enabled
model = aiplatform.Model.upload(
display_name="beam-invoke-echo-model",
serving_container_image_uri=IMAGE_URI,
serving_container_invoke_route_prefix="/*", # <--- Critical for custom routes
serving_container_health_route="/health",
sync=True,
)

# 2. Create Dedicated Endpoint (required for invoke)
endpoint = aiplatform.Endpoint.create(
display_name="beam-invoke-test-endpoint",
dedicated_endpoint_enabled=True,
sync=True,
)

# 3. Deploy Model
# NOTE: Set min_replica_count=0 to save costs when not testing
endpoint.deploy(
model=model,
traffic_percentage=100,
machine_type="n1-standard-2",
min_replica_count=0,
max_replica_count=1,
sync=True,
)

print(f"Deployment Complete!")
print(f"Endpoint ID: {endpoint.name}")
```

### 3. Update Test Configuration

1. Copy the **Endpoint ID** printed above (e.g., `1234567890`).
2. Update `_INVOKE_ENDPOINT_ID` in `apache_beam/ml/inference/vertex_ai_inference_it_test.py`.

## Cleanup

To avoid costs, undeploy and delete resources when finished:

```bash
# Undeploy model from endpoint
gcloud ai endpoints undeploy-model <ENDPOINT_ID> --deployed-model-id <DEPLOYED_MODEL_ID> --region=us-central1

# Delete endpoint
gcloud ai endpoints delete <ENDPOINT_ID> --region=us-central1

# Delete model
gcloud ai models delete <MODEL_ID> --region=us-central1
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from flask import Flask, request, jsonify

app = Flask(__name__)


@app.route('/predict', methods=['POST'])
def predict():
data = request.get_json()
# Echo back the instances
return jsonify({
"predictions": [{
"echo": inst
} for inst in data.get('instances', [])],
"deployedModelId": "echo-model"
})


@app.route('/health', methods=['GET'])
def health():
return 'OK', 200


if __name__ == '__main__':
app.run(host='0.0.0.0', port=8080)
70 changes: 67 additions & 3 deletions sdks/python/apache_beam/ml/inference/vertex_ai_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.
#

import json
import logging
from collections.abc import Iterable
from collections.abc import Mapping
Expand Down Expand Up @@ -63,6 +64,7 @@ def __init__(
experiment: Optional[str] = None,
network: Optional[str] = None,
private: bool = False,
invoke_route: Optional[str] = None,
*,
min_batch_size: Optional[int] = None,
max_batch_size: Optional[int] = None,
Expand Down Expand Up @@ -95,6 +97,12 @@ def __init__(
private: optional. if the deployed Vertex AI endpoint is
private, set to true. Requires a network to be provided
as well.
invoke_route: optional. the custom route path to use when invoking
endpoints with arbitrary prediction routes. When specified, uses
`Endpoint.invoke()` instead of `Endpoint.predict()`. The route
should start with a forward slash, e.g., "/predict/v1".
See https://cloud.google.com/vertex-ai/docs/predictions/use-arbitrary-custom-routes
for more information.
min_batch_size: optional. the minimum batch size to use when batching
inputs.
max_batch_size: optional. the maximum batch size to use when batching
Expand All @@ -104,6 +112,7 @@ def __init__(
"""
self._batching_kwargs = {}
self._env_vars = kwargs.get('env_vars', {})
self._invoke_route = invoke_route
if min_batch_size is not None:
self._batching_kwargs["min_batch_size"] = min_batch_size
if max_batch_size is not None:
Expand Down Expand Up @@ -203,9 +212,64 @@ def request(
Returns:
An iterable of Predictions.
"""
prediction = model.predict(instances=list(batch), parameters=inference_args)
return utils._convert_to_result(
batch, prediction.predictions, prediction.deployed_model_id)
if self._invoke_route:
# Use invoke() for endpoints with custom prediction routes
request_body: dict[str, Any] = {"instances": list(batch)}
if inference_args:
request_body["parameters"] = inference_args
response = model.invoke(
request_path=self._invoke_route,
body=json.dumps(request_body).encode("utf-8"),
headers={"Content-Type": "application/json"})
return self._parse_invoke_response(batch, bytes(response))
else:
prediction = model.predict(
instances=list(batch), parameters=inference_args)
return utils._convert_to_result(
batch, prediction.predictions, prediction.deployed_model_id)

def _parse_invoke_response(self, batch: Sequence[Any],
response: bytes) -> Iterable[PredictionResult]:
"""Parses the response from Endpoint.invoke() into PredictionResults.

Args:
batch: the original batch of inputs.
response: the raw bytes response from invoke().

Returns:
An iterable of PredictionResults.
"""
try:
response_json = json.loads(response.decode("utf-8"))
except (json.JSONDecodeError, UnicodeDecodeError) as e:
LOGGER.warning(
"Failed to decode invoke response as JSON, returning raw bytes: %s",
e)
# Return raw response for each batch item
return [
PredictionResult(example=example, inference=response)
for example in batch
]

# Handle standard Vertex AI response format with "predictions" key
if isinstance(response_json, dict) and "predictions" in response_json:
predictions = response_json["predictions"]
model_id = response_json.get("deployedModelId")
return utils._convert_to_result(batch, predictions, model_id)

# Handle response as a list of predictions (one per input)
if isinstance(response_json, list) and len(response_json) == len(batch):
return utils._convert_to_result(batch, response_json, None)

# Handle single prediction response
if len(batch) == 1:
return [PredictionResult(example=batch[0], inference=response_json)]

# Fallback: return the full response for each batch item
return [
PredictionResult(example=example, inference=response_json)
for example in batch
]

def batch_elements_kwargs(self) -> Mapping[str, Any]:
return self._batching_kwargs
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,15 @@

import pytest

import apache_beam as beam
from apache_beam.io.filesystems import FileSystems
from apache_beam.ml.inference.base import RunInference
from apache_beam.testing.test_pipeline import TestPipeline

# pylint: disable=ungrouped-imports
try:
from apache_beam.examples.inference import vertex_ai_image_classification
from apache_beam.ml.inference.vertex_ai_inference import VertexAIModelHandlerJSON
except ImportError as e:
raise unittest.SkipTest(
"Vertex AI model handler dependencies are not installed")
Expand All @@ -42,6 +45,13 @@
# pylint: disable=line-too-long
_SUBNETWORK = "https://www.googleapis.com/compute/v1/projects/apache-beam-testing/regions/us-central1/subnetworks/beam-test-vpc"

# Constants for custom prediction routes (invoke) test
# Follow beam/sdks/python/apache_beam/ml/inference/test_resources/vertex_ai_custom_prediction/README.md
# to get endpoint ID after deploying invoke-enabled model
_INVOKE_ENDPOINT_ID = "6890840581900075008"
_INVOKE_ROUTE = "/predict"
_INVOKE_OUTPUT_DIR = "gs://apache-beam-ml/testing/outputs/vertex_invoke"


class VertexAIInference(unittest.TestCase):
@pytest.mark.vertex_ai_postcommit
Expand All @@ -63,6 +73,43 @@ def test_vertex_ai_run_flower_image_classification(self):
test_pipeline.get_full_options_as_args(**extra_opts))
self.assertEqual(FileSystems().exists(output_file), True)

@pytest.mark.vertex_ai_postcommit
@unittest.skipIf(
not _INVOKE_ENDPOINT_ID,
"Invoke endpoint not configured. Set _INVOKE_ENDPOINT_ID.")
def test_vertex_ai_custom_prediction_route(self):
"""Test custom prediction routes using invoke_route parameter.

This test verifies that VertexAIModelHandlerJSON correctly uses
Endpoint.invoke() instead of Endpoint.predict() when invoke_route
is specified, enabling custom prediction routes.
"""
output_file = '/'.join(
[_INVOKE_OUTPUT_DIR, str(uuid.uuid4()), 'output.txt'])

test_pipeline = TestPipeline(is_integration_test=True)

model_handler = VertexAIModelHandlerJSON(
endpoint_id=_INVOKE_ENDPOINT_ID,
project=_ENDPOINT_PROJECT,
location=_ENDPOINT_REGION,
invoke_route=_INVOKE_ROUTE)

# Test inputs - simple data to echo back
test_inputs = [{"value": 1}, {"value": 2}, {"value": 3}]

with test_pipeline as p:
results = (
p
| "CreateInputs" >> beam.Create(test_inputs)
| "RunInference" >> RunInference(model_handler)
| "ExtractResults" >>
beam.Map(lambda result: f"{result.example}:{result.inference}"))
_ = results | "WriteOutput" >> beam.io.WriteToText(
output_file, shard_name_template='')

self.assertTrue(FileSystems().exists(output_file))


if __name__ == '__main__':
logging.getLogger().setLevel(logging.DEBUG)
Expand Down
Loading
Loading