diff --git a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnection.java b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnection.java index 7254e37f..7b70af0c 100644 --- a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnection.java +++ b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnection.java @@ -19,6 +19,7 @@ package org.apache.flink.agents.api.chat.model; import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup; import org.apache.flink.agents.api.resource.Resource; import org.apache.flink.agents.api.resource.ResourceDescriptor; import org.apache.flink.agents.api.resource.ResourceType; @@ -56,4 +57,22 @@ public ResourceType getResourceType() { */ public abstract ChatMessage chat( List messages, List tools, Map arguments); + + /** + * Record token usage metrics for the given model. + * + * @param modelName the name of the model used + * @param promptTokens the number of prompt tokens + * @param completionTokens the number of completion tokens + */ + protected void recordTokenMetrics(String modelName, long promptTokens, long completionTokens) { + FlinkAgentsMetricGroup metricGroup = getMetricGroup(); + if (metricGroup == null) { + return; + } + + FlinkAgentsMetricGroup modelGroup = metricGroup.getSubGroup(modelName); + modelGroup.getCounter("promptTokens").inc(promptTokens); + modelGroup.getCounter("completionTokens").inc(completionTokens); + } } diff --git a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java index 9e16c071..15aa953e 100644 --- a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java +++ b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java @@ -60,6 +60,9 @@ public ChatMessage chat(List messages, Map paramete (BaseChatModelConnection) this.getResource.apply(this.connection, ResourceType.CHAT_MODEL_CONNECTION); + // Pass metric group to connection for token usage tracking + connection.setMetricGroup(getMetricGroup()); + // Format input messages if set prompt. if (this.prompt != null) { if (this.prompt instanceof String) { diff --git a/api/src/main/java/org/apache/flink/agents/api/resource/Resource.java b/api/src/main/java/org/apache/flink/agents/api/resource/Resource.java index 0f22a5fd..78b15ad0 100644 --- a/api/src/main/java/org/apache/flink/agents/api/resource/Resource.java +++ b/api/src/main/java/org/apache/flink/agents/api/resource/Resource.java @@ -18,6 +18,8 @@ package org.apache.flink.agents.api.resource; +import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup; + import java.util.function.BiFunction; /** @@ -28,6 +30,9 @@ public abstract class Resource { protected BiFunction getResource; + /** The metric group bound to this resource, injected by RunnerContext.getResource(). */ + private transient FlinkAgentsMetricGroup metricGroup; + protected Resource( ResourceDescriptor descriptor, BiFunction getResource) { this.getResource = getResource; @@ -41,4 +46,22 @@ protected Resource() {} * @return the resource type */ public abstract ResourceType getResourceType(); + + /** + * Set the metric group for this resource. + * + * @param metricGroup the metric group to bind + */ + public void setMetricGroup(FlinkAgentsMetricGroup metricGroup) { + this.metricGroup = metricGroup; + } + + /** + * Get the bound metric group. + * + * @return the bound metric group, or null if not set + */ + protected FlinkAgentsMetricGroup getMetricGroup() { + return metricGroup; + } } diff --git a/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnectionTokenMetricsTest.java b/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnectionTokenMetricsTest.java new file mode 100644 index 00000000..53c9bc6c --- /dev/null +++ b/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnectionTokenMetricsTest.java @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api.chat.model; + +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.chat.messages.MessageRole; +import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup; +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.tools.Tool; +import org.apache.flink.metrics.Counter; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + +/** Test cases for BaseChatModelConnection token metrics functionality. */ +class BaseChatModelConnectionTokenMetricsTest { + + private TestChatModelConnection connection; + private FlinkAgentsMetricGroup mockMetricGroup; + private FlinkAgentsMetricGroup mockModelGroup; + private Counter mockPromptTokensCounter; + private Counter mockCompletionTokensCounter; + + /** Test implementation of BaseChatModelConnection for testing purposes. */ + private static class TestChatModelConnection extends BaseChatModelConnection { + + public TestChatModelConnection( + ResourceDescriptor descriptor, + BiFunction getResource) { + super(descriptor, getResource); + } + + @Override + public ChatMessage chat( + List messages, List tools, Map arguments) { + // Simple test implementation + return new ChatMessage(MessageRole.ASSISTANT, "Test response"); + } + + // Expose protected method for testing + public void testRecordTokenMetrics( + String modelName, long promptTokens, long completionTokens) { + recordTokenMetrics(modelName, promptTokens, completionTokens); + } + } + + @BeforeEach + void setUp() { + connection = + new TestChatModelConnection( + new ResourceDescriptor( + TestChatModelConnection.class.getName(), Collections.emptyMap()), + null); + + // Create mock objects + mockMetricGroup = mock(FlinkAgentsMetricGroup.class); + mockModelGroup = mock(FlinkAgentsMetricGroup.class); + mockPromptTokensCounter = mock(Counter.class); + mockCompletionTokensCounter = mock(Counter.class); + + // Set up mock behavior + when(mockMetricGroup.getSubGroup("gpt-4")).thenReturn(mockModelGroup); + when(mockModelGroup.getCounter("promptTokens")).thenReturn(mockPromptTokensCounter); + when(mockModelGroup.getCounter("completionTokens")).thenReturn(mockCompletionTokensCounter); + } + + @Test + @DisplayName("Test token metrics are recorded when metric group is set") + void testRecordTokenMetricsWithMetricGroup() { + // Set the metric group + connection.setMetricGroup(mockMetricGroup); + + // Record token metrics + connection.testRecordTokenMetrics("gpt-4", 100, 50); + + // Verify the metrics were recorded + verify(mockMetricGroup).getSubGroup("gpt-4"); + verify(mockModelGroup).getCounter("promptTokens"); + verify(mockModelGroup).getCounter("completionTokens"); + verify(mockPromptTokensCounter).inc(100); + verify(mockCompletionTokensCounter).inc(50); + } + + @Test + @DisplayName("Test token metrics are not recorded when metric group is null") + void testRecordTokenMetricsWithoutMetricGroup() { + // Do not set metric group (should be null by default) + + // Record token metrics - should not throw + assertDoesNotThrow(() -> connection.testRecordTokenMetrics("gpt-4", 100, 50)); + + // No metrics should be recorded + verifyNoInteractions(mockMetricGroup); + } + + @Test + @DisplayName("Test token metrics hierarchy: actionMetricGroup -> modelName -> counters") + void testTokenMetricsHierarchy() { + // Set the metric group + connection.setMetricGroup(mockMetricGroup); + + // Record token metrics for different models + FlinkAgentsMetricGroup mockGpt35Group = mock(FlinkAgentsMetricGroup.class); + Counter mockGpt35PromptCounter = mock(Counter.class); + Counter mockGpt35CompletionCounter = mock(Counter.class); + + when(mockMetricGroup.getSubGroup("gpt-3.5-turbo")).thenReturn(mockGpt35Group); + when(mockGpt35Group.getCounter("promptTokens")).thenReturn(mockGpt35PromptCounter); + when(mockGpt35Group.getCounter("completionTokens")).thenReturn(mockGpt35CompletionCounter); + + // Record for gpt-4 + connection.testRecordTokenMetrics("gpt-4", 100, 50); + + // Record for gpt-3.5-turbo + connection.testRecordTokenMetrics("gpt-3.5-turbo", 200, 100); + + // Verify each model has its own counters + verify(mockMetricGroup).getSubGroup("gpt-4"); + verify(mockMetricGroup).getSubGroup("gpt-3.5-turbo"); + verify(mockPromptTokensCounter).inc(100); + verify(mockCompletionTokensCounter).inc(50); + verify(mockGpt35PromptCounter).inc(200); + verify(mockGpt35CompletionCounter).inc(100); + } + + @Test + @DisplayName("Test resource type is CHAT_MODEL_CONNECTION") + void testResourceType() { + assertEquals(ResourceType.CHAT_MODEL_CONNECTION, connection.getResourceType()); + } +} diff --git a/docs/content/docs/operations/monitoring.md b/docs/content/docs/operations/monitoring.md index db1c44a5..462a7ba1 100644 --- a/docs/content/docs/operations/monitoring.md +++ b/docs/content/docs/operations/monitoring.md @@ -26,7 +26,9 @@ under the License. ### Built-in Metrics -We offer data monitoring for built-in metrics, which includes events and actions. +We offer data monitoring for built-in metrics, which includes events, actions, and token usage. + +#### Event and Action Metrics | Scope | Metrics | Description | Type | |-------------|--------------------------------------------------|----------------------------------------------------------------------------------|-------| @@ -37,7 +39,14 @@ We offer data monitoring for built-in metrics, which includes events and actions | **Action** | .numOfActionsExecuted | The total number of actions this operator has executed for a specific action name. | Count | | **Action** | .numOfActionsExecutedPerSec | The number of actions this operator has executed per second for a specific action name. | Meter | -#### +#### Token Usage Metrics + +Token usage metrics are automatically recorded when chat models are invoked through `ChatModelConnection`. These metrics help track LLM API usage and costs. + +| Scope | Metrics | Description | Type | +|-----------|---------------------------------------------|--------------------------------------------------------------------------------|-------| +| **Model** | ..promptTokens | The total number of prompt tokens consumed by the model within an action. | Count | +| **Model** | ..completionTokens | The total number of completion tokens generated by the model within an action. | Count | ### How to add custom metrics diff --git a/integrations/chat-models/anthropic/src/main/java/org/apache/flink/agents/integrations/chatmodels/anthropic/AnthropicChatModelConnection.java b/integrations/chat-models/anthropic/src/main/java/org/apache/flink/agents/integrations/chatmodels/anthropic/AnthropicChatModelConnection.java index ab30c97e..49fbef3e 100644 --- a/integrations/chat-models/anthropic/src/main/java/org/apache/flink/agents/integrations/chatmodels/anthropic/AnthropicChatModelConnection.java +++ b/integrations/chat-models/anthropic/src/main/java/org/apache/flink/agents/integrations/chatmodels/anthropic/AnthropicChatModelConnection.java @@ -128,7 +128,22 @@ public ChatMessage chat( MessageCreateParams params = buildRequest(messages, tools, arguments); Message response = client.messages().create(params); - return convertResponse(response, jsonPrefillApplied); + ChatMessage result = convertResponse(response, jsonPrefillApplied); + + // Record token metrics + String modelName = null; + if (arguments != null && arguments.get("model") != null) { + modelName = arguments.get("model").toString(); + } + if (modelName == null || modelName.isBlank()) { + modelName = this.defaultModel; + } + if (modelName != null && !modelName.isBlank()) { + recordTokenMetrics( + modelName, response.usage().inputTokens(), response.usage().outputTokens()); + } + + return result; } catch (Exception e) { throw new RuntimeException("Failed to call Anthropic messages API.", e); } diff --git a/integrations/chat-models/azureai/src/main/java/org/apache/flink/agents/integrations/chatmodels/azureai/AzureAIChatModelConnection.java b/integrations/chat-models/azureai/src/main/java/org/apache/flink/agents/integrations/chatmodels/azureai/AzureAIChatModelConnection.java index 182c4c12..f223ee55 100644 --- a/integrations/chat-models/azureai/src/main/java/org/apache/flink/agents/integrations/chatmodels/azureai/AzureAIChatModelConnection.java +++ b/integrations/chat-models/azureai/src/main/java/org/apache/flink/agents/integrations/chatmodels/azureai/AzureAIChatModelConnection.java @@ -171,9 +171,10 @@ public ChatMessage chat( .map(this::convertToChatRequestMessage) .collect(Collectors.toList()); + final String modelName = (String) arguments.get("model"); ChatCompletionsOptions options = new ChatCompletionsOptions(chatMessages) - .setModel((String) arguments.get("model")) + .setModel(modelName) .setTools(azureTools); ChatCompletions completions = client.complete(options); @@ -188,6 +189,15 @@ public ChatMessage chat( chatMessage.setToolCalls(convertedToolCalls); } + // Record token metrics if model name is available + if (modelName != null && !modelName.isBlank()) { + CompletionsUsage usage = completions.getUsage(); + if (usage != null) { + recordTokenMetrics( + modelName, usage.getPromptTokens(), usage.getCompletionTokens()); + } + } + return chatMessage; } catch (Exception e) { throw new RuntimeException(e); diff --git a/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java b/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java index 5faa3ed4..b591e0ea 100644 --- a/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java +++ b/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java @@ -189,10 +189,11 @@ public ChatMessage chat( .map(this::convertToOllamaChatMessages) .collect(Collectors.toList()); + final String modelName = (String) arguments.get("model"); final OllamaChatRequest chatRequest = OllamaChatRequest.builder() .withMessages(ollamaChatMessages) - .withModel((String) arguments.get("model")) + .withModel(modelName) .withThinking(extractReasoning ? ThinkMode.ENABLED : ThinkMode.DISABLED) .withUseTools(false) .build(); @@ -216,6 +217,16 @@ public ChatMessage chat( chatMessage.setToolCalls(toolCalls); } + // Record token metrics if model name is available + if (modelName != null && !modelName.isBlank()) { + Integer promptTokens = ollamaChatResponse.getPromptEvalCount(); + Integer completionTokens = ollamaChatResponse.getEvalCount(); + if (promptTokens != null && completionTokens != null) { + recordTokenMetrics( + modelName, promptTokens.longValue(), completionTokens.longValue()); + } + } + return chatMessage; } catch (Exception e) { throw new RuntimeException(e); diff --git a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIChatModelConnection.java b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIChatModelConnection.java index ff15cf67..2675d424 100644 --- a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIChatModelConnection.java +++ b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIChatModelConnection.java @@ -142,7 +142,23 @@ public ChatMessage chat( try { ChatCompletionCreateParams params = buildRequest(messages, tools, arguments); ChatCompletion completion = client.chat().completions().create(params); - return convertResponse(completion); + ChatMessage response = convertResponse(completion); + + // Record token metrics + if (completion.usage().isPresent()) { + String modelName = arguments != null ? (String) arguments.get("model") : null; + if (modelName == null || modelName.isBlank()) { + modelName = this.defaultModel; + } + if (modelName != null && !modelName.isBlank()) { + recordTokenMetrics( + modelName, + completion.usage().get().promptTokens(), + completion.usage().get().completionTokens()); + } + } + + return response; } catch (Exception e) { throw new RuntimeException("Failed to call OpenAI chat completions API.", e); } diff --git a/python/flink_agents/api/chat_models/chat_model.py b/python/flink_agents/api/chat_models/chat_model.py index 9614edd3..fbde4786 100644 --- a/python/flink_agents/api/chat_models/chat_model.py +++ b/python/flink_agents/api/chat_models/chat_model.py @@ -92,6 +92,28 @@ def _extract_reasoning( cleaned = cleaned.strip() return cleaned, reasoning + def _record_token_metrics( + self, model_name: str, prompt_tokens: int, completion_tokens: int + ) -> None: + """Record token usage metrics for the given model. + + Parameters + ---------- + model_name : str + The name of the model used + prompt_tokens : int + The number of prompt tokens + completion_tokens : int + The number of completion tokens + """ + metric_group = self.metric_group + if metric_group is None: + return + + model_group = metric_group.get_sub_group(model_name) + model_group.get_counter("promptTokens").inc(prompt_tokens) + model_group.get_counter("completionTokens").inc(completion_tokens) + @abstractmethod def chat( self, @@ -173,6 +195,10 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatMessage: self.connection, ResourceType.CHAT_MODEL_CONNECTION ) + # Pass metric group to connection for token usage tracking + if self.metric_group is not None: + connection.set_metric_group(self.metric_group) + # Apply prompt template if self.prompt is not None: if isinstance(self.prompt, str): diff --git a/python/flink_agents/api/chat_models/tests/__init__.py b/python/flink_agents/api/chat_models/tests/__init__.py new file mode 100644 index 00000000..e154fadd --- /dev/null +++ b/python/flink_agents/api/chat_models/tests/__init__.py @@ -0,0 +1,17 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# diff --git a/python/flink_agents/api/chat_models/tests/test_token_metrics.py b/python/flink_agents/api/chat_models/tests/test_token_metrics.py new file mode 100644 index 00000000..c51c1e5c --- /dev/null +++ b/python/flink_agents/api/chat_models/tests/test_token_metrics.py @@ -0,0 +1,182 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +"""Test cases for BaseChatModelConnection token metrics functionality.""" +from typing import Any, List, Sequence +from unittest.mock import MagicMock + +from flink_agents.api.chat_message import ChatMessage, MessageRole +from flink_agents.api.chat_models.chat_model import BaseChatModelConnection +from flink_agents.api.metric_group import Counter, MetricGroup +from flink_agents.api.resource import ResourceType +from flink_agents.api.tools.tool import Tool + + +class TestChatModelConnection(BaseChatModelConnection): + """Test implementation of BaseChatModelConnection for testing purposes.""" + + @classmethod + def resource_type(cls) -> ResourceType: + """Return resource type of class.""" + return ResourceType.CHAT_MODEL_CONNECTION + + def chat( + self, + messages: Sequence[ChatMessage], + tools: List[Tool] | None = None, + **kwargs: Any, + ) -> ChatMessage: + """Simple test implementation.""" + return ChatMessage(role=MessageRole.ASSISTANT, content="Test response") + + def test_record_token_metrics( + self, model_name: str, prompt_tokens: int, completion_tokens: int + ) -> None: + """Expose protected method for testing.""" + self._record_token_metrics(model_name, prompt_tokens, completion_tokens) + + +class _MockCounter(Counter): + """Mock implementation of Counter for testing.""" + + def __init__(self) -> None: + self._count = 0 + + def inc(self, n: int = 1) -> None: + self._count += n + + def dec(self, n: int = 1) -> None: + self._count -= n + + def get_count(self) -> int: + return self._count + + +class _MockMetricGroup(MetricGroup): + """Mock implementation of MetricGroup for testing.""" + + def __init__(self) -> None: + self._sub_groups: dict[str, _MockMetricGroup] = {} + self._counters: dict[str, _MockCounter] = {} + + def get_sub_group(self, name: str) -> "_MockMetricGroup": + if name not in self._sub_groups: + self._sub_groups[name] = _MockMetricGroup() + return self._sub_groups[name] + + def get_counter(self, name: str) -> _MockCounter: + if name not in self._counters: + self._counters[name] = _MockCounter() + return self._counters[name] + + def get_meter(self, name: str) -> Any: + return MagicMock() + + def get_gauge(self, name: str) -> Any: + return MagicMock() + + def get_histogram(self, name: str, window_size: int = 100) -> Any: + return MagicMock() + + +class TestBaseChatModelConnectionTokenMetrics: + """Test cases for BaseChatModelConnection token metrics functionality.""" + + def test_record_token_metrics_with_metric_group(self) -> None: + """Test token metrics are recorded when metric group is set.""" + connection = TestChatModelConnection() + mock_metric_group = _MockMetricGroup() + + # Set the metric group + connection.set_metric_group(mock_metric_group) + + # Record token metrics + connection.test_record_token_metrics("gpt-4", 100, 50) + + # Verify the metrics were recorded + model_group = mock_metric_group.get_sub_group("gpt-4") + assert model_group.get_counter("promptTokens").get_count() == 100 + assert model_group.get_counter("completionTokens").get_count() == 50 + + def test_record_token_metrics_without_metric_group(self) -> None: + """Test token metrics are not recorded when metric group is null.""" + connection = TestChatModelConnection() + + # Do not set metric group (should be None by default) + # Record token metrics - should not throw + connection.test_record_token_metrics("gpt-4", 100, 50) + # No exception should be raised + + def test_token_metrics_hierarchy(self) -> None: + """Test token metrics hierarchy: actionMetricGroup -> modelName -> counters.""" + connection = TestChatModelConnection() + mock_metric_group = _MockMetricGroup() + + # Set the metric group + connection.set_metric_group(mock_metric_group) + + # Record for gpt-4 + connection.test_record_token_metrics("gpt-4", 100, 50) + + # Record for gpt-3.5-turbo + connection.test_record_token_metrics("gpt-3.5-turbo", 200, 100) + + # Verify each model has its own counters + gpt4_group = mock_metric_group.get_sub_group("gpt-4") + gpt35_group = mock_metric_group.get_sub_group("gpt-3.5-turbo") + + assert gpt4_group.get_counter("promptTokens").get_count() == 100 + assert gpt4_group.get_counter("completionTokens").get_count() == 50 + assert gpt35_group.get_counter("promptTokens").get_count() == 200 + assert gpt35_group.get_counter("completionTokens").get_count() == 100 + + def test_token_metrics_accumulation(self) -> None: + """Test that token metrics accumulate across multiple calls.""" + connection = TestChatModelConnection() + mock_metric_group = _MockMetricGroup() + + # Set the metric group + connection.set_metric_group(mock_metric_group) + + # Record multiple times for the same model + connection.test_record_token_metrics("gpt-4", 100, 50) + connection.test_record_token_metrics("gpt-4", 150, 75) + + # Verify the metrics accumulated + model_group = mock_metric_group.get_sub_group("gpt-4") + assert model_group.get_counter("promptTokens").get_count() == 250 + assert model_group.get_counter("completionTokens").get_count() == 125 + + def test_resource_type(self) -> None: + """Test resource type is CHAT_MODEL_CONNECTION.""" + connection = TestChatModelConnection() + assert connection.resource_type() == ResourceType.CHAT_MODEL_CONNECTION + + def test_bound_metric_group_property(self) -> None: + """Test bound_metric_group property.""" + connection = TestChatModelConnection() + + # Initially should be None + assert connection.metric_group is None + + # Set metric group + mock_metric_group = _MockMetricGroup() + connection.set_metric_group(mock_metric_group) + + # Now should return the set metric group + assert connection.metric_group is mock_metric_group + diff --git a/python/flink_agents/api/resource.py b/python/flink_agents/api/resource.py index 056a850a..090690e0 100644 --- a/python/flink_agents/api/resource.py +++ b/python/flink_agents/api/resource.py @@ -18,9 +18,12 @@ import importlib from abc import ABC, abstractmethod from enum import Enum -from typing import Any, Callable, Dict, Type +from typing import TYPE_CHECKING, Any, Callable, Dict, Type -from pydantic import BaseModel, Field, model_serializer, model_validator +from pydantic import BaseModel, Field, PrivateAttr, model_serializer, model_validator + +if TYPE_CHECKING: + from flink_agents.api.metric_group import MetricGroup class ResourceType(Enum): @@ -58,11 +61,35 @@ class Resource(BaseModel, ABC): exclude=True, default=None ) + # The metric group bound to this resource, injected in RunnerContext#get_resource + _metric_group: "MetricGroup | None" = PrivateAttr(default=None) + @classmethod @abstractmethod def resource_type(cls) -> ResourceType: """Return resource type of class.""" + def set_metric_group(self, metric_group: "MetricGroup") -> None: + """Set the metric group for this resource. + + Parameters + ---------- + metric_group : MetricGroup + The metric group to bind. + """ + self._metric_group = metric_group + + @property + def metric_group(self) -> "MetricGroup | None": + """Get the bound metric group. + + Returns: + ------- + MetricGroup | None + The bound metric group, or None if not set. + """ + return self._metric_group + class SerializableResource(Resource, ABC): """Resource which is serializable.""" diff --git a/python/flink_agents/integrations/chat_models/anthropic/anthropic_chat_model.py b/python/flink_agents/integrations/chat_models/anthropic/anthropic_chat_model.py index bef1043b..297fb5e7 100644 --- a/python/flink_agents/integrations/chat_models/anthropic/anthropic_chat_model.py +++ b/python/flink_agents/integrations/chat_models/anthropic/anthropic_chat_model.py @@ -167,6 +167,15 @@ def chat(self, messages: Sequence[ChatMessage], tools: List[Tool] | None = None, **kwargs, ) + # Record token metrics if model name and usage are available + model_name = kwargs.get("model") + if model_name and message.usage: + self._record_token_metrics( + model_name, + message.usage.input_tokens, + message.usage.output_tokens, + ) + if message.stop_reason == "tool_use": tool_calls = [ { diff --git a/python/flink_agents/integrations/chat_models/ollama_chat_model.py b/python/flink_agents/integrations/chat_models/ollama_chat_model.py index 8549ff38..3eeff8ee 100644 --- a/python/flink_agents/integrations/chat_models/ollama_chat_model.py +++ b/python/flink_agents/integrations/chat_models/ollama_chat_model.py @@ -95,8 +95,9 @@ def chat( if tools is not None: ollama_tools = [to_openai_tool(metadata=tool.metadata) for tool in tools] + model_name = kwargs.pop("model") response = self.client.chat( - model=kwargs.pop("model"), + model=model_name, messages=ollama_messages, stream=False, tools=ollama_tools, @@ -128,6 +129,16 @@ def chat( if reasoning: extra_args["reasoning"] = reasoning + # Record token metrics if model name and usage are available + if ( + model_name + and response.prompt_eval_count is not None + and response.eval_count is not None + ): + self._record_token_metrics( + model_name, response.prompt_eval_count, response.eval_count + ) + return ChatMessage( role=MessageRole(response.message.role), content=content, diff --git a/python/flink_agents/integrations/chat_models/openai/openai_chat_model.py b/python/flink_agents/integrations/chat_models/openai/openai_chat_model.py index f12dd39e..a00bbaea 100644 --- a/python/flink_agents/integrations/chat_models/openai/openai_chat_model.py +++ b/python/flink_agents/integrations/chat_models/openai/openai_chat_model.py @@ -171,9 +171,18 @@ def chat( **kwargs, ) - response = response.choices[0].message + # Record token metrics if model name and usage are available + model_name = kwargs.get("model") + if model_name and response.usage: + self._record_token_metrics( + model_name, + response.usage.prompt_tokens, + response.usage.completion_tokens, + ) - return convert_from_openai_message(response) + message = response.choices[0].message + + return convert_from_openai_message(message) DEFAULT_TEMPERATURE = 0.1 diff --git a/python/flink_agents/integrations/chat_models/tests/test_tongyi_chat_model.py b/python/flink_agents/integrations/chat_models/tests/test_tongyi_chat_model.py index 3c3f13d6..3278f1ed 100644 --- a/python/flink_agents/integrations/chat_models/tests/test_tongyi_chat_model.py +++ b/python/flink_agents/integrations/chat_models/tests/test_tongyi_chat_model.py @@ -125,6 +125,7 @@ def test_tongyi_chat_with_extract_reasoning(monkeypatch: pytest.MonkeyPatch) -> } ] }, + usage=SimpleNamespace(input_tokens=100, output_tokens=50), ) mock_call = MagicMock(return_value=mocked_response) diff --git a/python/flink_agents/integrations/chat_models/tongyi_chat_model.py b/python/flink_agents/integrations/chat_models/tongyi_chat_model.py index 0148603c..164eb8ad 100644 --- a/python/flink_agents/integrations/chat_models/tongyi_chat_model.py +++ b/python/flink_agents/integrations/chat_models/tongyi_chat_model.py @@ -113,8 +113,9 @@ def chat( req_api_key = kwargs.pop("api_key", self.api_key) + model_name = kwargs.pop("model", DEFAULT_MODEL) response = Generation.call( - model=kwargs.pop("model", DEFAULT_MODEL), + model=model_name, messages=tongyi_messages, tools=tongyi_tools, result_format="message", @@ -123,10 +124,16 @@ def chat( **kwargs, ) - if getattr(response, "status_code", 200) != 200: - msg = f"DashScope call failed: {getattr(response, 'message', 'unknown error')}" + if response.status_code != 200: + msg = f"DashScope call failed: {response.message}" raise RuntimeError(msg) + # Record token metrics if model name and usage are available + if model_name and response.usage: + self._record_token_metrics( + model_name, response.usage.input_tokens, response.usage.output_tokens + ) + choice = response.output["choices"][0] response_message: Dict[str, Any] = choice["message"] diff --git a/python/flink_agents/runtime/flink_runner_context.py b/python/flink_agents/runtime/flink_runner_context.py index 43190743..44425aee 100644 --- a/python/flink_agents/runtime/flink_runner_context.py +++ b/python/flink_agents/runtime/flink_runner_context.py @@ -98,7 +98,10 @@ def send_event(self, event: Event) -> None: @override def get_resource(self, name: str, type: ResourceType) -> Resource: - return self.__agent_plan.get_resource(name, type) + resource = self.__agent_plan.get_resource(name, type) + # Bind current action's metric group to the resource + resource.set_metric_group(self.action_metric_group) + return resource @property @override diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java index 6321d985..7e49ff41 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java @@ -154,7 +154,10 @@ public Resource getResource(String name, ResourceType type) throws Exception { if (agentPlan == null) { throw new IllegalStateException("AgentPlan is not available in this context"); } - return agentPlan.getResource(name, type); + Resource resource = agentPlan.getResource(name, type); + // Set current action's metric group to the resource + resource.setMetricGroup(getActionMetricGroup()); + return resource; } @Override