Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 261 additions & 0 deletions sdks/python/apache_beam/ml/inference/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#

"""Tests for apache_beam.ml.base."""
import contextlib
import importlib
import math
import multiprocessing
import os
Expand All @@ -25,6 +27,7 @@
import tempfile
import time
import unittest
import unittest.mock
from collections.abc import Iterable
from collections.abc import Mapping
from collections.abc import Sequence
Expand Down Expand Up @@ -2319,6 +2322,264 @@ def test_batching_kwargs_none_values_omitted(self):
self.assertEqual(kwargs['min_batch_size'], 5)


class PaddingReportingStringModelHandler(base.ModelHandler[str, str,
FakeModel]):
"""Reports each element with the max length of the batch it ran in."""
def load_model(self):
return FakeModel()

def run_inference(self, batch, model, inference_args=None):
max_len = max(len(s) for s in batch)
return [f'{s}:{max_len}' for s in batch]


class RunInferenceLengthAwareBatchingTest(unittest.TestCase):
"""End-to-end tests for PR2 length-aware batching in RunInference."""
def test_run_inference_with_length_aware_batch_elements(self):
handler = PaddingReportingStringModelHandler(
min_batch_size=2,
max_batch_size=2,
max_batch_duration_secs=60,
batch_length_fn=len,
batch_bucket_boundaries=[5])

examples = ['a', 'cccccc', 'bb', 'ddddddd']
with TestPipeline('FnApiRunner') as p:
results = (
p
| beam.Create(examples, reshuffle=False)
| base.RunInference(handler))
assert_that(results, equal_to(['a:2', 'bb:2', 'cccccc:7', 'ddddddd:7']))


class HandlerBucketingKwargsForwardingTest(unittest.TestCase):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The test methods in this class are very repetitive. To improve maintainability and reduce code duplication, consider parameterizing these tests. You could use a library like parameterized or unittest.TestCase.subTest with a loop over a list of handler configurations. Each configuration could specify the handler class, its specific __init__ arguments, and any necessary mocks or setup.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. These cases all exercise the same forwarding assertion with different handler constructors, so I can consolidate them into a single data-driven test using subTest.

"""Verify each concrete ModelHandler forwards batch_length_fn and
batch_bucket_boundaries through to batch_elements_kwargs()."""
_BUCKETING_KWARGS = {
'batch_length_fn': len,
'batch_bucket_boundaries': [32],
}

def _assert_bucketing_kwargs_forwarded(self, handler):
kwargs = handler.batch_elements_kwargs()
self.assertIs(kwargs['length_fn'], len)
self.assertEqual(kwargs['bucket_boundaries'], [32])

def _load_handler_class(self, case):
try:
module = importlib.import_module(case['module_name'])
except ImportError:
raise unittest.SkipTest(case['skip_message'])
return getattr(module, case['class_name'])

@contextlib.contextmanager
def _handler_setup(self, case):
if not case.get('mock_aiplatform'):
yield
return

with unittest.mock.patch(
'apache_beam.ml.inference.vertex_ai_inference.aiplatform') as mock_aip:
mock_aip.init.return_value = None
mock_endpoint = unittest.mock.MagicMock()
mock_endpoint.list_models.return_value = ['fake-model']
mock_aip.Endpoint.return_value = mock_endpoint
yield

def _assert_handler_cases(self, cases):
for case in cases:
with self.subTest(handler=case['name']):
handler_cls = self._load_handler_class(case)
init_kwargs = dict(case['init_kwargs'])
init_kwargs.update(self._BUCKETING_KWARGS)

with self._handler_setup(case):
handler = handler_cls(**init_kwargs)

self._assert_bucketing_kwargs_forwarded(handler)

def test_pytorch_handlers(self):
self._assert_handler_cases((
{
'name': 'pytorch_tensor',
'module_name': 'apache_beam.ml.inference.pytorch_inference',
'class_name': 'PytorchModelHandlerTensor',
'skip_message': 'PyTorch not available',
'init_kwargs': {},
},
{
'name': 'pytorch_keyed_tensor',
'module_name': 'apache_beam.ml.inference.pytorch_inference',
'class_name': 'PytorchModelHandlerKeyedTensor',
'skip_message': 'PyTorch not available',
'init_kwargs': {},
},
))

def test_huggingface_handlers(self):
self._assert_handler_cases((
{
'name': 'huggingface_keyed_tensor',
'module_name': 'apache_beam.ml.inference.huggingface_inference',
'class_name': 'HuggingFaceModelHandlerKeyedTensor',
'skip_message': 'HuggingFace transformers not available',
'init_kwargs': {
'model_uri': 'unused',
'model_class': object,
'framework': 'pt',
},
},
{
'name': 'huggingface_tensor',
'module_name': 'apache_beam.ml.inference.huggingface_inference',
'class_name': 'HuggingFaceModelHandlerTensor',
'skip_message': 'HuggingFace transformers not available',
'init_kwargs': {
'model_uri': 'unused',
'model_class': object,
},
},
{
'name': 'huggingface_pipeline',
'module_name': 'apache_beam.ml.inference.huggingface_inference',
'class_name': 'HuggingFacePipelineModelHandler',
'skip_message': 'HuggingFace transformers not available',
'init_kwargs': {
'task': 'text-classification',
},
},
))

def test_sklearn_handlers(self):
self._assert_handler_cases((
{
'name': 'sklearn_numpy',
'module_name': 'apache_beam.ml.inference.sklearn_inference',
'class_name': 'SklearnModelHandlerNumpy',
'skip_message': 'scikit-learn not available',
'init_kwargs': {
'model_uri': 'unused',
},
},
{
'name': 'sklearn_pandas',
'module_name': 'apache_beam.ml.inference.sklearn_inference',
'class_name': 'SklearnModelHandlerPandas',
'skip_message': 'scikit-learn not available',
'init_kwargs': {
'model_uri': 'unused',
},
},
))

def test_tensorflow_handlers(self):
self._assert_handler_cases((
{
'name': 'tensorflow_numpy',
'module_name': 'apache_beam.ml.inference.tensorflow_inference',
'class_name': 'TFModelHandlerNumpy',
'skip_message': 'TensorFlow not available',
'init_kwargs': {
'model_uri': 'unused',
},
},
{
'name': 'tensorflow_tensor',
'module_name': 'apache_beam.ml.inference.tensorflow_inference',
'class_name': 'TFModelHandlerTensor',
'skip_message': 'TensorFlow not available',
'init_kwargs': {
'model_uri': 'unused',
},
},
))

def test_onnx_handler(self):
self._assert_handler_cases(({
'name': 'onnx_numpy',
'module_name': 'apache_beam.ml.inference.onnx_inference',
'class_name': 'OnnxModelHandlerNumpy',
'skip_message': 'ONNX Runtime not available',
'init_kwargs': {
'model_uri': 'unused',
},
}, ))

def test_xgboost_handler(self):
self._assert_handler_cases(({
'name': 'xgboost_numpy',
'module_name': 'apache_beam.ml.inference.xgboost_inference',
'class_name': 'XGBoostModelHandlerNumpy',
'skip_message': 'XGBoost not available',
'init_kwargs': {
'model_class': object,
'model_state': 'unused',
},
}, ))

def test_tensorrt_handler(self):
self._assert_handler_cases(({
'name': 'tensorrt_numpy',
'module_name': 'apache_beam.ml.inference.tensorrt_inference',
'class_name': 'TensorRTEngineHandlerNumPy',
'skip_message': 'TensorRT not available',
'init_kwargs': {
'min_batch_size': 1,
'max_batch_size': 8,
},
}, ))

def test_vllm_handlers(self):
self._assert_handler_cases((
{
'name': 'vllm_completions',
'module_name': 'apache_beam.ml.inference.vllm_inference',
'class_name': 'VLLMCompletionsModelHandler',
'skip_message': 'vLLM not available',
'init_kwargs': {
'model_name': 'unused',
},
},
{
'name': 'vllm_chat',
'module_name': 'apache_beam.ml.inference.vllm_inference',
'class_name': 'VLLMChatModelHandler',
'skip_message': 'vLLM not available',
'init_kwargs': {
'model_name': 'unused',
},
},
))

def test_vertex_ai_handler(self):
self._assert_handler_cases(({
'name': 'vertex_ai',
'module_name': 'apache_beam.ml.inference.vertex_ai_inference',
'class_name': 'VertexAIModelHandlerJSON',
'skip_message': 'Vertex AI SDK not available',
'init_kwargs': {
'endpoint_id': 'unused',
'project': 'unused',
'location': 'unused',
},
'mock_aiplatform': True,
}, ))

def test_gemini_handler(self):
self._assert_handler_cases(({
'name': 'gemini',
'module_name': 'apache_beam.ml.inference.gemini_inference',
'class_name': 'GeminiModelHandler',
'skip_message': 'Google GenAI SDK not available',
'init_kwargs': {
'model_name': 'unused',
'request_fn': lambda *args: None,
'project': 'unused',
'location': 'unused',
},
}, ))


class SimpleFakeModelHandler(base.ModelHandler[int, int, FakeModel]):
def load_model(self):
return FakeModel()
Expand Down
10 changes: 10 additions & 0 deletions sdks/python/apache_beam/ml/inference/gemini_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ def __init__(
max_batch_duration_secs: Optional[int] = None,
max_batch_weight: Optional[int] = None,
element_size_fn: Optional[Callable[[Any], int]] = None,
batch_length_fn: Optional[Callable[[Any], int]] = None,
batch_bucket_boundaries: Optional[list[int]] = None,
**kwargs):
"""Implementation of the ModelHandler interface for Google Gemini.
**NOTE:** This API and its implementation are under development and
Expand Down Expand Up @@ -158,6 +160,10 @@ def __init__(
max_batch_weight: optional. the maximum total weight of a batch.
element_size_fn: optional. a function that returns the size (weight)
of an element.
batch_length_fn: optional. a callable that returns the length of an
element for length-aware batching.
batch_bucket_boundaries: optional. a sorted list of positive boundary
values for length-aware batching buckets.
"""
self._batching_kwargs = {}
self._env_vars = kwargs.get('env_vars', {})
Expand All @@ -171,6 +177,10 @@ def __init__(
self._batching_kwargs["max_batch_weight"] = max_batch_weight
if element_size_fn is not None:
self._batching_kwargs['element_size_fn'] = element_size_fn
if batch_length_fn is not None:
self._batching_kwargs['length_fn'] = batch_length_fn
if batch_bucket_boundaries is not None:
self._batching_kwargs['bucket_boundaries'] = batch_bucket_boundaries

self.model_name = model_name
self.request_fn = request_fn
Expand Down
24 changes: 24 additions & 0 deletions sdks/python/apache_beam/ml/inference/huggingface_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ def __init__(
model_copies: Optional[int] = None,
max_batch_weight: Optional[int] = None,
element_size_fn: Optional[Callable[[Any], int]] = None,
batch_length_fn: Optional[Callable[[Any], int]] = None,
batch_bucket_boundaries: Optional[list[int]] = None,
**kwargs):
"""
Implementation of the ModelHandler interface for HuggingFace with
Expand Down Expand Up @@ -266,6 +268,10 @@ def __init__(
GPU capacity and want to maximize resource utilization.
max_batch_weight: the maximum total weight of a batch.
element_size_fn: a function that returns the size (weight) of an element.
batch_length_fn: a callable that returns the length of an element for
length-aware batching.
batch_bucket_boundaries: a sorted list of positive boundary values for
length-aware batching buckets.
kwargs: 'env_vars' can be used to set environment variables
before loading the model.

Expand All @@ -278,6 +284,8 @@ def __init__(
max_batch_duration_secs=max_batch_duration_secs,
max_batch_weight=max_batch_weight,
element_size_fn=element_size_fn,
batch_length_fn=batch_length_fn,
batch_bucket_boundaries=batch_bucket_boundaries,
large_model=large_model,
model_copies=model_copies,
**kwargs)
Expand Down Expand Up @@ -411,6 +419,8 @@ def __init__(
model_copies: Optional[int] = None,
max_batch_weight: Optional[int] = None,
element_size_fn: Optional[Callable[[Any], int]] = None,
batch_length_fn: Optional[Callable[[Any], int]] = None,
batch_bucket_boundaries: Optional[list[int]] = None,
**kwargs):
"""
Implementation of the ModelHandler interface for HuggingFace with
Expand Down Expand Up @@ -448,6 +458,10 @@ def __init__(
GPU capacity and want to maximize resource utilization.
max_batch_weight: the maximum total weight of a batch.
element_size_fn: a function that returns the size (weight) of an element.
batch_length_fn: a callable that returns the length of an element for
length-aware batching.
batch_bucket_boundaries: a sorted list of positive boundary values for
length-aware batching buckets.
kwargs: 'env_vars' can be used to set environment variables
before loading the model.

Expand All @@ -460,6 +474,8 @@ def __init__(
max_batch_duration_secs=max_batch_duration_secs,
max_batch_weight=max_batch_weight,
element_size_fn=element_size_fn,
batch_length_fn=batch_length_fn,
batch_bucket_boundaries=batch_bucket_boundaries,
large_model=large_model,
model_copies=model_copies,
**kwargs)
Expand Down Expand Up @@ -576,6 +592,8 @@ def __init__(
model_copies: Optional[int] = None,
max_batch_weight: Optional[int] = None,
element_size_fn: Optional[Callable[[Any], int]] = None,
batch_length_fn: Optional[Callable[[Any], int]] = None,
batch_bucket_boundaries: Optional[list[int]] = None,
**kwargs):
"""
Implementation of the ModelHandler interface for Hugging Face Pipelines.
Expand Down Expand Up @@ -621,6 +639,10 @@ def __init__(
GPU capacity and want to maximize resource utilization.
max_batch_weight: the maximum total weight of a batch.
element_size_fn: a function that returns the size (weight) of an element.
batch_length_fn: a callable that returns the length of an element for
length-aware batching.
batch_bucket_boundaries: a sorted list of positive boundary values for
length-aware batching buckets.
kwargs: 'env_vars' can be used to set environment variables
before loading the model.

Expand All @@ -633,6 +655,8 @@ def __init__(
max_batch_duration_secs=max_batch_duration_secs,
max_batch_weight=max_batch_weight,
element_size_fn=element_size_fn,
batch_length_fn=batch_length_fn,
batch_bucket_boundaries=batch_bucket_boundaries,
large_model=large_model,
model_copies=model_copies,
**kwargs)
Expand Down
Loading
Loading