diff --git a/docs/content.zh/docs/connectors/models/triton.md b/docs/content.zh/docs/connectors/models/triton.md new file mode 100644 index 0000000000000..7d481cfd087bd --- /dev/null +++ b/docs/content.zh/docs/connectors/models/triton.md @@ -0,0 +1,482 @@ +--- +title: "Triton" +weight: 2 +type: docs +--- + + +# Triton + +The Triton Model Function allows Flink SQL to call [NVIDIA Triton Inference Server](https://github.com/triton-inference-server/server) for real-time model inference tasks. + +## Overview + +The function supports calling remote Triton Inference Server via Flink SQL for prediction/inference tasks. Triton Inference Server is a high-performance inference serving solution that supports multiple machine learning frameworks including TensorFlow, PyTorch, ONNX, and more. + +Key features: +* **High Performance**: Optimized for low-latency and high-throughput inference +* **Multi-Framework Support**: Works with models from various ML frameworks +* **Asynchronous Processing**: Non-blocking inference requests for better resource utilization +* **Flexible Configuration**: Comprehensive configuration options for different use cases +* **Resource Management**: Efficient HTTP client pooling and automatic resource cleanup + +## Usage Examples + +The following example creates a Triton model for text classification and uses it to analyze sentiment in movie reviews. + +First, create the Triton model with the following SQL statement: + +```sql +CREATE MODEL triton_sentiment_classifier +INPUT (`input` STRING) +OUTPUT (`output` STRING) +WITH ( + 'provider' = 'triton', + 'endpoint' = 'http://localhost:8000/v2/models', + 'model-name' = 'text-classification', + 'model-version' = '1', + 'timeout' = '10000', + 'max-retries' = '3' +); +``` + +Suppose the following data is stored in a table named `movie_reviews`, and the prediction result is to be stored in a table named `classified_reviews`: + +```sql +CREATE TEMPORARY VIEW movie_reviews(id, movie_name, user_review, actual_sentiment) +AS VALUES + (1, 'Great Movie', 'This movie was absolutely fantastic! Great acting and storyline.', 'positive'), + (2, 'Boring Film', 'I fell asleep halfway through. Very disappointing.', 'negative'), + (3, 'Average Show', 'It was okay, nothing special but not terrible either.', 'neutral'); + +CREATE TEMPORARY TABLE classified_reviews( + id BIGINT, + movie_name VARCHAR, + predicted_sentiment VARCHAR, + actual_sentiment VARCHAR +) WITH ( + 'connector' = 'print' +); +``` + +Then the following SQL statement can be used to classify sentiment for movie reviews: + +```sql +INSERT INTO classified_reviews +SELECT id, movie_name, output as predicted_sentiment, actual_sentiment +FROM ML_PREDICT( + TABLE movie_reviews, + MODEL triton_sentiment_classifier, + DESCRIPTOR(user_review) +); +``` + +### Advanced Configuration Example + +For production environments with authentication and custom headers: + +```sql +CREATE MODEL triton_advanced_model +INPUT (`input` STRING) +OUTPUT (`output` STRING) +WITH ( + 'provider' = 'triton', + 'endpoint' = 'https://triton.example.com/v2/models', + 'model-name' = 'advanced-nlp-model', + 'model-version' = 'latest', + 'timeout' = '15000', + 'max-retries' = '5', + 'batch-size' = '4', + 'priority' = '100', + 'auth-token' = 'Bearer your-auth-token-here', + 'custom-headers' = '{"X-Custom-Header": "custom-value", "X-Request-ID": "req-123"}', + 'compression' = 'gzip' +); +``` + +### Array Type Inference Example + +For models that accept array inputs (e.g., vector embeddings, image features): + +```sql +-- Create model with array input +CREATE MODEL triton_vector_model +INPUT (input_vector ARRAY) +OUTPUT (output_vector ARRAY) +WITH ( + 'provider' = 'triton', + 'endpoint' = 'http://localhost:8000/v2/models', + 'model-name' = 'vector-transform', + 'model-version' = '1', + 'flatten-batch-dim' = 'true' -- If model doesn't expect batch dimension +); + +-- Use the model for inference +CREATE TEMPORARY TABLE vector_input ( + id BIGINT, + features ARRAY +) WITH ( + 'connector' = 'datagen', + 'fields.features.length' = '128' -- 128-dimensional vector +); + +SELECT id, output_vector +FROM ML_PREDICT( + TABLE vector_input, + MODEL triton_vector_model, + DESCRIPTOR(features) +); +``` + +### Stateful Model Example + +For stateful models that require sequence processing: + +```sql +CREATE MODEL triton_sequence_model +INPUT (`input` STRING) +OUTPUT (`output` STRING) +WITH ( + 'provider' = 'triton', + 'endpoint' = 'http://localhost:8000/v2/models', + 'model-name' = 'sequence-model', + 'model-version' = '1', + 'sequence-id' = 'seq-001', + 'sequence-start' = 'true', + 'sequence-end' = 'false' +); +``` + +## Model Options + +### Required Options + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
OptionRequiredDefaultTypeDescription
+
provider
+
required(none)StringSpecifies the model function provider to use, must be 'triton'.
+
endpoint
+
required(none)StringFull URL of the Triton Inference Server endpoint, e.g. http://localhost:8000/v2/models.
+
model-name
+
required(none)StringName of the model to invoke on Triton server.
+
model-version
+
requiredlatestStringVersion of the model to use. Defaults to 'latest'.
+ +### Optional Options + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
OptionRequiredDefaultTypeDescription
+
timeout
+
optional30000LongRequest timeout in milliseconds.
+
max-retries
+
optional3IntegerMaximum number of retries for failed requests.
+
batch-size
+
optional1IntegerBatch size for inference requests.
+
flatten-batch-dim
+
optionalfalseBooleanWhether to flatten the batch dimension for array inputs. When true, shape [1,N] becomes [N]. Defaults to false. Useful for Triton models that do not expect a batch dimension.
+
priority
+
optional(none)IntegerRequest priority level (0-255). Higher values indicate higher priority.
+
sequence-id
+
optional(none)StringSequence ID for stateful models.
+
sequence-start
+
optionalfalseBooleanWhether this is the start of a sequence for stateful models.
+
sequence-end
+
optionalfalseBooleanWhether this is the end of a sequence for stateful models.
+
binary-data
+
optionalfalseBooleanWhether to use binary data transfer. Defaults to false (JSON).
+
compression
+
optional(none)StringCompression algorithm to use (e.g., 'gzip').
+
auth-token
+
optional(none)StringAuthentication token for secured Triton servers.
+
custom-headers
+
optional(none)StringCustom HTTP headers in JSON format, e.g., {"X-Custom-Header":"value"}.
+ +## Schema Requirement + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Input TypeOutput TypeDescription
BOOLEAN, TINYINT, SMALLINT, INT, BIGINTBOOLEAN, TINYINT, SMALLINT, INT, BIGINTInteger type inference
FLOAT, DOUBLEFLOAT, DOUBLEFloating-point type inference
STRINGSTRINGText-to-text inference (classification, generation, etc.)
ARRAY<numeric types>ARRAY<numeric types>Array inference (vectors, tensors, etc.). Supports arrays of numeric types.
+ +**Note**: Input and output types must match the types defined in your Triton model configuration. + +## Triton Server Setup + +To use this integration, you need a running Triton Inference Server. Here's a basic setup guide: + +### Using Docker + +```bash +# Pull Triton server image +docker pull nvcr.io/nvidia/tritonserver:23.10-py3 + +# Run Triton server with your model repository +docker run --rm -p 8000:8000 -p 8001:8001 -p 8002:8002 \ + -v /path/to/your/model/repository:/models \ + nvcr.io/nvidia/tritonserver:23.10-py3 \ + tritonserver --model-repository=/models +``` + +### Model Repository Structure + +Your model repository should follow this structure: + +``` +model_repository/ +├── text-classification/ +│ ├── config.pbtxt +│ └── 1/ +│ └── model.py # or model.onnx, model.plan, etc. +└── other-model/ + ├── config.pbtxt + └── 1/ + └── model.savedmodel/ +``` + +### Example Model Configuration + +Here's an example `config.pbtxt` for a text classification model: + +```protobuf +name: "text-classification" +platform: "python" +max_batch_size: 8 +input [ + { + name: "INPUT_TEXT" + data_type: TYPE_STRING + dims: [ 1 ] + } +] +output [ + { + name: "OUTPUT_TEXT" + data_type: TYPE_STRING + dims: [ 1 ] + } +] +``` + +## Performance Considerations + +1. **Connection Pooling**: HTTP clients are pooled and reused for efficiency +2. **Asynchronous Processing**: Non-blocking requests prevent thread starvation +3. **Batch Processing**: Configure batch size for optimal throughput +4. **Resource Management**: Automatic cleanup of HTTP resources +5. **Timeout Configuration**: Set appropriate timeout values based on model complexity +6. **Retry Strategy**: Configure retry attempts for handling transient failures + +## Error Handling + +The integration includes comprehensive error handling: + +- **Connection Errors**: Automatic retry with exponential backoff +- **Timeout Handling**: Configurable request timeouts +- **HTTP Errors**: Detailed error messages from Triton server +- **Serialization Errors**: JSON parsing and validation errors + +## Monitoring and Debugging + +Enable debug logging to monitor the integration: + +```properties +# In log4j2.properties +logger.triton.name = org.apache.flink.model.triton +logger.triton.level = DEBUG +``` + +This will provide detailed logs about: +- HTTP request/response details +- Client connection management +- Error conditions and retries +- Performance metrics + +## Dependencies + +To use the Triton model function, you need to include the following dependency in your Flink application: + +```xml + + org.apache.flink + flink-model-triton + ${flink.version} + +``` + +{{< top >}} diff --git a/docs/content.zh/docs/dev/table/sql/create.md b/docs/content.zh/docs/dev/table/sql/create.md index 9c18baf478cf9..40bc3ff55997e 100644 --- a/docs/content.zh/docs/dev/table/sql/create.md +++ b/docs/content.zh/docs/dev/table/sql/create.md @@ -968,4 +968,19 @@ WITH ( ); ``` +```sql +CREATE MODEL triton_text_classifier +INPUT (input STRING COMMENT '用于分类的输入文本') +OUTPUT (output STRING COMMENT '分类结果') +COMMENT '基于 Triton 的文本分类模型' +WITH ( + 'provider' = 'triton', + 'endpoint' = 'http://localhost:8000/v2/models', + 'model-name' = 'text-classification', + 'model-version' = '1', + 'timeout' = '10000', + 'max-retries' = '3' +); +``` + {{< top >}} diff --git a/docs/content.zh/docs/dev/table/sql/queries/model-inference.md b/docs/content.zh/docs/dev/table/sql/queries/model-inference.md index ef65593b1a9b8..e71d7e509b7ad 100644 --- a/docs/content.zh/docs/dev/table/sql/queries/model-inference.md +++ b/docs/content.zh/docs/dev/table/sql/queries/model-inference.md @@ -131,4 +131,11 @@ SELECT * FROM ML_PREDICT( - [模型创建]({{< ref "docs/dev/table/sql/create#create-model" >}}) - [模型修改]({{< ref "docs/dev/table/sql/alter#alter-model" >}}) +### 支持的模型提供者 + +Flink 目前支持以下模型提供者: + +- **OpenAI**:用于调用 OpenAI API 服务。详情请参见 [OpenAI 模型文档]({{< ref "docs/connectors/models/openai" >}})。 +- **Triton**:用于调用 NVIDIA Triton 推理服务器。详情请参见 [Triton 模型文档]({{< ref "docs/connectors/models/triton" >}})。 + {{< top >}} diff --git a/docs/content/docs/connectors/models/triton.md b/docs/content/docs/connectors/models/triton.md new file mode 100644 index 0000000000000..b7a4c20ac0962 --- /dev/null +++ b/docs/content/docs/connectors/models/triton.md @@ -0,0 +1,482 @@ +--- +title: "Triton" +weight: 2 +type: docs +--- + + +# Triton + +Triton 模型函数允许 Flink SQL 调用 [NVIDIA Triton 推理服务器](https://github.com/triton-inference-server/server) 进行实时模型推理任务。 + +## 概述 + +该函数支持通过 Flink SQL 调用远程 Triton 推理服务器进行预测/推理任务。Triton 推理服务器是一个高性能的推理服务解决方案,支持多种机器学习框架,包括 TensorFlow、PyTorch、ONNX 等。 + +主要特性: +* **高性能**:针对低延迟和高吞吐量推理进行优化 +* **多框架支持**:支持来自各种机器学习框架的模型 +* **异步处理**:非阻塞推理请求,提高资源利用率 +* **灵活配置**:为不同用例提供全面的配置选项 +* **资源管理**:高效的 HTTP 客户端池化和自动资源清理 + +## 使用示例 + +以下示例创建了一个用于文本分类的 Triton 模型,并使用它来分析电影评论中的情感。 + +首先,使用以下 SQL 语句创建 Triton 模型: + +```sql +CREATE MODEL triton_sentiment_classifier +INPUT (`input` STRING) +OUTPUT (`output` STRING) +WITH ( + 'provider' = 'triton', + 'endpoint' = 'http://localhost:8000/v2/models', + 'model-name' = 'text-classification', + 'model-version' = '1', + 'timeout' = '10000', + 'max-retries' = '3' +); +``` + +假设以下数据存储在名为 `movie_reviews` 的表中,预测结果将存储在名为 `classified_reviews` 的表中: + +```sql +CREATE TEMPORARY VIEW movie_reviews(id, movie_name, user_review, actual_sentiment) +AS VALUES + (1, '好电影', '这部电影绝对精彩!演技和故事情节都很棒。', 'positive'), + (2, '无聊的电影', '我看到一半就睡着了。非常失望。', 'negative'), + (3, '一般的电影', '还可以,没什么特别的,但也不算糟糕。', 'neutral'); + +CREATE TEMPORARY TABLE classified_reviews( + id BIGINT, + movie_name VARCHAR, + predicted_sentiment VARCHAR, + actual_sentiment VARCHAR +) WITH ( + 'connector' = 'print' +); +``` + +然后可以使用以下 SQL 语句对电影评论进行情感分类: + +```sql +INSERT INTO classified_reviews +SELECT id, movie_name, output as predicted_sentiment, actual_sentiment +FROM ML_PREDICT( + TABLE movie_reviews, + MODEL triton_sentiment_classifier, + DESCRIPTOR(user_review) +); +``` + +### 高级配置示例 + +对于需要身份验证和自定义头部的生产环境: + +```sql +CREATE MODEL triton_advanced_model +INPUT (`input` STRING) +OUTPUT (`output` STRING) +WITH ( + 'provider' = 'triton', + 'endpoint' = 'https://triton.example.com/v2/models', + 'model-name' = 'advanced-nlp-model', + 'model-version' = 'latest', + 'timeout' = '15000', + 'max-retries' = '5', + 'batch-size' = '4', + 'priority' = '100', + 'auth-token' = 'Bearer your-auth-token-here', + 'custom-headers' = '{"X-Custom-Header": "custom-value", "X-Request-ID": "req-123"}', + 'compression' = 'gzip' +); +``` + +### 数组类型推理示例 + +对于接受数组输入的模型(例如向量嵌入、图像特征等): + +```sql +-- 创建数组输入的模型 +CREATE MODEL triton_vector_model +INPUT (input_vector ARRAY) +OUTPUT (output_vector ARRAY) +WITH ( + 'provider' = 'triton', + 'endpoint' = 'http://localhost:8000/v2/models', + 'model-name' = 'vector-transform', + 'model-version' = '1', + 'flatten-batch-dim' = 'true' -- 如果模型不期望批次维度 +); + +-- 使用模型进行推理 +CREATE TEMPORARY TABLE vector_input ( + id BIGINT, + features ARRAY +) WITH ( + 'connector' = 'datagen', + 'fields.features.length' = '128' -- 128维向量 +); + +SELECT id, output_vector +FROM ML_PREDICT( + TABLE vector_input, + MODEL triton_vector_model, + DESCRIPTOR(features) +); +``` + +### 有状态模型示例 + +对于需要序列处理的有状态模型: + +```sql +CREATE MODEL triton_sequence_model +INPUT (`input` STRING) +OUTPUT (`output` STRING) +WITH ( + 'provider' = 'triton', + 'endpoint' = 'http://localhost:8000/v2/models', + 'model-name' = 'sequence-model', + 'model-version' = '1', + 'sequence-id' = 'seq-001', + 'sequence-start' = 'true', + 'sequence-end' = 'false' +); +``` + +## 模型选项 + +### 必需选项 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
选项是否必需默认值类型描述
+
provider
+
必需(无)String指定要使用的模型函数提供者,必须为 'triton'。
+
endpoint
+
必需(无)StringTriton 推理服务器端点的完整 URL,例如 http://localhost:8000/v2/models
+
model-name
+
必需(无)String要在 Triton 服务器上调用的模型名称。
+
model-version
+
必需latestString要使用的模型版本。默认为 'latest'。
+ +### 可选选项 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
选项是否必需默认值类型描述
+
timeout
+
可选30000Long请求超时时间(毫秒)。
+
max-retries
+
可选3Integer失败请求的最大重试次数。
+
batch-size
+
可选1Integer推理请求的批处理大小。
+
flatten-batch-dim
+
可选falseBoolean是否扁平化数组输入的批次维度。当设置为 true 时,形状 [1,N] 会被转换为 [N]。默认为 false。适用于某些 Triton 模型不期望批次维度的情况。
+
priority
+
可选(无)Integer请求优先级(0-255)。数值越高表示优先级越高。
+
sequence-id
+
可选(无)String有状态模型的序列 ID。
+
sequence-start
+
可选falseBoolean对于有状态模型,是否为序列的开始。
+
sequence-end
+
可选falseBoolean对于有状态模型,是否为序列的结束。
+
binary-data
+
可选falseBoolean是否使用二进制数据传输。默认为 false(JSON)。
+
compression
+
可选(无)String要使用的压缩算法(例如 'gzip')。
+
auth-token
+
可选(无)String安全 Triton 服务器的身份验证令牌。
+
custom-headers
+
可选(无)StringJSON 格式的自定义 HTTP 头部,例如 {"X-Custom-Header":"value"}
+ +## 模式要求 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
输入类型输出类型描述
BOOLEAN, TINYINT, SMALLINT, INT, BIGINTBOOLEAN, TINYINT, SMALLINT, INT, BIGINT整数类型推理
FLOAT, DOUBLEFLOAT, DOUBLE浮点数类型推理
STRINGSTRING文本到文本推理(分类、生成等)
ARRAY<数值类型>ARRAY<数值类型>数组推理(向量、张量等)。支持数值类型数组。
+ +**注意**:输入和输出类型必须与 Triton 模型配置中定义的类型匹配。 + +## Triton 服务器设置 + +要使用此集成,您需要运行 Triton 推理服务器。以下是基本设置指南: + +### 使用 Docker + +```bash +# 拉取 Triton 服务器镜像 +docker pull nvcr.io/nvidia/tritonserver:23.10-py3 + +# 使用您的模型仓库运行 Triton 服务器 +docker run --rm -p 8000:8000 -p 8001:8001 -p 8002:8002 \ + -v /path/to/your/model/repository:/models \ + nvcr.io/nvidia/tritonserver:23.10-py3 \ + tritonserver --model-repository=/models +``` + +### 模型仓库结构 + +您的模型仓库应遵循以下结构: + +``` +model_repository/ +├── text-classification/ +│ ├── config.pbtxt +│ └── 1/ +│ └── model.py # 或 model.onnx, model.plan 等 +└── other-model/ + ├── config.pbtxt + └── 1/ + └── model.savedmodel/ +``` + +### 示例模型配置 + +以下是文本分类模型的示例 `config.pbtxt`: + +```protobuf +name: "text-classification" +platform: "python" +max_batch_size: 8 +input [ + { + name: "INPUT_TEXT" + data_type: TYPE_STRING + dims: [ 1 ] + } +] +output [ + { + name: "OUTPUT_TEXT" + data_type: TYPE_STRING + dims: [ 1 ] + } +] +``` + +## 性能考虑 + +1. **连接池**:HTTP 客户端被池化和重用以提高效率 +2. **异步处理**:非阻塞请求防止线程饥饿 +3. **批处理**:配置批处理大小以获得最佳吞吐量 +4. **资源管理**:HTTP 资源的自动清理 +5. **超时配置**:根据模型复杂性设置适当的超时值 +6. **重试策略**:配置重试次数以处理瞬态故障 + +## 错误处理 + +集成包括全面的错误处理: + +- **连接错误**:使用指数退避的自动重试 +- **超时处理**:可配置的请求超时 +- **HTTP 错误**:来自 Triton 服务器的详细错误消息 +- **序列化错误**:JSON 解析和验证错误 + +## 监控和调试 + +启用调试日志以监控集成: + +```properties +# 在 log4j2.properties 中 +logger.triton.name = org.apache.flink.model.triton +logger.triton.level = DEBUG +``` + +这将提供以下详细日志: +- HTTP 请求/响应详情 +- 客户端连接管理 +- 错误条件和重试 +- 性能指标 + +## 依赖项 + +要使用 Triton 模型函数,您需要在 Flink 应用程序中包含以下依赖项: + +```xml + + org.apache.flink + flink-model-triton + ${flink.version} + +``` + +{{< top >}} diff --git a/docs/content/docs/dev/table/sql/create.md b/docs/content/docs/dev/table/sql/create.md index 89e958e9a5d49..579a497f0b552 100644 --- a/docs/content/docs/dev/table/sql/create.md +++ b/docs/content/docs/dev/table/sql/create.md @@ -951,4 +951,19 @@ WITH ( ); ``` +```sql +CREATE MODEL triton_text_classifier +INPUT (input STRING COMMENT 'Input text for classification') +OUTPUT (output STRING COMMENT 'Classification result') +COMMENT 'A Triton-based text classification model' +WITH ( + 'provider' = 'triton', + 'endpoint' = 'http://localhost:8000/v2/models', + 'model-name' = 'text-classification', + 'model-version' = '1', + 'timeout' = '10000', + 'max-retries' = '3' +); +``` + {{< top >}} diff --git a/docs/content/docs/dev/table/sql/queries/model-inference.md b/docs/content/docs/dev/table/sql/queries/model-inference.md index 0d52a1e27e51c..bd3c86cff1d70 100644 --- a/docs/content/docs/dev/table/sql/queries/model-inference.md +++ b/docs/content/docs/dev/table/sql/queries/model-inference.md @@ -131,4 +131,11 @@ The function will throw an exception in the following cases: - [Model Creation]({{< ref "docs/dev/table/sql/create#create-model" >}}) - [Model Alteration]({{< ref "docs/dev/table/sql/alter#alter-model" >}}) +### Supported Model Providers + +Flink currently supports the following model providers: + +- **OpenAI**: For calling OpenAI API services. See [OpenAI Model Documentation]({{< ref "docs/connectors/models/openai" >}}) for details. +- **Triton**: For calling NVIDIA Triton Inference Server. See [Triton Model Documentation]({{< ref "docs/connectors/models/triton" >}}) for details. + {{< top >}} diff --git a/flink-docs/src/main/java/org/apache/flink/docs/util/ConfigurationOptionLocator.java b/flink-docs/src/main/java/org/apache/flink/docs/util/ConfigurationOptionLocator.java index 0b75527211077..95dae776d7d83 100644 --- a/flink-docs/src/main/java/org/apache/flink/docs/util/ConfigurationOptionLocator.java +++ b/flink-docs/src/main/java/org/apache/flink/docs/util/ConfigurationOptionLocator.java @@ -90,7 +90,9 @@ public class ConfigurationOptionLocator { "flink-external-resources/flink-external-resource-gpu", "org.apache.flink.externalresource.gpu"), new OptionsClassLocation( - "flink-models/flink-model-openai", "org.apache.flink.model.openai") + "flink-models/flink-model-openai", "org.apache.flink.model.openai"), + new OptionsClassLocation( + "flink-models/flink-model-triton", "org.apache.flink.model.triton") }; private static final Set EXCLUSIONS = diff --git a/flink-models/flink-model-triton/README.md b/flink-models/flink-model-triton/README.md new file mode 100644 index 0000000000000..4dc3ee578695a --- /dev/null +++ b/flink-models/flink-model-triton/README.md @@ -0,0 +1,378 @@ +# Flink Triton Model Integration + +## ⚠️ Experimental / MVP Status + +**This module is currently experimental and designed for batch-oriented inference workloads.** + +### Scope and Positioning + +- **Primary Use Case**: Batch inference via `ML_PREDICT` on bounded tables +- **Stability**: Experimental - APIs may evolve in future releases +- **Target Scenarios**: Offline processing, batch scoring, model evaluation pipelines + +### Non-Goals (for v1) + +- **Streaming inference**: Real-time/low-latency async inference in streaming jobs (future scope) +- **Multi-input/output models**: Complex tensor schemas (ROW, MAP types - planned for v2+) +- **gRPC protocol**: Binary protocol support (HTTP/REST only in v1) + +### Why Batch-First? + +**Important**: The term "batch-first" refers to this module's **primary use case** (batch table processing via `ML_PREDICT`), NOT to request-level batching semantics. + +**Current Request Model (v1):** +- Each Flink record triggers **ONE HTTP inference request** (1:1 mapping) +- No Flink-side mini-batch aggregation in this version +- Batching efficiency comes from: + - **Triton server-side dynamic batching**: Configure in model's `config.pbtxt` to aggregate concurrent requests + - **Flink table-level parallelism**: Natural concurrency from parallel source reads + - **AsyncDataStream capacity**: Buffer size controls concurrent in-flight requests + +This initial version focuses on correctness and API compatibility with Triton's inference protocol. The batch-oriented **use case** allows us to: +- Validate type mappings and schema handling with simpler control flow +- Establish stable configuration patterns before adding streaming complexity +- Gather community feedback on API design before committing to streaming semantics + +**Future Enhancement (v2+):** Flink-side mini-batch buffer (N rows / T milliseconds) to reduce HTTP overhead for high-throughput scenarios. + +**For streaming use cases**, consider evaluating this module after v2 when async streaming patterns are stabilized. + +--- + +This module provides integration between Apache Flink and NVIDIA Triton Inference Server, enabling model inference within Flink batch applications. + +## Features + +- **REST API Integration**: Communicates with Triton Inference Server via HTTP/REST API +- **Batch Inference Support**: Designed for `ML_PREDICT` in batch table queries +- **Flexible Configuration**: Comprehensive configuration options for various use cases +- **Multi-Type Support**: Supports various input/output data types (STRING, INT, FLOAT, DOUBLE, ARRAY, etc.) +- **Error Handling**: Built-in retry mechanisms and error handling +- **Resource Management**: Efficient HTTP client pooling and resource management + +## Configuration Options + +### Required Options + +| Option | Type | Description | +|--------|------|-------------| +| `endpoint` | String | Base URL of the Triton Inference Server (e.g., `http://localhost:8000` or `http://localhost:8000/v2/models`). The integration will auto-complete to the full inference path. | +| `model-name` | String | Name of the model to invoke on Triton server | +| `model-version` | String | Version of the model to use (defaults to "latest") | + +### Optional Options + +| Option | Type | Default | Description | +|--------|------|---------|-------------| +| `timeout` | Long | 30000 | HTTP request timeout in milliseconds (connect + read + write). | +| `max-retries` | Integer | 3 | Maximum retry attempts for connection failures (IOException). HTTP 4xx/5xx errors are NOT retried automatically. | +| `batch-size` | Integer | 1 | **Reserved for future use (v2+).** Currently has NO effect in v1 - each Flink record triggers one HTTP request. Future versions will support Flink-side mini-batch aggregation (buffer N records or T milliseconds). For batching efficiency in v1, configure Triton's `dynamic_batching` in model config and tune AsyncDataStream capacity. | +| `priority` | Integer | - | Request priority level (0-255, higher values = higher priority). *Triton-specific: See Triton docs for server support.* | +| `sequence-id` | String | - | Sequence ID for stateful models. *Triton-specific: For models with sequence/state handling.* | +| `sequence-start` | Boolean | false | Whether this is the start of a sequence for stateful models. *Triton-specific.* | +| `sequence-end` | Boolean | false | Whether this is the end of a sequence for stateful models. *Triton-specific.* | +| `binary-data` | Boolean | false | Whether to use binary data transfer. **Not implemented in v1** (reserved for future use, currently JSON-only). | +| `compression` | String | - | Compression algorithm to use (e.g., 'gzip') | +| `auth-token` | String | - | Authentication token for secured Triton servers | +| `custom-headers` | String | - | Custom HTTP headers in JSON format | +| `flatten-batch-dim` | Boolean | false | *Advanced/Triton-specific*: Remove leading batch dimension from input shape. Use when Triton model expects `[N]` but Flink provides `[1,N]`. | + +### Important Notes on Batching + +**Current v1 Behavior**: Each Flink record triggers **one HTTP request** (1:1 mapping). There is no Flink-side batching in the initial version. + +- The `batch-size` option is reserved for future use (Flink-side request aggregation) +- For server-side batching: Configure Triton's model `config.pbtxt` with `dynamic_batching` settings +- Batch inference workloads naturally benefit from table-level parallelism without explicit batching + +**Future enhancement** (v2+): Flink-side batching to reduce HTTP overhead for high-throughput scenarios. + +## Usage Example + +### Basic Text Processing + +```sql +CREATE MODEL my_triton_model ( + input STRING, + output STRING +) WITH ( + 'provider' = 'triton', + 'endpoint' = 'http://localhost:8000/v2/models', + 'model-name' = 'text-classification', + 'model-version' = '1', + 'timeout' = '10000', + 'max-retries' = '5' +); +``` + +### Image Classification with Array Input + +```sql +CREATE MODEL image_classifier ( + image_data ARRAY, + predictions ARRAY +) WITH ( + 'provider' = 'triton', + 'endpoint' = 'http://localhost:8000/v2/models', + 'model-name' = 'resnet50', + 'model-version' = '1' +); +``` + +### Numeric Prediction + +```sql +CREATE MODEL numeric_model ( + features ARRAY, + score FLOAT +) WITH ( + 'provider' = 'triton', + 'endpoint' = 'http://localhost:8000/v2/models', + 'model-name' = 'linear-regression', + 'model-version' = 'latest' +); +``` + +### Table API + +```java +// Create table environment +TableEnvironment tableEnv = TableEnvironment.create(EnvironmentSettings.inStreamingMode()); + +// Register the model +tableEnv.executeSql( + "CREATE MODEL my_triton_model (" + + " input STRING," + + " output STRING" + + ") WITH (" + + " 'provider' = 'triton'," + + " 'endpoint' = 'http://localhost:8000/v2/models'," + + " 'model-name' = 'text-classification'," + + " 'model-version' = '1'" + + ")" +); + +// Use the model for inference +Table result = tableEnv.sqlQuery( + "SELECT input, ML_PREDICT('my_triton_model', input) as prediction " + + "FROM input_table" +); +``` + +## Supported Data Types + +**Current Version (v1) Limitation**: Only single input column and single output column are supported per model. + +The Triton integration supports the following Flink data types: + +| Flink Type | Triton Type | Description | +|------------|-------------|-------------| +| `BOOLEAN` | `BOOL` | Boolean values | +| `TINYINT` | `INT8` | 8-bit signed integer | +| `SMALLINT` | `INT16` | 16-bit signed integer | +| `INT` | `INT32` | 32-bit signed integer | +| `BIGINT` | `INT64` | 64-bit signed integer | +| `FLOAT` | `FP32` | 32-bit floating point | +| `DOUBLE` | `FP64` | 64-bit floating point | +| `STRING` / `VARCHAR` | `BYTES` | String/text data | +| `ARRAY` | `TYPE[]` | Array of any supported type | + +### Multi-Tensor Models (Workarounds for v1) + +If your Triton model requires multiple input tensors, consider these approaches: + +1. **JSON Encoding**: Serialize multiple fields into a JSON STRING +2. **Array Packing**: Concatenate values into a single ARRAY +3. **Future Support**: ROW<...> and MAP<...> types are planned for future releases + +### Type Mapping Examples + +```sql +-- String input/output (text processing) +CREATE MODEL text_model ( + text STRING, + result STRING +) WITH ('provider' = 'triton', ...); + +-- Array input/output (image processing, embeddings) +CREATE MODEL embedding_model ( + text STRING, + embedding ARRAY +) WITH ('provider' = 'triton', ...); + +-- Numeric computation +CREATE MODEL regression_model ( + features ARRAY, + prediction DOUBLE +) WITH ('provider' = 'triton', ...); +``` + +## Advanced Configuration + +```sql +CREATE MODEL advanced_triton_model ( + input STRING, + output STRING +) WITH ( + 'provider' = 'triton', + 'endpoint' = 'https://triton.example.com/v2/models', + 'model-name' = 'advanced-nlp-model', + 'model-version' = 'latest', + 'timeout' = '15000', + 'max-retries' = '3', + 'batch-size' = '4', + 'priority' = '100', + 'auth-token' = 'your-auth-token-here', + 'custom-headers' = '{"X-Custom-Header": "custom-value"}', + 'compression' = 'gzip' +); +``` + +## Triton Server Setup + +To use this integration, you need a running Triton Inference Server. Here's a basic setup: + +### Using Docker + +```bash +# Pull Triton server image +docker pull nvcr.io/nvidia/tritonserver:23.10-py3 + +# Run Triton server with your model repository +docker run --rm -p 8000:8000 -p 8001:8001 -p 8002:8002 \ + -v /path/to/your/model/repository:/models \ + nvcr.io/nvidia/tritonserver:23.10-py3 \ + tritonserver --model-repository=/models +``` + +### Model Repository Structure + +``` +model_repository/ +├── text-classification/ +│ ├── config.pbtxt +│ └── 1/ +│ └── model.py # or model.onnx, model.plan, etc. +└── other-model/ + ├── config.pbtxt + └── 1/ + └── model.savedmodel/ +``` + +## Error Handling + +The integration includes comprehensive error handling: + +- **Connection Errors**: Automatic retry with exponential backoff (OkHttp built-in) +- **Timeout Handling**: Configurable HTTP request timeout (default 30s) +- **HTTP Errors**: 4xx/5xx responses are NOT automatically retried + - 400 Bad Request: Usually indicates shape/type mismatch + - 404 Not Found: Model or version not available + - 500 Internal Server Error: Triton inference failure +- **Serialization Errors**: JSON parsing and type validation errors + +### Retry Behavior Matrix + +| Error Type | Trigger | Flink Behavior | Triton Behavior | +|------------|---------|----------------|-----------------| +| Connection Timeout | Network issue | Fails async operation | N/A | +| HTTP Timeout | Slow inference | Fails after `timeout` ms | N/A | +| Connection Failure (IOException) | Network error | Retries up to `max-retries` | N/A | +| HTTP 4xx | Client error (bad input/shape) | No retry, fails immediately | Returns error JSON | +| HTTP 5xx | Server error (inference crash) | No retry, fails immediately | Returns error JSON | +| JSON Parse Error | Invalid response | No retry, fails immediately | N/A | + +**Important**: Configure Flink's async timeout separately from HTTP timeout to avoid cascading failures: +```java +// Flink async timeout should be > HTTP timeout + retry overhead +AsyncDataStream.unorderedWait(stream, asyncFunc, 60000, TimeUnit.MILLISECONDS); +``` + +## Performance Considerations + +- **Connection Pooling**: HTTP clients are shared across function instances with the same timeout/retry configuration (reference-counted singleton per JVM) +- **Asynchronous Processing**: Non-blocking requests prevent thread starvation +- **Batch Processing**: + - **Triton-side**: Enable dynamic batching in Triton's model config for optimal throughput + - **Flink-side**: Configure AsyncDataStream capacity for concurrent request buffering +- **Resource Management**: Automatic cleanup of HTTP resources via reference counting + +### Performance Tuning Tips + +1. **Increase Async Capacity**: For high-throughput scenarios + ```java + AsyncDataStream.unorderedWait(stream, asyncFunc, timeout, TimeUnit.MILLISECONDS, 200); // capacity=200 + ``` + +2. **Enable Triton Dynamic Batching**: In model's `config.pbtxt` + ``` + dynamic_batching { + preferred_batch_size: [ 4, 8, 16 ] + max_queue_delay_microseconds: 100 + } + ``` + +3. **Tune Parallelism**: Match Flink parallelism to Triton server capacity + ```java + dataStream.map(...).setParallelism(10); // Adjust based on server resources + ``` + +## Monitoring and Debugging + +Enable debug logging to monitor the integration: + +```properties +# In log4j2.properties +logger.triton.name = org.apache.flink.model.triton +logger.triton.level = DEBUG +``` + +This will provide detailed logs about: +- HTTP request/response details +- Client connection management +- Error conditions and retries +- Performance metrics + +## Dependencies + +This module includes the following key dependencies: +- OkHttp for HTTP client functionality (with connection pooling) +- Jackson for JSON processing +- Flink Table API for model integration + +All dependencies are shaded to avoid conflicts with your application. + +## Limitations and Future Work + +### Current Limitations (v1) + +1. **Single Input/Output Only**: Each model must have exactly one input column and one output column +2. **REST API Only**: Uses HTTP/REST protocol. gRPC is not yet supported +3. **No Flink-Side Batching**: Each record triggers a separate HTTP request (relies on Triton's server-side batching) +4. **Binary Data Mode**: Declared but not fully implemented (JSON only) + +### Testing and Validation + +**Unit Test Coverage**: This module includes comprehensive unit tests for: +- Type mapping logic (`TritonTypeMapper`) +- HTTP request/response formatting +- Configuration validation +- Provider factory registration + +**Integration Testing**: End-to-end tests with a live Triton server are **not included** in this PR due to: +- CI environment constraints (no GPU/Triton infrastructure in Flink CI) +- Complexity of Docker-in-Docker setup for model serving +- Focus on **protocol correctness** rather than end-to-end deployment validation + +**Manual Validation**: The module has been manually tested with local Triton instances across various model types (text classification, embeddings, numeric regression). Users are encouraged to validate with their specific Triton deployments. + +**Note**: This testing approach is consistent with other `flink-models` providers (e.g., `flink-model-openai` tests protocol compliance without live API calls). + +### Planned Enhancements (v2+) + +- **Multi-Input/Output Support**: Using ROW<...> or MAP<...> types to map multiple Triton tensors +- **gRPC Protocol**: Native gRPC support for improved performance and streaming +- **Flink-Side Batching**: Optional aggregation of multiple records before sending to Triton +- **Binary Data Transfer**: Efficient binary serialization for large tensor data + +**Feedback Welcome**: Please share your use cases and requirements via JIRA or mailing lists to help prioritize these features. diff --git a/flink-models/flink-model-triton/pom.xml b/flink-models/flink-model-triton/pom.xml new file mode 100644 index 0000000000000..da88294c515c4 --- /dev/null +++ b/flink-models/flink-model-triton/pom.xml @@ -0,0 +1,168 @@ + + + + + 4.0.0 + + + org.apache.flink + flink-models + 2.3-SNAPSHOT + + + flink-model-triton + Flink : Models : Triton + + + 4.12.0 + 2.15.2 + 2.11.0 + + + + + + com.squareup.okhttp3 + okhttp + ${okhttp.version} + ${flink.markBundledAsOptional} + + + + + com.fasterxml.jackson.core + jackson-core + ${jackson.version} + ${flink.markBundledAsOptional} + + + + com.fasterxml.jackson.core + jackson-databind + ${jackson.version} + ${flink.markBundledAsOptional} + + + + com.fasterxml.jackson.core + jackson-annotations + ${jackson.version} + ${flink.markBundledAsOptional} + + + + + org.apache.flink + flink-core + ${project.version} + provided + + + + org.apache.flink + flink-table-api-java + ${project.version} + provided + + + + org.apache.flink + flink-table-common + ${project.version} + provided + + + + + org.apache.flink + flink-table-planner_${scala.binary.version} + ${project.version} + test + + + + com.squareup.okhttp3 + mockwebserver + ${okhttp.version} + test + + + + org.apache.flink + flink-table-api-java-bridge + ${project.version} + test + + + + org.apache.flink + flink-clients + ${project.version} + test + + + + com.google.code.gson + gson + ${test.gson.version} + test + + + + + + + org.apache.maven.plugins + maven-shade-plugin + + + shade-flink + package + + shade + + + + + *:* + + + com.google.code.findbugs:jsr305 + + + + + com.fasterxml.jackson + org.apache.flink.model.triton.com.fasterxml.jackson + + + com.squareup + org.apache.flink.model.triton.com.squareup + + + + + + + + + diff --git a/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/AbstractTritonModelFunction.java b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/AbstractTritonModelFunction.java new file mode 100644 index 0000000000000..77f612369c1d4 --- /dev/null +++ b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/AbstractTritonModelFunction.java @@ -0,0 +1,466 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.model.triton; + +import org.apache.flink.configuration.ConfigOption; +import org.apache.flink.configuration.ConfigOptions; +import org.apache.flink.configuration.ReadableConfig; +import org.apache.flink.configuration.description.Description; +import org.apache.flink.table.catalog.Column; +import org.apache.flink.table.catalog.ResolvedSchema; +import org.apache.flink.table.factories.ModelProviderFactory; +import org.apache.flink.table.functions.AsyncPredictFunction; +import org.apache.flink.table.functions.FunctionContext; +import org.apache.flink.table.types.logical.ArrayType; +import org.apache.flink.table.types.logical.LogicalType; + +import okhttp3.OkHttpClient; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; +import java.util.stream.Collectors; + +import static org.apache.flink.configuration.description.TextElement.code; + +/** + * Abstract parent class for {@link AsyncPredictFunction}s for Triton Inference Server API. + * + *

This implementation uses REST-based HTTP communication with Triton Inference Server. Each + * Flink record triggers a separate HTTP request (no Flink-side batching). Triton's server-side + * dynamic batching can aggregate concurrent requests. + * + *

HTTP Client Lifecycle: A shared HTTP client pool is maintained per JVM with reference + * counting. Multiple function instances with identical timeout/retry settings share the same client + * instance to avoid resource exhaustion in high-parallelism scenarios. + * + *

Current Limitations (v1): + * + *

    + *
  • Only single input column and single output column are supported + *
  • REST API only; gRPC may be introduced in future versions + *
  • Binary data mode is declared but not fully implemented + *
+ * + *

Future Roadmap: Support for multi-input/multi-output models using ROW or MAP types, and + * native gRPC protocol for improved performance. + */ +public abstract class AbstractTritonModelFunction extends AsyncPredictFunction { + private static final Logger LOG = LoggerFactory.getLogger(AbstractTritonModelFunction.class); + + public static final ConfigOption ENDPOINT = + ConfigOptions.key("endpoint") + .stringType() + .noDefaultValue() + .withDescription( + Description.builder() + .text( + "Full URL of the Triton Inference Server endpoint, e.g., %s", + code("http://localhost:8000/v2/models")) + .build()); + + public static final ConfigOption MODEL_NAME = + ConfigOptions.key("model-name") + .stringType() + .noDefaultValue() + .withDescription("Name of the model to invoke on Triton server."); + + public static final ConfigOption MODEL_VERSION = + ConfigOptions.key("model-version") + .stringType() + .defaultValue("latest") + .withDescription("Version of the model to use. Defaults to 'latest'."); + + public static final ConfigOption TIMEOUT = + ConfigOptions.key("timeout") + .longType() + .defaultValue(30000L) + .withDescription( + "HTTP request timeout in milliseconds (connect + read + write). " + + "This applies per individual request and is separate from Flink's async timeout. " + + "Defaults to 30000ms (30 seconds)."); + + public static final ConfigOption MAX_RETRIES = + ConfigOptions.key("max-retries") + .intType() + .defaultValue(3) + .withDescription( + "Maximum number of retry attempts for failed HTTP requests. " + + "Retries are triggered on connection failures (IOException). " + + "HTTP errors (4xx/5xx) are NOT automatically retried. " + + "Defaults to 3 retries."); + + public static final ConfigOption BATCH_SIZE = + ConfigOptions.key("batch-size") + .intType() + .defaultValue(1) + .withDescription( + Description.builder() + .text( + "Reserved for future use (v2+). Currently has NO effect in v1. " + + "Each Flink record triggers one HTTP request regardless of this setting. " + + "Future versions will support Flink-side mini-batch aggregation " + + "(buffer N records or T milliseconds before sending). " + + "For batching efficiency in v1: " + + "1) Configure Triton model's dynamic_batching in config.pbtxt, " + + "2) Tune Flink AsyncDataStream capacity for concurrent requests, " + + "3) Increase Flink parallelism to create more concurrent requests. " + + "Defaults to 1.") + .build()); + + public static final ConfigOption FLATTEN_BATCH_DIM = + ConfigOptions.key("flatten-batch-dim") + .booleanType() + .defaultValue(false) + .withDescription( + "Whether to flatten the batch dimension for array inputs. " + + "When true, shape [1,N] becomes [N]. Defaults to false."); + + public static final ConfigOption PRIORITY = + ConfigOptions.key("priority") + .intType() + .noDefaultValue() + .withDescription( + "Request priority level (0-255). Higher values indicate higher priority."); + + public static final ConfigOption SEQUENCE_ID = + ConfigOptions.key("sequence-id") + .stringType() + .noDefaultValue() + .withDescription("Sequence ID for stateful models."); + + public static final ConfigOption SEQUENCE_START = + ConfigOptions.key("sequence-start") + .booleanType() + .defaultValue(false) + .withDescription( + "Whether this is the start of a sequence for stateful models."); + + public static final ConfigOption SEQUENCE_END = + ConfigOptions.key("sequence-end") + .booleanType() + .defaultValue(false) + .withDescription("Whether this is the end of a sequence for stateful models."); + + public static final ConfigOption BINARY_DATA = + ConfigOptions.key("binary-data") + .booleanType() + .defaultValue(false) + .withDescription( + "Whether to use binary data transfer. Defaults to false (JSON)."); + + public static final ConfigOption COMPRESSION = + ConfigOptions.key("compression") + .stringType() + .noDefaultValue() + .withDescription("Compression algorithm to use (e.g., 'gzip')."); + + public static final ConfigOption AUTH_TOKEN = + ConfigOptions.key("auth-token") + .stringType() + .noDefaultValue() + .withDescription("Authentication token for secured Triton servers."); + + public static final ConfigOption CUSTOM_HEADERS = + ConfigOptions.key("custom-headers") + .stringType() + .noDefaultValue() + .withDescription( + "Custom HTTP headers in JSON format, e.g., '{\"X-Custom-Header\":\"value\"}'."); + + protected transient OkHttpClient httpClient; + + private final String endpoint; + private final String modelName; + private final String modelVersion; + private final long timeout; + private final int maxRetries; + private final int batchSize; + private final boolean flattenBatchDim; + private final Integer priority; + private final String sequenceId; + private final boolean sequenceStart; + private final boolean sequenceEnd; + private final boolean binaryData; + private final String compression; + private final String authToken; + private final String customHeaders; + + public AbstractTritonModelFunction( + ModelProviderFactory.Context factoryContext, ReadableConfig config) { + this.endpoint = config.get(ENDPOINT); + this.modelName = config.get(MODEL_NAME); + this.modelVersion = config.get(MODEL_VERSION); + this.timeout = config.get(TIMEOUT); + this.maxRetries = config.get(MAX_RETRIES); + this.batchSize = config.get(BATCH_SIZE); + this.flattenBatchDim = config.get(FLATTEN_BATCH_DIM); + this.priority = config.get(PRIORITY); + this.sequenceId = config.get(SEQUENCE_ID); + this.sequenceStart = config.get(SEQUENCE_START); + this.sequenceEnd = config.get(SEQUENCE_END); + this.binaryData = config.get(BINARY_DATA); + this.compression = config.get(COMPRESSION); + this.authToken = config.get(AUTH_TOKEN); + this.customHeaders = config.get(CUSTOM_HEADERS); + + // Validate input schema - support multiple types + validateInputSchema(factoryContext.getCatalogModel().getResolvedInputSchema()); + } + + @Override + public void open(FunctionContext context) throws Exception { + super.open(context); + LOG.debug("Creating Triton HTTP client."); + this.httpClient = TritonUtils.createHttpClient(timeout, maxRetries); + } + + @Override + public void close() throws Exception { + super.close(); + if (this.httpClient != null) { + LOG.debug("Releasing Triton HTTP client."); + TritonUtils.releaseHttpClient(this.httpClient); + httpClient = null; + } + } + + /** + * Validates the input schema. Subclasses can override for custom validation. + * + * @param schema The input schema to validate + */ + protected void validateInputSchema(ResolvedSchema schema) { + validateSingleColumnSchema(schema, null, "input"); + } + + /** + * Validates that the schema has exactly one physical column, optionally checking the type. + * + *

Version 1 Limitation: Only single input/single output models are supported. For + * models requiring multiple tensors, consider these workarounds: + * + *

    + *
  • Flatten inputs into a JSON STRING and parse server-side + *
  • Use ARRAY<T> to pack multiple values + *
  • Wait for future ROW<...> support (planned for v2) + *
+ * + * @param schema The schema to validate + * @param expectedType The expected type, or null to skip type checking + * @param inputOrOutput Description of whether this is input or output schema + */ + protected void validateSingleColumnSchema( + ResolvedSchema schema, LogicalType expectedType, String inputOrOutput) { + List columns = schema.getColumns(); + if (columns.size() != 1) { + throw new IllegalArgumentException( + String.format( + "Model should have exactly one %s column, but actually has %s columns: %s. " + + "Current version only supports single input/output. " + + "For multi-tensor models, consider using JSON STRING encoding or ARRAY packing.", + inputOrOutput, + columns.size(), + columns.stream().map(Column::getName).collect(Collectors.toList()))); + } + + Column column = columns.get(0); + if (!column.isPhysical()) { + throw new IllegalArgumentException( + String.format( + "%s column %s should be a physical column, but is a %s.", + inputOrOutput, column.getName(), column.getClass())); + } + + if (expectedType != null && !expectedType.equals(column.getDataType().getLogicalType())) { + throw new IllegalArgumentException( + String.format( + "%s column %s should be %s, but is a %s.", + inputOrOutput, + column.getName(), + expectedType, + column.getDataType().getLogicalType())); + } + + // Validate that the type is supported by Triton + try { + TritonTypeMapper.toTritonDataType(column.getDataType().getLogicalType()); + } catch (IllegalArgumentException e) { + String suggestedType = getSuggestedTypeForTriton(column.getDataType().getLogicalType()); + throw new IllegalArgumentException( + String.format( + "%s column %s has unsupported type %s for Triton. %s%s", + inputOrOutput, + column.getName(), + column.getDataType().getLogicalType(), + e.getMessage(), + suggestedType.isEmpty() ? "" : "\nSuggestion: " + suggestedType)); + } + + // Enhanced validation for type compatibility + validateTritonTypeCompatibility( + column.getDataType().getLogicalType(), column.getName(), inputOrOutput); + } + + /** + * Validates Triton type compatibility with enhanced checks. + * + *

This method performs additional validation beyond basic type support: + * + *

    + *
  • Checks for nested arrays (multi-dimensional tensors not supported in v1) + *
  • Warns about STRING to BYTES mapping + *
  • Provides structured error messages with troubleshooting hints + *
+ * + * @param type The logical type to validate + * @param columnName The name of the column + * @param inputOrOutput Description of whether this is input or output + */ + private void validateTritonTypeCompatibility( + LogicalType type, String columnName, String inputOrOutput) { + + // Check for nested arrays (multi-dimensional tensors) + if (type instanceof ArrayType) { + ArrayType arrayType = (ArrayType) type; + LogicalType elementType = arrayType.getElementType(); + + // Reject nested arrays + if (elementType instanceof ArrayType) { + throw new IllegalArgumentException( + String.format( + "%s column '%s' has nested array type: %s\n" + + "Multi-dimensional tensors (ARRAY>) are not supported in v1.\n" + + "=== Supported Types ===\n" + + " • Scalars: INT, BIGINT, FLOAT, DOUBLE, BOOLEAN, STRING\n" + + " • 1-D Arrays: ARRAY, ARRAY, ARRAY, etc.\n" + + "=== Workarounds ===\n" + + " • Flatten to 1-D array: ARRAY with size = rows * cols\n" + + " • Use JSON STRING encoding for complex structures\n" + + " • Wait for v2+ which will support ROW<...> types", + inputOrOutput, columnName, type)); + } + + // Additional check: ensure element type is supported + if (elementType instanceof ArrayType) { + // This should have been caught above, but double-check + throw new IllegalArgumentException( + String.format( + "%s column '%s' has unsupported element type: %s", + inputOrOutput, columnName, elementType)); + } + } + + // Log info about STRING to BYTES mapping + if (type instanceof org.apache.flink.table.types.logical.VarCharType) { + LOG.info( + "{} column '{}' uses STRING type, which will be mapped to Triton BYTES dtype. " + + "Ensure your Triton model expects string/text inputs.", + inputOrOutput, + columnName); + } + } + + /** Provides user-friendly type suggestions for unsupported types. */ + private String getSuggestedTypeForTriton(LogicalType unsupportedType) { + String typeName = unsupportedType.getTypeRoot().name(); + + if (typeName.contains("ARRAY") && unsupportedType instanceof ArrayType) { + ArrayType arrayType = (ArrayType) unsupportedType; + if (arrayType.getElementType() instanceof ArrayType) { + return "Flatten nested array to 1-D: ARRAY instead of ARRAY>"; + } + } + + if (typeName.contains("MAP")) { + return "Use ARRAY instead of MAP, or serialize to JSON STRING"; + } else if (typeName.contains("ROW") || typeName.contains("STRUCT")) { + return "Flatten ROW into single column, use ARRAY packing, or serialize to JSON STRING"; + } else if (typeName.contains("TIME") || typeName.contains("DATE")) { + return "Convert timestamp/date to BIGINT (epoch milliseconds) or STRING (ISO-8601)"; + } else if (typeName.contains("DECIMAL")) { + return "Use DOUBLE for numeric precision or STRING for exact decimal representation"; + } else if (typeName.contains("BINARY") || typeName.contains("VARBINARY")) { + return "Consider using STRING (VARCHAR) type, which maps to Triton BYTES"; + } + + return ""; + } + + // Getters for configuration values + protected String getEndpoint() { + return endpoint; + } + + protected String getModelName() { + return modelName; + } + + protected String getModelVersion() { + return modelVersion; + } + + protected long getTimeout() { + return timeout; + } + + protected int getMaxRetries() { + return maxRetries; + } + + protected int getBatchSize() { + return batchSize; + } + + protected boolean isFlattenBatchDim() { + return flattenBatchDim; + } + + protected Integer getPriority() { + return priority; + } + + protected String getSequenceId() { + return sequenceId; + } + + protected boolean isSequenceStart() { + return sequenceStart; + } + + protected boolean isSequenceEnd() { + return sequenceEnd; + } + + protected boolean isBinaryData() { + return binaryData; + } + + protected String getCompression() { + return compression; + } + + protected String getAuthToken() { + return authToken; + } + + protected String getCustomHeaders() { + return customHeaders; + } +} diff --git a/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/TritonDataType.java b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/TritonDataType.java new file mode 100644 index 0000000000000..451b400f9ffdf --- /dev/null +++ b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/TritonDataType.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.model.triton; + +/** + * Enumeration of data types supported by Triton Inference Server. + * + *

These data types correspond to the types defined in the Triton Inference Server protocol. + * Reference: + * https://github.com/triton-inference-server/server/blob/main/docs/protocol/extension_model_configuration.md + */ +public enum TritonDataType { + /** Boolean type. */ + BOOL("BOOL"), + + /** 8-bit unsigned integer. */ + UINT8("UINT8"), + + /** 16-bit unsigned integer. */ + UINT16("UINT16"), + + /** 32-bit unsigned integer. */ + UINT32("UINT32"), + + /** 64-bit unsigned integer. */ + UINT64("UINT64"), + + /** 8-bit signed integer. */ + INT8("INT8"), + + /** 16-bit signed integer. */ + INT16("INT16"), + + /** 32-bit signed integer. */ + INT32("INT32"), + + /** 64-bit signed integer. */ + INT64("INT64"), + + /** 16-bit floating point (half precision). */ + FP16("FP16"), + + /** 32-bit floating point (single precision). */ + FP32("FP32"), + + /** 64-bit floating point (double precision). */ + FP64("FP64"), + + /** String/text data. */ + BYTES("BYTES"); + + private final String tritonName; + + TritonDataType(String tritonName) { + this.tritonName = tritonName; + } + + /** Returns the Triton protocol name for this data type. */ + public String getTritonName() { + return tritonName; + } + + /** Gets a TritonDataType from its Triton protocol name. */ + public static TritonDataType fromTritonName(String tritonName) { + for (TritonDataType type : values()) { + if (type.tritonName.equals(tritonName)) { + return type; + } + } + throw new IllegalArgumentException("Unknown Triton data type: " + tritonName); + } +} diff --git a/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/TritonInferenceModelFunction.java b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/TritonInferenceModelFunction.java new file mode 100644 index 0000000000000..e2564a396a70e --- /dev/null +++ b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/TritonInferenceModelFunction.java @@ -0,0 +1,450 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.model.triton; + +import org.apache.flink.configuration.ReadableConfig; +import org.apache.flink.model.triton.exception.TritonClientException; +import org.apache.flink.model.triton.exception.TritonNetworkException; +import org.apache.flink.model.triton.exception.TritonSchemaException; +import org.apache.flink.model.triton.exception.TritonServerException; +import org.apache.flink.table.catalog.Column; +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.data.binary.BinaryStringData; +import org.apache.flink.table.factories.ModelProviderFactory; +import org.apache.flink.table.functions.AsyncPredictFunction; +import org.apache.flink.table.types.logical.ArrayType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.VarCharType; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import okhttp3.Call; +import okhttp3.Callback; +import okhttp3.MediaType; +import okhttp3.Request; +import okhttp3.RequestBody; +import okhttp3.Response; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.CompletableFuture; + +/** + * {@link AsyncPredictFunction} for Triton Inference Server generic inference task. + * + *

Request Model (v1): This implementation processes records one-by-one. Each {@link + * #asyncPredict(RowData)} call triggers one HTTP request to Triton server. There is no Flink-side + * mini-batch aggregation in the current version. + * + *

Batch Efficiency: Inference throughput benefits from: + * + *

    + *
  • Triton Dynamic Batching: Configure {@code dynamic_batching} in model's {@code + * config.pbtxt} to aggregate concurrent requests server-side + *
  • Flink Parallelism: High parallelism naturally creates concurrent requests that + * Triton can batch together + *
  • AsyncDataStream Capacity: Buffer size controls concurrent in-flight requests, + * increasing opportunities for server-side batching + *
+ * + *

Future Roadmap (v2+): Flink-side mini-batch aggregation will be added to reduce HTTP + * overhead (configurable via {@code batch-size} and {@code batch-timeout} options). + * + * @see Triton + * Dynamic Batching Documentation + */ +public class TritonInferenceModelFunction extends AbstractTritonModelFunction { + private static final long serialVersionUID = 1L; + private static final Logger LOG = LoggerFactory.getLogger(TritonInferenceModelFunction.class); + + private static final MediaType JSON_MEDIA_TYPE = + MediaType.get("application/json; charset=utf-8"); + private static final ObjectMapper objectMapper = new ObjectMapper(); + + private final LogicalType inputType; + private final LogicalType outputType; + private final String inputName; + private final String outputName; + + public TritonInferenceModelFunction( + ModelProviderFactory.Context factoryContext, ReadableConfig config) { + super(factoryContext, config); + + // Validate and store input/output types + validateSingleColumnSchema( + factoryContext.getCatalogModel().getResolvedOutputSchema(), + null, // Allow any supported type + "output"); + + // Get input and output column information + Column inputColumn = + factoryContext.getCatalogModel().getResolvedInputSchema().getColumns().get(0); + Column outputColumn = + factoryContext.getCatalogModel().getResolvedOutputSchema().getColumns().get(0); + + this.inputType = inputColumn.getDataType().getLogicalType(); + this.outputType = outputColumn.getDataType().getLogicalType(); + this.inputName = inputColumn.getName(); + this.outputName = outputColumn.getName(); + } + + @Override + public CompletableFuture> asyncPredict(RowData rowData) { + CompletableFuture> future = new CompletableFuture<>(); + + try { + String requestBody = buildInferenceRequest(rowData); + String url = + TritonUtils.buildInferenceUrl(getEndpoint(), getModelName(), getModelVersion()); + + Request.Builder requestBuilder = + new Request.Builder() + .url(url) + .post(RequestBody.create(requestBody, JSON_MEDIA_TYPE)); + + // Add authentication header if provided + if (getAuthToken() != null) { + requestBuilder.addHeader("Authorization", "Bearer " + getAuthToken()); + } + + // Add custom headers if provided + if (getCustomHeaders() != null) { + try { + JsonNode headersNode = objectMapper.readTree(getCustomHeaders()); + headersNode + .fields() + .forEachRemaining( + entry -> + requestBuilder.addHeader( + entry.getKey(), entry.getValue().asText())); + } catch (JsonProcessingException e) { + LOG.warn("Failed to parse custom headers: {}", getCustomHeaders(), e); + } + } + + // Add compression header if specified + if (getCompression() != null) { + requestBuilder.addHeader("Content-Encoding", getCompression()); + } + + Request request = requestBuilder.build(); + + httpClient + .newCall(request) + .enqueue( + new Callback() { + @Override + public void onFailure(Call call, IOException e) { + LOG.error( + "Triton inference request failed due to network error", + e); + + // Wrap IOException in TritonNetworkException + TritonNetworkException networkException = + new TritonNetworkException( + String.format( + "Failed to connect to Triton server at %s: %s. " + + "This may indicate network connectivity issues, DNS resolution failure, or server unavailability.", + url, e.getMessage()), + e); + + future.completeExceptionally(networkException); + } + + @Override + public void onResponse(Call call, Response response) + throws IOException { + try { + if (!response.isSuccessful()) { + handleErrorResponse(response, future); + return; + } + + String responseBody = response.body().string(); + LOG.info("Triton inference response: {}", responseBody); + Collection result = + parseInferenceResponse(responseBody); + future.complete(result); + } catch (JsonProcessingException e) { + LOG.error("Failed to parse Triton inference response", e); + future.completeExceptionally( + new TritonClientException( + "Failed to parse Triton response JSON: " + + e.getMessage() + + ". This may indicate an incompatible response format.", + 400)); + } catch (Exception e) { + LOG.error("Failed to process Triton inference response", e); + future.completeExceptionally(e); + } finally { + response.close(); + } + } + }); + + } catch (Exception e) { + LOG.error("Failed to build Triton inference request", e); + future.completeExceptionally(e); + } + + return future; + } + + /** + * Handles HTTP error responses and creates appropriate typed exceptions. + * + * @param response The HTTP response with error status + * @param future The future to complete exceptionally + * @throws IOException If reading response body fails + */ + private void handleErrorResponse( + Response response, CompletableFuture> future) throws IOException { + + String errorBody = + response.body() != null ? response.body().string() : "No error details provided"; + int statusCode = response.code(); + + // Build detailed error message with context + StringBuilder errorMsg = new StringBuilder(); + errorMsg.append( + String.format("Triton inference failed with HTTP %d: %s\n", statusCode, errorBody)); + errorMsg.append("\n=== Request Configuration ===\n"); + errorMsg.append( + String.format(" Model: %s (version: %s)\n", getModelName(), getModelVersion())); + errorMsg.append(String.format(" Endpoint: %s\n", getEndpoint())); + errorMsg.append(String.format(" Input column: %s\n", inputName)); + errorMsg.append(String.format(" Input Flink type: %s\n", inputType)); + errorMsg.append( + String.format( + " Input Triton dtype: %s\n", + TritonTypeMapper.toTritonDataType(inputType).getTritonName())); + + // Check if this is a shape mismatch error + boolean isShapeMismatch = + errorBody.toLowerCase().contains("shape") + || errorBody.toLowerCase().contains("dimension"); + + if (statusCode >= 400 && statusCode < 500) { + // Client error - user configuration issue + errorMsg.append("\n=== Troubleshooting (Client Error) ===\n"); + + if (statusCode == 400) { + errorMsg.append(" • Verify input shape matches model's config.pbtxt\n"); + errorMsg.append(" • For scalar: use INT/FLOAT/DOUBLE/STRING\n"); + errorMsg.append(" • For 1-D tensor: use ARRAY\n"); + errorMsg.append( + " • Try flatten-batch-dim=true if model expects [N] but gets [1,N]\n"); + + if (isShapeMismatch) { + // Create schema exception for shape mismatches + future.completeExceptionally( + new TritonSchemaException( + errorMsg.toString(), + "See Triton model config.pbtxt", + String.format("Flink type: %s", inputType))); + return; + } + } else if (statusCode == 404) { + errorMsg.append(" • Verify model-name: ").append(getModelName()).append("\n"); + errorMsg.append(" • Verify model-version: ") + .append(getModelVersion()) + .append("\n"); + errorMsg.append(" • Check model is loaded: GET ") + .append(getEndpoint()) + .append("\n"); + } else if (statusCode == 401 || statusCode == 403) { + errorMsg.append(" • Check auth-token configuration\n"); + errorMsg.append(" • Verify server authentication requirements\n"); + } + + future.completeExceptionally( + new TritonClientException(errorMsg.toString(), statusCode)); + + } else if (statusCode >= 500 && statusCode < 600) { + // Server error - Triton service issue + errorMsg.append("\n=== Troubleshooting (Server Error) ===\n"); + + if (statusCode == 500) { + errorMsg.append(" • Check Triton server logs for inference crash details\n"); + errorMsg.append(" • Model may have run out of memory\n"); + errorMsg.append(" • Input data may trigger model bug\n"); + } else if (statusCode == 503) { + errorMsg.append(" • Server is overloaded or unavailable\n"); + errorMsg.append(" • This error is retryable with backoff\n"); + errorMsg.append(" • Consider scaling Triton server resources\n"); + } else if (statusCode == 504) { + errorMsg.append(" • Inference exceeded gateway timeout\n"); + errorMsg.append(" • This error is retryable\n"); + errorMsg.append(" • Consider increasing timeout configuration\n"); + } + + future.completeExceptionally( + new TritonServerException(errorMsg.toString(), statusCode)); + + } else { + // Unexpected status code + errorMsg.append("\n=== Unexpected Status Code ===\n"); + errorMsg.append(" • This status code is not standard for Triton\n"); + errorMsg.append(" • Check if proxy/load balancer is involved\n"); + + future.completeExceptionally( + new TritonClientException(errorMsg.toString(), statusCode)); + } + } + + private String buildInferenceRequest(RowData rowData) throws JsonProcessingException { + ObjectNode requestNode = objectMapper.createObjectNode(); + + // Add request ID if sequence ID is provided + if (getSequenceId() != null) { + requestNode.put("id", getSequenceId()); + } + + // Add parameters + ObjectNode parametersNode = objectMapper.createObjectNode(); + if (getPriority() != null) { + parametersNode.put("priority", getPriority()); + } + if (isSequenceStart()) { + parametersNode.put("sequence_start", true); + } + if (isSequenceEnd()) { + parametersNode.put("sequence_end", true); + } + if (parametersNode.size() > 0) { + requestNode.set("parameters", parametersNode); + } + + // Add inputs + ArrayNode inputsArray = objectMapper.createArrayNode(); + ObjectNode inputNode = objectMapper.createObjectNode(); + inputNode.put("name", inputName.toUpperCase()); + + // Map Flink type to Triton type + TritonDataType tritonType = TritonTypeMapper.toTritonDataType(inputType); + inputNode.put("datatype", tritonType.getTritonName()); + + // Serialize input data first to get actual size + ArrayNode dataArray = objectMapper.createArrayNode(); + TritonTypeMapper.serializeToJsonArray(rowData, 0, inputType, dataArray); + + // Calculate and add shape based on actual data + int[] shape = TritonTypeMapper.calculateShape(inputType, 1, rowData, 0); + + // Apply flatten-batch-dim if configured + if (isFlattenBatchDim() && shape.length > 1 && shape[0] == 1) { + // Remove the batch dimension: [1, N] -> [N] + int[] flattenedShape = new int[shape.length - 1]; + System.arraycopy(shape, 1, flattenedShape, 0, flattenedShape.length); + shape = flattenedShape; + } + + ArrayNode shapeArray = objectMapper.createArrayNode(); + for (int dim : shape) { + shapeArray.add(dim); + } + inputNode.set("shape", shapeArray); + inputNode.set("data", dataArray); + + inputsArray.add(inputNode); + requestNode.set("inputs", inputsArray); + + // Add outputs (request all outputs) + ArrayNode outputsArray = objectMapper.createArrayNode(); + ObjectNode outputNode = objectMapper.createObjectNode(); + outputNode.put("name", outputName.toUpperCase()); + outputsArray.add(outputNode); + requestNode.set("outputs", outputsArray); + + String requestJson = objectMapper.writeValueAsString(requestNode); + + // Log the request for debugging + if (LOG.isDebugEnabled()) { + LOG.debug( + "Triton inference request - Model: {}, Version: {}, Input: {}, Shape: {}", + getModelName(), + getModelVersion(), + inputName, + java.util.Arrays.toString(shape)); + LOG.debug("Request body: {}", requestJson); + } + + return requestJson; + } + + private Collection parseInferenceResponse(String responseBody) + throws JsonProcessingException { + JsonNode responseNode = objectMapper.readTree(responseBody); + List results = new ArrayList<>(); + + if (LOG.isDebugEnabled()) { + LOG.debug("Triton response body: {}", responseBody); + } + + JsonNode outputsNode = responseNode.get("outputs"); + if (outputsNode != null && outputsNode.isArray()) { + for (JsonNode outputNode : outputsNode) { + JsonNode dataNode = outputNode.get("data"); + + if (dataNode != null && dataNode.isArray()) { + if (dataNode.size() > 0) { + // Check if output is array type or scalar + // If outputType is scalar but dataNode is array, extract first element + JsonNode nodeToDeserialize = dataNode; + if (!(outputType instanceof ArrayType) && dataNode.isArray()) { + // Scalar type - extract first element from array + nodeToDeserialize = dataNode.get(0); + if (LOG.isDebugEnabled()) { + LOG.debug("Extracting scalar value from array[0]"); + } + } + + Object deserializedData = + TritonTypeMapper.deserializeFromJson(nodeToDeserialize, outputType); + + results.add(GenericRowData.of(deserializedData)); + } + } + } + } else { + LOG.warn("No outputs found in Triton response"); + } + + // If no outputs found, return default value based on type + if (results.isEmpty()) { + Object defaultValue; + if (outputType instanceof VarCharType) { + defaultValue = BinaryStringData.fromString(""); + } else { + defaultValue = null; + } + results.add(GenericRowData.of(defaultValue)); + } + + return results; + } +} diff --git a/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/TritonModelProviderFactory.java b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/TritonModelProviderFactory.java new file mode 100644 index 0000000000000..b3c51d8df8b69 --- /dev/null +++ b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/TritonModelProviderFactory.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.model.triton; + +import org.apache.flink.configuration.ConfigOption; +import org.apache.flink.table.factories.FactoryUtil; +import org.apache.flink.table.factories.ModelProviderFactory; +import org.apache.flink.table.functions.AsyncPredictFunction; +import org.apache.flink.table.ml.AsyncPredictRuntimeProvider; +import org.apache.flink.table.ml.ModelProvider; + +import java.util.HashSet; +import java.util.Set; + +/** {@link ModelProviderFactory} for Triton Inference Server model functions. */ +public class TritonModelProviderFactory implements ModelProviderFactory { + public static final String IDENTIFIER = "triton"; + + @Override + public ModelProvider createModelProvider(ModelProviderFactory.Context context) { + FactoryUtil.ModelProviderFactoryHelper helper = + FactoryUtil.createModelProviderFactoryHelper(this, context); + helper.validate(); + + // For now, we create a generic inference function + // In the future, this could be extended to support different model types + AsyncPredictFunction function = + new TritonInferenceModelFunction(context, helper.getOptions()); + return new Provider(function); + } + + @Override + public String factoryIdentifier() { + return IDENTIFIER; + } + + @Override + public Set> requiredOptions() { + Set> set = new HashSet<>(); + set.add(AbstractTritonModelFunction.ENDPOINT); + set.add(AbstractTritonModelFunction.MODEL_NAME); + set.add(AbstractTritonModelFunction.MODEL_VERSION); + return set; + } + + @Override + public Set> optionalOptions() { + Set> set = new HashSet<>(); + set.add(AbstractTritonModelFunction.TIMEOUT); + set.add(AbstractTritonModelFunction.MAX_RETRIES); + set.add(AbstractTritonModelFunction.BATCH_SIZE); + set.add(AbstractTritonModelFunction.FLATTEN_BATCH_DIM); + set.add(AbstractTritonModelFunction.PRIORITY); + set.add(AbstractTritonModelFunction.SEQUENCE_ID); + set.add(AbstractTritonModelFunction.SEQUENCE_START); + set.add(AbstractTritonModelFunction.SEQUENCE_END); + set.add(AbstractTritonModelFunction.BINARY_DATA); + set.add(AbstractTritonModelFunction.COMPRESSION); + set.add(AbstractTritonModelFunction.AUTH_TOKEN); + set.add(AbstractTritonModelFunction.CUSTOM_HEADERS); + return set; + } + + /** {@link ModelProvider} for Triton model functions. */ + public static class Provider implements AsyncPredictRuntimeProvider { + private final AsyncPredictFunction function; + + public Provider(AsyncPredictFunction function) { + this.function = function; + } + + @Override + public AsyncPredictFunction createAsyncPredictFunction(Context context) { + return function; + } + + @Override + public ModelProvider copy() { + return new Provider(function); + } + } +} diff --git a/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/TritonTypeMapper.java b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/TritonTypeMapper.java new file mode 100644 index 0000000000000..49eab681f942e --- /dev/null +++ b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/TritonTypeMapper.java @@ -0,0 +1,317 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.model.triton; + +import org.apache.flink.table.data.ArrayData; +import org.apache.flink.table.data.GenericArrayData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.data.binary.BinaryStringData; +import org.apache.flink.table.types.logical.ArrayType; +import org.apache.flink.table.types.logical.BigIntType; +import org.apache.flink.table.types.logical.BooleanType; +import org.apache.flink.table.types.logical.DoubleType; +import org.apache.flink.table.types.logical.FloatType; +import org.apache.flink.table.types.logical.IntType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.SmallIntType; +import org.apache.flink.table.types.logical.TinyIntType; +import org.apache.flink.table.types.logical.VarCharType; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.ArrayNode; + +/** Utility class for mapping between Flink logical types and Triton data types. */ +public class TritonTypeMapper { + + /** + * Maps a Flink LogicalType to the corresponding Triton data type. + * + * @param logicalType The Flink logical type + * @return The corresponding Triton data type + * @throws IllegalArgumentException if the type is not supported + */ + public static TritonDataType toTritonDataType(LogicalType logicalType) { + if (logicalType instanceof BooleanType) { + return TritonDataType.BOOL; + } else if (logicalType instanceof TinyIntType) { + return TritonDataType.INT8; + } else if (logicalType instanceof SmallIntType) { + return TritonDataType.INT16; + } else if (logicalType instanceof IntType) { + return TritonDataType.INT32; + } else if (logicalType instanceof BigIntType) { + return TritonDataType.INT64; + } else if (logicalType instanceof FloatType) { + return TritonDataType.FP32; + } else if (logicalType instanceof DoubleType) { + return TritonDataType.FP64; + } else if (logicalType instanceof VarCharType) { + return TritonDataType.BYTES; + } else if (logicalType instanceof ArrayType) { + // For arrays, we map the element type + ArrayType arrayType = (ArrayType) logicalType; + return toTritonDataType(arrayType.getElementType()); + } else { + throw new IllegalArgumentException("Unsupported Flink type for Triton: " + logicalType); + } + } + + /** + * Serializes Flink RowData field value to JSON array for Triton request. + * + * @param rowData The row data + * @param fieldIndex The field index + * @param logicalType The logical type of the field + * @param dataArray The JSON array to add data to + */ + public static void serializeToJsonArray( + RowData rowData, int fieldIndex, LogicalType logicalType, ArrayNode dataArray) { + if (rowData.isNullAt(fieldIndex)) { + dataArray.addNull(); + return; + } + + if (logicalType instanceof BooleanType) { + dataArray.add(rowData.getBoolean(fieldIndex)); + } else if (logicalType instanceof TinyIntType) { + dataArray.add(rowData.getByte(fieldIndex)); + } else if (logicalType instanceof SmallIntType) { + dataArray.add(rowData.getShort(fieldIndex)); + } else if (logicalType instanceof IntType) { + dataArray.add(rowData.getInt(fieldIndex)); + } else if (logicalType instanceof BigIntType) { + dataArray.add(rowData.getLong(fieldIndex)); + } else if (logicalType instanceof FloatType) { + dataArray.add(rowData.getFloat(fieldIndex)); + } else if (logicalType instanceof DoubleType) { + dataArray.add(rowData.getDouble(fieldIndex)); + } else if (logicalType instanceof VarCharType) { + dataArray.add(rowData.getString(fieldIndex).toString()); + } else if (logicalType instanceof ArrayType) { + ArrayType arrayType = (ArrayType) logicalType; + ArrayData arrayData = rowData.getArray(fieldIndex); + serializeArrayToJsonArray(arrayData, arrayType.getElementType(), dataArray); + } else { + throw new IllegalArgumentException( + "Unsupported Flink type for serialization: " + logicalType); + } + } + + /** + * Serializes Flink ArrayData to JSON array (flattened). + * + * @param arrayData The array data + * @param elementType The element type + * @param targetArray The JSON array to add data to + */ + private static void serializeArrayToJsonArray( + ArrayData arrayData, LogicalType elementType, ArrayNode targetArray) { + int size = arrayData.size(); + for (int i = 0; i < size; i++) { + if (arrayData.isNullAt(i)) { + targetArray.addNull(); + continue; + } + + if (elementType instanceof BooleanType) { + targetArray.add(arrayData.getBoolean(i)); + } else if (elementType instanceof TinyIntType) { + targetArray.add(arrayData.getByte(i)); + } else if (elementType instanceof SmallIntType) { + targetArray.add(arrayData.getShort(i)); + } else if (elementType instanceof IntType) { + targetArray.add(arrayData.getInt(i)); + } else if (elementType instanceof BigIntType) { + targetArray.add(arrayData.getLong(i)); + } else if (elementType instanceof FloatType) { + targetArray.add(arrayData.getFloat(i)); + } else if (elementType instanceof DoubleType) { + targetArray.add(arrayData.getDouble(i)); + } else if (elementType instanceof VarCharType) { + targetArray.add(arrayData.getString(i).toString()); + } else { + throw new IllegalArgumentException( + "Unsupported array element type: " + elementType); + } + } + } + + /** + * Deserializes JSON data to Flink object based on logical type. + * + * @param dataNode The JSON node containing the data + * @param logicalType The target logical type + * @return The deserialized object + */ + public static Object deserializeFromJson(JsonNode dataNode, LogicalType logicalType) { + if (dataNode == null || dataNode.isNull()) { + return null; + } + + if (logicalType instanceof BooleanType) { + return dataNode.asBoolean(); + } else if (logicalType instanceof TinyIntType) { + return (byte) dataNode.asInt(); + } else if (logicalType instanceof SmallIntType) { + return (short) dataNode.asInt(); + } else if (logicalType instanceof IntType) { + return dataNode.asInt(); + } else if (logicalType instanceof BigIntType) { + return dataNode.asLong(); + } else if (logicalType instanceof FloatType) { + // Use floatValue() to properly handle the conversion + if (dataNode.isNumber()) { + return dataNode.floatValue(); + } else { + return (float) dataNode.asDouble(); + } + } else if (logicalType instanceof DoubleType) { + return dataNode.asDouble(); + } else if (logicalType instanceof VarCharType) { + return BinaryStringData.fromString(dataNode.asText()); + } else if (logicalType instanceof ArrayType) { + ArrayType arrayType = (ArrayType) logicalType; + return deserializeArrayFromJson(dataNode, arrayType.getElementType()); + } else { + throw new IllegalArgumentException( + "Unsupported Flink type for deserialization: " + logicalType); + } + } + + /** + * Deserializes JSON array to Flink ArrayData. + * + * @param dataNode The JSON array node + * @param elementType The element type + * @return The deserialized ArrayData + */ + private static ArrayData deserializeArrayFromJson(JsonNode dataNode, LogicalType elementType) { + if (!dataNode.isArray()) { + throw new IllegalArgumentException( + "Expected JSON array but got: " + dataNode.getNodeType()); + } + + int size = dataNode.size(); + + // Handle different element types with appropriate array types + if (elementType instanceof BooleanType) { + boolean[] array = new boolean[size]; + int i = 0; + for (JsonNode element : dataNode) { + array[i++] = element.asBoolean(); + } + return new GenericArrayData(array); + } else if (elementType instanceof TinyIntType) { + byte[] array = new byte[size]; + int i = 0; + for (JsonNode element : dataNode) { + array[i++] = (byte) element.asInt(); + } + return new GenericArrayData(array); + } else if (elementType instanceof SmallIntType) { + short[] array = new short[size]; + int i = 0; + for (JsonNode element : dataNode) { + array[i++] = (short) element.asInt(); + } + return new GenericArrayData(array); + } else if (elementType instanceof IntType) { + int[] array = new int[size]; + int i = 0; + for (JsonNode element : dataNode) { + array[i++] = element.asInt(); + } + return new GenericArrayData(array); + } else if (elementType instanceof BigIntType) { + long[] array = new long[size]; + int i = 0; + for (JsonNode element : dataNode) { + array[i++] = element.asLong(); + } + return new GenericArrayData(array); + } else if (elementType instanceof FloatType) { + float[] array = new float[size]; + int i = 0; + for (JsonNode element : dataNode) { + array[i++] = element.isNumber() ? element.floatValue() : (float) element.asDouble(); + } + return new GenericArrayData(array); + } else if (elementType instanceof DoubleType) { + double[] array = new double[size]; + int i = 0; + for (JsonNode element : dataNode) { + array[i++] = element.asDouble(); + } + return new GenericArrayData(array); + } else if (elementType instanceof VarCharType) { + BinaryStringData[] array = new BinaryStringData[size]; + int i = 0; + for (JsonNode element : dataNode) { + array[i++] = BinaryStringData.fromString(element.asText()); + } + return new GenericArrayData(array); + } else { + throw new IllegalArgumentException("Unsupported array element type: " + elementType); + } + } + + /** + * Calculates the shape dimensions for the input data. + * + * @param logicalType The logical type + * @param batchSize The batch size + * @return Array of dimensions + */ + public static int[] calculateShape(LogicalType logicalType, int batchSize) { + if (logicalType instanceof ArrayType) { + // For arrays, we need to know the array size at runtime + // Return shape with batch size and -1 for dynamic dimension + return new int[] {batchSize, -1}; + } else { + // For scalar types, shape is just the batch size + return new int[] {batchSize}; + } + } + + /** + * Calculates the shape dimensions for the input data based on actual row data. + * + * @param logicalType The logical type + * @param batchSize The batch size + * @param rowData The actual row data + * @param fieldIndex The field index in the row + * @return Array of dimensions + */ + public static int[] calculateShape( + LogicalType logicalType, int batchSize, RowData rowData, int fieldIndex) { + if (logicalType instanceof ArrayType) { + // For arrays, calculate actual size from the data + if (rowData.isNullAt(fieldIndex)) { + // Null array - return shape [batchSize, 0] + return new int[] {batchSize, 0}; + } + ArrayData arrayData = rowData.getArray(fieldIndex); + int arraySize = arrayData.size(); + return new int[] {batchSize, arraySize}; + } else { + // For scalar types, shape is just the batch size + return new int[] {batchSize}; + } + } +} diff --git a/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/TritonUtils.java b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/TritonUtils.java new file mode 100644 index 0000000000000..1ba81dd25fcf4 --- /dev/null +++ b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/TritonUtils.java @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.model.triton; + +import org.apache.flink.annotation.VisibleForTesting; + +import okhttp3.OkHttpClient; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * Utility class for Triton Inference Server HTTP client management. + * + *

This class implements a reference-counted singleton pattern for OkHttpClient instances. + * Multiple function instances sharing the same timeout and retry configuration will reuse the same + * client, reducing resource consumption in high-parallelism scenarios. + * + *

Resource Management: + * + *

    + *
  • Clients are cached by (timeout, maxRetries) key + *
  • Reference count tracks active users + *
  • Client is closed when reference count reaches zero + *
  • Thread-safe via synchronized blocks + *
+ * + *

URL Construction: The {@link #buildInferenceUrl} method normalizes endpoint URLs to + * conform to Triton's REST API specification: {@code /v2/models/{name}/versions/{version}/infer} + */ +public class TritonUtils { + private static final Logger LOG = LoggerFactory.getLogger(TritonUtils.class); + + private static final Object LOCK = new Object(); + + private static final Map cache = new HashMap<>(); + + /** + * Creates or retrieves a cached HTTP client with the specified configuration. + * + *

This method implements reference-counted client pooling. Clients with identical timeout + * and retry settings are shared across multiple callers. + * + * @param timeoutMs Timeout in milliseconds for connect, read, and write operations + * @param maxRetries Maximum retry attempts (note: OkHttp retries are automatic for connection + * failures only, not for HTTP errors) + * @return A shared or new OkHttpClient instance + */ + public static OkHttpClient createHttpClient(long timeoutMs, int maxRetries) { + synchronized (LOCK) { + ClientKey key = new ClientKey(timeoutMs, maxRetries); + ClientValue value = cache.get(key); + if (value != null) { + LOG.debug("Returning an existing Triton HTTP client."); + value.referenceCount.incrementAndGet(); + return value.client; + } + + LOG.debug("Building a new Triton HTTP client."); + OkHttpClient client = + new OkHttpClient.Builder() + .connectTimeout(timeoutMs, TimeUnit.MILLISECONDS) + .readTimeout(timeoutMs, TimeUnit.MILLISECONDS) + .writeTimeout(timeoutMs, TimeUnit.MILLISECONDS) + .retryOnConnectionFailure(true) + .build(); + + cache.put(key, new ClientValue(client)); + return client; + } + } + + /** + * Releases a reference to an HTTP client. When the reference count reaches zero, the client is + * closed and removed from the cache. + * + * @param client The client to release + */ + public static void releaseHttpClient(OkHttpClient client) { + synchronized (LOCK) { + ClientKey keyToRemove = null; + ClientValue valueToRemove = null; + + for (Map.Entry entry : cache.entrySet()) { + if (entry.getValue().client == client) { + keyToRemove = entry.getKey(); + valueToRemove = entry.getValue(); + break; + } + } + + if (valueToRemove != null) { + int count = valueToRemove.referenceCount.decrementAndGet(); + if (count == 0) { + LOG.debug("Closing the Triton HTTP client."); + cache.remove(keyToRemove); + // OkHttpClient doesn't need explicit closing, but we can clean up resources + client.dispatcher().executorService().shutdown(); + client.connectionPool().evictAll(); + } + } + } + } + + /** + * Builds the inference URL for a specific model and version. + * + *

This method normalizes various endpoint formats to the standard Triton REST API path: + * + *

+     * Input: http://localhost:8000          → http://localhost:8000/v2/models/mymodel/versions/1/infer
+     * Input: http://localhost:8000/v2       → http://localhost:8000/v2/models/mymodel/versions/1/infer
+     * Input: http://localhost:8000/v2/models → http://localhost:8000/v2/models/mymodel/versions/1/infer
+     * 
+ * + * @param endpoint The base URL or partial URL of the Triton server + * @param modelName The name of the model + * @param modelVersion The version of the model (e.g., "1", "latest") + * @return The complete inference endpoint URL + */ + public static String buildInferenceUrl(String endpoint, String modelName, String modelVersion) { + String baseUrl = endpoint.replaceAll("/*$", ""); + if (!baseUrl.endsWith("/v2/models")) { + if (baseUrl.endsWith("/v2")) { + baseUrl += "/models"; + } else { + baseUrl += "/v2/models"; + } + } + return String.format("%s/%s/versions/%s/infer", baseUrl, modelName, modelVersion); + } + + /** Builds the model metadata URL for a specific model and version. */ + public static String buildModelMetadataUrl( + String endpoint, String modelName, String modelVersion) { + String baseUrl = endpoint.replaceAll("/*$", ""); + if (!baseUrl.endsWith("/v2/models")) { + if (baseUrl.endsWith("/v2")) { + baseUrl += "/models"; + } else { + baseUrl += "/v2/models"; + } + } + return String.format("%s/%s/versions/%s", baseUrl, modelName, modelVersion); + } + + private static class ClientValue { + private final OkHttpClient client; + private final AtomicInteger referenceCount; + + private ClientValue(OkHttpClient client) { + this.client = client; + this.referenceCount = new AtomicInteger(1); + } + } + + private static class ClientKey { + private final long timeoutMs; + private final int maxRetries; + + private ClientKey(long timeoutMs, int maxRetries) { + this.timeoutMs = timeoutMs; + this.maxRetries = maxRetries; + } + + @Override + public int hashCode() { + return Objects.hash(timeoutMs, maxRetries); + } + + @Override + public boolean equals(Object obj) { + return obj instanceof ClientKey + && timeoutMs == ((ClientKey) obj).timeoutMs + && maxRetries == ((ClientKey) obj).maxRetries; + } + } + + @VisibleForTesting + static Map getCache() { + return cache; + } +} diff --git a/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/exception/TritonClientException.java b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/exception/TritonClientException.java new file mode 100644 index 0000000000000..702c4b1e3d96a --- /dev/null +++ b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/exception/TritonClientException.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.model.triton.exception; + +/** + * Exception for client-side errors (HTTP 4xx status codes). + * + *

Indicates user configuration or input data issues that should be fixed by the user. These + * errors are NOT retryable as they require configuration changes. + * + *

Common Scenarios: + * + *

    + *
  • 400 Bad Request: Invalid input shape or data format + *
  • 404 Not Found: Model name or version doesn't exist + *
  • 401 Unauthorized: Invalid authentication token + *
+ */ +public class TritonClientException extends TritonException { + private static final long serialVersionUID = 1L; + + private final int httpStatus; + + /** + * Creates a new client exception. + * + * @param message The detailed error message + * @param httpStatus The HTTP status code (4xx) + */ + public TritonClientException(String message, int httpStatus) { + super(String.format("[HTTP %d] %s", httpStatus, message)); + this.httpStatus = httpStatus; + } + + /** + * Returns the HTTP status code. + * + * @return The HTTP status code (4xx) + */ + public int getHttpStatus() { + return httpStatus; + } + + @Override + public boolean isRetryable() { + // Client errors require configuration fixes, not retries + return false; + } + + @Override + public ErrorCategory getCategory() { + return ErrorCategory.CLIENT_ERROR; + } +} diff --git a/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/exception/TritonException.java b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/exception/TritonException.java new file mode 100644 index 0000000000000..73870113b0170 --- /dev/null +++ b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/exception/TritonException.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.model.triton.exception; + +/** + * Base exception for all Triton Inference Server integration errors. + * + *

This exception hierarchy provides typed error handling for different failure scenarios: + * + *

    + *
  • {@link TritonClientException}: HTTP 4xx errors (user configuration issues) + *
  • {@link TritonServerException}: HTTP 5xx errors (server-side issues) + *
  • {@link TritonNetworkException}: Network/connection failures + *
  • {@link TritonSchemaException}: Shape/type mismatch errors + *
+ */ +public class TritonException extends RuntimeException { + private static final long serialVersionUID = 1L; + + /** Error category for classification and monitoring. */ + public enum ErrorCategory { + /** Client-side errors (4xx): Bad configuration, invalid input, etc. */ + CLIENT_ERROR, + + /** Server-side errors (5xx): Inference failure, service unavailable, etc. */ + SERVER_ERROR, + + /** Network errors: Connection timeout, DNS failure, etc. */ + NETWORK_ERROR, + + /** Schema/type errors: Shape mismatch, incompatible types, etc. */ + SCHEMA_ERROR, + + /** Unknown or unclassified errors. */ + UNKNOWN + } + + /** + * Creates a new Triton exception with the specified message. + * + * @param message The detailed error message + */ + public TritonException(String message) { + super(message); + } + + /** + * Creates a new Triton exception with the specified message and cause. + * + * @param message The detailed error message + * @param cause The underlying cause + */ + public TritonException(String message, Throwable cause) { + super(message, cause); + } + + /** + * Returns true if this error is retryable with exponential backoff. + * + *

Default implementation returns false. Subclasses should override if the error condition is + * transient (e.g., 503 Service Unavailable). + * + * @return true if the operation can be retried + */ + public boolean isRetryable() { + return false; + } + + /** + * Returns the error category for logging, monitoring, and alerting purposes. + * + *

This can be used to: + * + *

    + *
  • Route errors to appropriate handling logic + *
  • Aggregate metrics by error type + *
  • Configure different retry strategies + *
+ * + * @return The error category + */ + public ErrorCategory getCategory() { + return ErrorCategory.UNKNOWN; + } +} diff --git a/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/exception/TritonNetworkException.java b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/exception/TritonNetworkException.java new file mode 100644 index 0000000000000..af2747afbdac0 --- /dev/null +++ b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/exception/TritonNetworkException.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.model.triton.exception; + +/** + * Exception for network-level errors (connection failures, timeouts). + * + *

Indicates transient network issues such as DNS resolution failures, connection timeouts, or + * socket errors. These errors are typically retryable with exponential backoff. + * + *

Common Scenarios: + * + *

    + *
  • Connection refused: Server not reachable + *
  • Connection timeout: Network latency or firewall issues + *
  • DNS resolution failure: Hostname cannot be resolved + *
  • Socket timeout: Long-running request exceeded timeout + *
+ */ +public class TritonNetworkException extends TritonException { + private static final long serialVersionUID = 1L; + + /** + * Creates a new network exception. + * + * @param message The detailed error message + * @param cause The underlying IOException or network error + */ + public TritonNetworkException(String message, Throwable cause) { + super(message, cause); + } + + @Override + public boolean isRetryable() { + // Network errors are typically transient and retryable + return true; + } + + @Override + public ErrorCategory getCategory() { + return ErrorCategory.NETWORK_ERROR; + } +} diff --git a/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/exception/TritonSchemaException.java b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/exception/TritonSchemaException.java new file mode 100644 index 0000000000000..dba00ec09adae --- /dev/null +++ b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/exception/TritonSchemaException.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.model.triton.exception; + +/** + * Exception for schema/shape/type mismatch errors. + * + *

Indicates that the Flink data type or tensor shape does not match what the Triton model + * expects. These errors are NOT retryable as they require schema or configuration fixes. + * + *

Common Scenarios: + * + *

    + *
  • Shape mismatch: Sent [1,224,224,3] but model expects [1,3,224,224] + *
  • Type mismatch: Sent FP32 but model expects INT8 + *
  • Dimension error: Sent scalar but model expects array + *
+ * + *

This exception includes detailed information about both expected and actual schemas to help + * users diagnose and fix the issue. + */ +public class TritonSchemaException extends TritonException { + private static final long serialVersionUID = 1L; + + private final String expectedSchema; + private final String actualSchema; + + /** + * Creates a new schema exception. + * + * @param message The detailed error message + * @param expectedSchema The schema/shape expected by Triton model + * @param actualSchema The schema/shape actually sent + */ + public TritonSchemaException(String message, String expectedSchema, String actualSchema) { + super( + String.format( + "%s\n=== Expected Schema ===\n%s\n=== Actual Schema ===\n%s", + message, expectedSchema, actualSchema)); + this.expectedSchema = expectedSchema; + this.actualSchema = actualSchema; + } + + /** + * Returns the schema/shape expected by the Triton model. + * + * @return The expected schema description + */ + public String getExpectedSchema() { + return expectedSchema; + } + + /** + * Returns the schema/shape that was actually sent. + * + * @return The actual schema description + */ + public String getActualSchema() { + return actualSchema; + } + + @Override + public boolean isRetryable() { + // Schema errors require configuration fixes, not retries + return false; + } + + @Override + public ErrorCategory getCategory() { + return ErrorCategory.SCHEMA_ERROR; + } +} diff --git a/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/exception/TritonServerException.java b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/exception/TritonServerException.java new file mode 100644 index 0000000000000..a2724b9cfbe1b --- /dev/null +++ b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/exception/TritonServerException.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.model.triton.exception; + +/** + * Exception for server-side errors (HTTP 5xx status codes). + * + *

Indicates Triton Inference Server issues such as inference crashes, out of memory, or service + * overload. Some server errors are retryable (e.g., 503 Service Unavailable, 504 Gateway Timeout). + * + *

Common Scenarios: + * + *

    + *
  • 500 Internal Server Error: Model inference crash (NOT retryable) + *
  • 503 Service Unavailable: Server overloaded (retryable) + *
  • 504 Gateway Timeout: Inference took too long (retryable) + *
+ */ +public class TritonServerException extends TritonException { + private static final long serialVersionUID = 1L; + + private final int httpStatus; + + /** + * Creates a new server exception. + * + * @param message The detailed error message + * @param httpStatus The HTTP status code (5xx) + */ + public TritonServerException(String message, int httpStatus) { + super(String.format("[HTTP %d] %s", httpStatus, message)); + this.httpStatus = httpStatus; + } + + /** + * Returns the HTTP status code. + * + * @return The HTTP status code (5xx) + */ + public int getHttpStatus() { + return httpStatus; + } + + @Override + public boolean isRetryable() { + // 503 Service Unavailable and 504 Gateway Timeout are retryable + // 500 Internal Server Error typically requires investigation + return httpStatus == 503 || httpStatus == 504; + } + + @Override + public ErrorCategory getCategory() { + return ErrorCategory.SERVER_ERROR; + } +} diff --git a/flink-models/flink-model-triton/src/main/resources/META-INF/NOTICE b/flink-models/flink-model-triton/src/main/resources/META-INF/NOTICE new file mode 100644 index 0000000000000..b5e5272bc9708 --- /dev/null +++ b/flink-models/flink-model-triton/src/main/resources/META-INF/NOTICE @@ -0,0 +1,28 @@ +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. + +================================================================================ + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + +This project bundles the following dependencies under the Apache Software License 2.0 (http://www.apache.org/licenses/LICENSE-2.0.txt) + +- com.squareup.okhttp3:okhttp +- com.fasterxml.jackson.core:jackson-core +- com.fasterxml.jackson.core:jackson-databind +- com.fasterxml.jackson.core:jackson-annotations diff --git a/flink-models/flink-model-triton/src/main/resources/META-INF/services/org.apache.flink.table.factories.Factory b/flink-models/flink-model-triton/src/main/resources/META-INF/services/org.apache.flink.table.factories.Factory new file mode 100644 index 0000000000000..abf2ba6d518bf --- /dev/null +++ b/flink-models/flink-model-triton/src/main/resources/META-INF/services/org.apache.flink.table.factories.Factory @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +org.apache.flink.model.triton.TritonModelProviderFactory diff --git a/flink-models/flink-model-triton/src/test/java/org/apache/flink/model/triton/TritonModelProviderFactoryTest.java b/flink-models/flink-model-triton/src/test/java/org/apache/flink/model/triton/TritonModelProviderFactoryTest.java new file mode 100644 index 0000000000000..a27e4effbaa6a --- /dev/null +++ b/flink-models/flink-model-triton/src/test/java/org/apache/flink/model/triton/TritonModelProviderFactoryTest.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.model.triton; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** Test for {@link TritonModelProviderFactory}. */ +public class TritonModelProviderFactoryTest { + + @Test + public void testFactoryIdentifier() { + TritonModelProviderFactory factory = new TritonModelProviderFactory(); + assertEquals("triton", factory.factoryIdentifier()); + } + + @Test + public void testRequiredOptions() { + TritonModelProviderFactory factory = new TritonModelProviderFactory(); + assertEquals(3, factory.requiredOptions().size()); + assertTrue(factory.requiredOptions().contains(AbstractTritonModelFunction.ENDPOINT)); + assertTrue(factory.requiredOptions().contains(AbstractTritonModelFunction.MODEL_NAME)); + assertTrue(factory.requiredOptions().contains(AbstractTritonModelFunction.MODEL_VERSION)); + } + + @Test + public void testOptionalOptions() { + TritonModelProviderFactory factory = new TritonModelProviderFactory(); + assertEquals(11, factory.optionalOptions().size()); + assertTrue(factory.optionalOptions().contains(AbstractTritonModelFunction.TIMEOUT)); + assertTrue(factory.optionalOptions().contains(AbstractTritonModelFunction.MAX_RETRIES)); + assertTrue(factory.optionalOptions().contains(AbstractTritonModelFunction.BATCH_SIZE)); + assertTrue(factory.optionalOptions().contains(AbstractTritonModelFunction.PRIORITY)); + assertTrue(factory.optionalOptions().contains(AbstractTritonModelFunction.SEQUENCE_ID)); + assertTrue(factory.optionalOptions().contains(AbstractTritonModelFunction.SEQUENCE_START)); + assertTrue(factory.optionalOptions().contains(AbstractTritonModelFunction.SEQUENCE_END)); + assertTrue(factory.optionalOptions().contains(AbstractTritonModelFunction.BINARY_DATA)); + assertTrue(factory.optionalOptions().contains(AbstractTritonModelFunction.COMPRESSION)); + assertTrue(factory.optionalOptions().contains(AbstractTritonModelFunction.AUTH_TOKEN)); + assertTrue(factory.optionalOptions().contains(AbstractTritonModelFunction.CUSTOM_HEADERS)); + } +} diff --git a/flink-models/flink-model-triton/src/test/java/org/apache/flink/model/triton/TritonTypeMapperTest.java b/flink-models/flink-model-triton/src/test/java/org/apache/flink/model/triton/TritonTypeMapperTest.java new file mode 100644 index 0000000000000..df0c033a4cbb6 --- /dev/null +++ b/flink-models/flink-model-triton/src/test/java/org/apache/flink/model/triton/TritonTypeMapperTest.java @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.model.triton; + +import org.apache.flink.table.data.ArrayData; +import org.apache.flink.table.data.GenericArrayData; +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.data.binary.BinaryStringData; +import org.apache.flink.table.types.logical.ArrayType; +import org.apache.flink.table.types.logical.BigIntType; +import org.apache.flink.table.types.logical.BooleanType; +import org.apache.flink.table.types.logical.DoubleType; +import org.apache.flink.table.types.logical.FloatType; +import org.apache.flink.table.types.logical.IntType; +import org.apache.flink.table.types.logical.SmallIntType; +import org.apache.flink.table.types.logical.TinyIntType; +import org.apache.flink.table.types.logical.VarCharType; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + +/** Test for {@link TritonTypeMapper}. */ +public class TritonTypeMapperTest { + + private static final ObjectMapper objectMapper = new ObjectMapper(); + + @Test + public void testToTritonDataType() { + assertEquals(TritonDataType.BOOL, TritonTypeMapper.toTritonDataType(new BooleanType())); + assertEquals(TritonDataType.INT8, TritonTypeMapper.toTritonDataType(new TinyIntType())); + assertEquals(TritonDataType.INT16, TritonTypeMapper.toTritonDataType(new SmallIntType())); + assertEquals(TritonDataType.INT32, TritonTypeMapper.toTritonDataType(new IntType())); + assertEquals(TritonDataType.INT64, TritonTypeMapper.toTritonDataType(new BigIntType())); + assertEquals(TritonDataType.FP32, TritonTypeMapper.toTritonDataType(new FloatType())); + assertEquals(TritonDataType.FP64, TritonTypeMapper.toTritonDataType(new DoubleType())); + assertEquals( + TritonDataType.BYTES, + TritonTypeMapper.toTritonDataType(new VarCharType(VarCharType.MAX_LENGTH))); + } + + @Test + public void testToTritonDataTypeForArray() { + // For arrays, returns the element type's Triton type + assertEquals( + TritonDataType.FP32, + TritonTypeMapper.toTritonDataType(new ArrayType(new FloatType()))); + assertEquals( + TritonDataType.INT32, + TritonTypeMapper.toTritonDataType(new ArrayType(new IntType()))); + } + + @Test + public void testSerializeScalarTypes() { + // Test boolean + RowData boolRow = GenericRowData.of(true); + ArrayNode boolArray = objectMapper.createArrayNode(); + TritonTypeMapper.serializeToJsonArray(boolRow, 0, new BooleanType(), boolArray); + assertEquals(1, boolArray.size()); + assertEquals(true, boolArray.get(0).asBoolean()); + + // Test int + RowData intRow = GenericRowData.of(42); + ArrayNode intArray = objectMapper.createArrayNode(); + TritonTypeMapper.serializeToJsonArray(intRow, 0, new IntType(), intArray); + assertEquals(1, intArray.size()); + assertEquals(42, intArray.get(0).asInt()); + + // Test float + RowData floatRow = GenericRowData.of(3.14f); + ArrayNode floatArray = objectMapper.createArrayNode(); + TritonTypeMapper.serializeToJsonArray(floatRow, 0, new FloatType(), floatArray); + assertEquals(1, floatArray.size()); + assertEquals(3.14f, floatArray.get(0).floatValue(), 0.001f); + + // Test string + RowData stringRow = GenericRowData.of(BinaryStringData.fromString("hello")); + ArrayNode stringArray = objectMapper.createArrayNode(); + TritonTypeMapper.serializeToJsonArray( + stringRow, 0, new VarCharType(VarCharType.MAX_LENGTH), stringArray); + assertEquals(1, stringArray.size()); + assertEquals("hello", stringArray.get(0).asText()); + } + + @Test + public void testSerializeArrayType() { + Float[] floatArray = new Float[] {1.0f, 2.0f, 3.0f}; + ArrayData arrayData = new GenericArrayData(floatArray); + RowData rowData = GenericRowData.of(arrayData); + + ArrayNode jsonArray = objectMapper.createArrayNode(); + TritonTypeMapper.serializeToJsonArray( + rowData, 0, new ArrayType(new FloatType()), jsonArray); + + // Array should be flattened + assertEquals(3, jsonArray.size()); + assertEquals(1.0f, jsonArray.get(0).floatValue(), 0.001f); + assertEquals(2.0f, jsonArray.get(1).floatValue(), 0.001f); + assertEquals(3.0f, jsonArray.get(2).floatValue(), 0.001f); + } + + @Test + public void testCalculateShape() { + // Scalar type + int[] scalarShape = TritonTypeMapper.calculateShape(new IntType(), 1); + assertArrayEquals(new int[] {1}, scalarShape); + + // Array type + int[] arrayShape = TritonTypeMapper.calculateShape(new ArrayType(new FloatType()), 1); + assertArrayEquals(new int[] {1, -1}, arrayShape); + + // Batch size > 1 + int[] batchShape = TritonTypeMapper.calculateShape(new IntType(), 4); + assertArrayEquals(new int[] {4}, batchShape); + } + + @Test + public void testDeserializeScalarTypes() { + // Test int + assertEquals( + 42, + TritonTypeMapper.deserializeFromJson(objectMapper.valueToTree(42), new IntType())); + + // Test float + Object floatResult = + TritonTypeMapper.deserializeFromJson( + objectMapper.valueToTree(3.14f), new FloatType()); + assertEquals(3.14f, (Float) floatResult, 0.001f); + + // Test string + Object stringResult = + TritonTypeMapper.deserializeFromJson( + objectMapper.valueToTree("hello"), new VarCharType(VarCharType.MAX_LENGTH)); + assertEquals("hello", stringResult.toString()); + } + + @Test + public void testDeserializeArrayType() { + float[] floatArray = new float[] {1.0f, 2.0f, 3.0f}; + Object result = + TritonTypeMapper.deserializeFromJson( + objectMapper.valueToTree(floatArray), new ArrayType(new FloatType())); + + ArrayData arrayData = (ArrayData) result; + assertEquals(3, arrayData.size()); + assertEquals(1.0f, arrayData.getFloat(0), 0.001f); + assertEquals(2.0f, arrayData.getFloat(1), 0.001f); + assertEquals(3.0f, arrayData.getFloat(2), 0.001f); + } + + @Test + public void testSerializeNull() { + GenericRowData nullRow = new GenericRowData(1); + nullRow.setField(0, null); + + ArrayNode jsonArray = objectMapper.createArrayNode(); + TritonTypeMapper.serializeToJsonArray(nullRow, 0, new IntType(), jsonArray); + + assertEquals(1, jsonArray.size()); + assertEquals(true, jsonArray.get(0).isNull()); + } +} diff --git a/flink-models/pom.xml b/flink-models/pom.xml index e7ae8ce866bfd..e8b5c5316b9d3 100644 --- a/flink-models/pom.xml +++ b/flink-models/pom.xml @@ -35,6 +35,7 @@ under the License. flink-model-openai + flink-model-triton diff --git a/tools/ci/stage.sh b/tools/ci/stage.sh index d473056fb7f0b..9c1129bc0299b 100755 --- a/tools/ci/stage.sh +++ b/tools/ci/stage.sh @@ -121,6 +121,7 @@ flink-metrics/flink-metrics-slf4j,\ flink-metrics/flink-metrics-otel,\ flink-connectors/flink-connector-base,\ flink-models/flink-model-openai,\ +flink-models/flink-model-triton,\ " MODULES_TESTS="\