Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.flink.agents.api.chat.model;

import org.apache.flink.agents.api.chat.messages.ChatMessage;
import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup;
import org.apache.flink.agents.api.resource.Resource;
import org.apache.flink.agents.api.resource.ResourceDescriptor;
import org.apache.flink.agents.api.resource.ResourceType;
Expand Down Expand Up @@ -56,4 +57,22 @@ public ResourceType getResourceType() {
*/
public abstract ChatMessage chat(
List<ChatMessage> messages, List<Tool> tools, Map<String, Object> arguments);

/**
* Record token usage metrics for the given model.
*
* @param modelName the name of the model used
* @param promptTokens the number of prompt tokens
* @param completionTokens the number of completion tokens
*/
protected void recordTokenMetrics(String modelName, long promptTokens, long completionTokens) {
FlinkAgentsMetricGroup metricGroup = getMetricGroup();
if (metricGroup == null) {
return;
}

FlinkAgentsMetricGroup modelGroup = metricGroup.getSubGroup(modelName);
modelGroup.getCounter("promptTokens").inc(promptTokens);
modelGroup.getCounter("completionTokens").inc(completionTokens);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ public ChatMessage chat(List<ChatMessage> messages, Map<String, Object> paramete
(BaseChatModelConnection)
this.getResource.apply(this.connection, ResourceType.CHAT_MODEL_CONNECTION);

// Pass metric group to connection for token usage tracking
connection.setMetricGroup(getMetricGroup());

// Format input messages if set prompt.
if (this.prompt != null) {
if (this.prompt instanceof String) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

package org.apache.flink.agents.api.resource;

import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup;

import java.util.function.BiFunction;

/**
Expand All @@ -28,6 +30,9 @@
public abstract class Resource {
protected BiFunction<String, ResourceType, Resource> getResource;

/** The metric group bound to this resource, injected by RunnerContext.getResource(). */
private transient FlinkAgentsMetricGroup metricGroup;

protected Resource(
ResourceDescriptor descriptor, BiFunction<String, ResourceType, Resource> getResource) {
this.getResource = getResource;
Expand All @@ -41,4 +46,22 @@ protected Resource() {}
* @return the resource type
*/
public abstract ResourceType getResourceType();

/**
* Set the metric group for this resource.
*
* @param metricGroup the metric group to bind
*/
public void setMetricGroup(FlinkAgentsMetricGroup metricGroup) {
this.metricGroup = metricGroup;
}

/**
* Get the bound metric group.
*
* @return the bound metric group, or null if not set
*/
protected FlinkAgentsMetricGroup getMetricGroup() {
return metricGroup;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.agents.api.chat.model;

import org.apache.flink.agents.api.chat.messages.ChatMessage;
import org.apache.flink.agents.api.chat.messages.MessageRole;
import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup;
import org.apache.flink.agents.api.resource.Resource;
import org.apache.flink.agents.api.resource.ResourceDescriptor;
import org.apache.flink.agents.api.resource.ResourceType;
import org.apache.flink.agents.api.tools.Tool;
import org.apache.flink.metrics.Counter;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;

import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;

import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.*;

/** Test cases for BaseChatModelConnection token metrics functionality. */
class BaseChatModelConnectionTokenMetricsTest {

private TestChatModelConnection connection;
private FlinkAgentsMetricGroup mockMetricGroup;
private FlinkAgentsMetricGroup mockModelGroup;
private Counter mockPromptTokensCounter;
private Counter mockCompletionTokensCounter;

/** Test implementation of BaseChatModelConnection for testing purposes. */
private static class TestChatModelConnection extends BaseChatModelConnection {

public TestChatModelConnection(
ResourceDescriptor descriptor,
BiFunction<String, ResourceType, Resource> getResource) {
super(descriptor, getResource);
}

@Override
public ChatMessage chat(
List<ChatMessage> messages, List<Tool> tools, Map<String, Object> arguments) {
// Simple test implementation
return new ChatMessage(MessageRole.ASSISTANT, "Test response");
}

// Expose protected method for testing
public void testRecordTokenMetrics(
String modelName, long promptTokens, long completionTokens) {
recordTokenMetrics(modelName, promptTokens, completionTokens);
}
}

@BeforeEach
void setUp() {
connection =
new TestChatModelConnection(
new ResourceDescriptor(
TestChatModelConnection.class.getName(), Collections.emptyMap()),
null);

// Create mock objects
mockMetricGroup = mock(FlinkAgentsMetricGroup.class);
mockModelGroup = mock(FlinkAgentsMetricGroup.class);
mockPromptTokensCounter = mock(Counter.class);
mockCompletionTokensCounter = mock(Counter.class);

// Set up mock behavior
when(mockMetricGroup.getSubGroup("gpt-4")).thenReturn(mockModelGroup);
when(mockModelGroup.getCounter("promptTokens")).thenReturn(mockPromptTokensCounter);
when(mockModelGroup.getCounter("completionTokens")).thenReturn(mockCompletionTokensCounter);
}

@Test
@DisplayName("Test token metrics are recorded when metric group is set")
void testRecordTokenMetricsWithMetricGroup() {
// Set the metric group
connection.setMetricGroup(mockMetricGroup);

// Record token metrics
connection.testRecordTokenMetrics("gpt-4", 100, 50);

// Verify the metrics were recorded
verify(mockMetricGroup).getSubGroup("gpt-4");
verify(mockModelGroup).getCounter("promptTokens");
verify(mockModelGroup).getCounter("completionTokens");
verify(mockPromptTokensCounter).inc(100);
verify(mockCompletionTokensCounter).inc(50);
}

@Test
@DisplayName("Test token metrics are not recorded when metric group is null")
void testRecordTokenMetricsWithoutMetricGroup() {
// Do not set metric group (should be null by default)

// Record token metrics - should not throw
assertDoesNotThrow(() -> connection.testRecordTokenMetrics("gpt-4", 100, 50));

// No metrics should be recorded
verifyNoInteractions(mockMetricGroup);
}

@Test
@DisplayName("Test token metrics hierarchy: actionMetricGroup -> modelName -> counters")
void testTokenMetricsHierarchy() {
// Set the metric group
connection.setMetricGroup(mockMetricGroup);

// Record token metrics for different models
FlinkAgentsMetricGroup mockGpt35Group = mock(FlinkAgentsMetricGroup.class);
Counter mockGpt35PromptCounter = mock(Counter.class);
Counter mockGpt35CompletionCounter = mock(Counter.class);

when(mockMetricGroup.getSubGroup("gpt-3.5-turbo")).thenReturn(mockGpt35Group);
when(mockGpt35Group.getCounter("promptTokens")).thenReturn(mockGpt35PromptCounter);
when(mockGpt35Group.getCounter("completionTokens")).thenReturn(mockGpt35CompletionCounter);

// Record for gpt-4
connection.testRecordTokenMetrics("gpt-4", 100, 50);

// Record for gpt-3.5-turbo
connection.testRecordTokenMetrics("gpt-3.5-turbo", 200, 100);

// Verify each model has its own counters
verify(mockMetricGroup).getSubGroup("gpt-4");
verify(mockMetricGroup).getSubGroup("gpt-3.5-turbo");
verify(mockPromptTokensCounter).inc(100);
verify(mockCompletionTokensCounter).inc(50);
verify(mockGpt35PromptCounter).inc(200);
verify(mockGpt35CompletionCounter).inc(100);
}

@Test
@DisplayName("Test resource type is CHAT_MODEL_CONNECTION")
void testResourceType() {
assertEquals(ResourceType.CHAT_MODEL_CONNECTION, connection.getResourceType());
}
}
13 changes: 11 additions & 2 deletions docs/content/docs/operations/monitoring.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ under the License.

### Built-in Metrics

We offer data monitoring for built-in metrics, which includes events and actions.
We offer data monitoring for built-in metrics, which includes events, actions, and token usage.

#### Event and Action Metrics

| Scope | Metrics | Description | Type |
|-------------|--------------------------------------------------|----------------------------------------------------------------------------------|-------|
Expand All @@ -37,7 +39,14 @@ We offer data monitoring for built-in metrics, which includes events and actions
| **Action** | <action_name>.numOfActionsExecuted | The total number of actions this operator has executed for a specific action name. | Count |
| **Action** | <action_name>.numOfActionsExecutedPerSec | The number of actions this operator has executed per second for a specific action name. | Meter |

####
#### Token Usage Metrics

Token usage metrics are automatically recorded when chat models are invoked through `ChatModelConnection`. These metrics help track LLM API usage and costs.

| Scope | Metrics | Description | Type |
|-----------|---------------------------------------------|--------------------------------------------------------------------------------|-------|
| **Model** | <action_name>.<model_name>.promptTokens | The total number of prompt tokens consumed by the model within an action. | Count |
| **Model** | <action_name>.<model_name>.completionTokens | The total number of completion tokens generated by the model within an action. | Count |

### How to add custom metrics

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,22 @@ public ChatMessage chat(

MessageCreateParams params = buildRequest(messages, tools, arguments);
Message response = client.messages().create(params);
return convertResponse(response, jsonPrefillApplied);
ChatMessage result = convertResponse(response, jsonPrefillApplied);

// Record token metrics
String modelName = null;
if (arguments != null && arguments.get("model") != null) {
modelName = arguments.get("model").toString();
}
if (modelName == null || modelName.isBlank()) {
modelName = this.defaultModel;
}
if (modelName != null && !modelName.isBlank()) {
recordTokenMetrics(
modelName, response.usage().inputTokens(), response.usage().outputTokens());
}

return result;
} catch (Exception e) {
throw new RuntimeException("Failed to call Anthropic messages API.", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,10 @@ public ChatMessage chat(
.map(this::convertToChatRequestMessage)
.collect(Collectors.toList());

final String modelName = (String) arguments.get("model");
ChatCompletionsOptions options =
new ChatCompletionsOptions(chatMessages)
.setModel((String) arguments.get("model"))
.setModel(modelName)
.setTools(azureTools);

ChatCompletions completions = client.complete(options);
Expand All @@ -188,6 +189,15 @@ public ChatMessage chat(
chatMessage.setToolCalls(convertedToolCalls);
}

// Record token metrics if model name is available
if (modelName != null && !modelName.isBlank()) {
CompletionsUsage usage = completions.getUsage();
if (usage != null) {
recordTokenMetrics(
modelName, usage.getPromptTokens(), usage.getCompletionTokens());
}
}

return chatMessage;
} catch (Exception e) {
throw new RuntimeException(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,11 @@ public ChatMessage chat(
.map(this::convertToOllamaChatMessages)
.collect(Collectors.toList());

final String modelName = (String) arguments.get("model");
final OllamaChatRequest chatRequest =
OllamaChatRequest.builder()
.withMessages(ollamaChatMessages)
.withModel((String) arguments.get("model"))
.withModel(modelName)
.withThinking(extractReasoning ? ThinkMode.ENABLED : ThinkMode.DISABLED)
.withUseTools(false)
.build();
Expand All @@ -216,6 +217,16 @@ public ChatMessage chat(
chatMessage.setToolCalls(toolCalls);
}

// Record token metrics if model name is available
if (modelName != null && !modelName.isBlank()) {
Integer promptTokens = ollamaChatResponse.getPromptEvalCount();
Integer completionTokens = ollamaChatResponse.getEvalCount();
if (promptTokens != null && completionTokens != null) {
recordTokenMetrics(
modelName, promptTokens.longValue(), completionTokens.longValue());
}
}

return chatMessage;
} catch (Exception e) {
throw new RuntimeException(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,23 @@ public ChatMessage chat(
try {
ChatCompletionCreateParams params = buildRequest(messages, tools, arguments);
ChatCompletion completion = client.chat().completions().create(params);
return convertResponse(completion);
ChatMessage response = convertResponse(completion);

// Record token metrics
if (completion.usage().isPresent()) {
String modelName = arguments != null ? (String) arguments.get("model") : null;
if (modelName == null || modelName.isBlank()) {
modelName = this.defaultModel;
}
if (modelName != null && !modelName.isBlank()) {
recordTokenMetrics(
modelName,
completion.usage().get().promptTokens(),
completion.usage().get().completionTokens());
}
}

return response;
} catch (Exception e) {
throw new RuntimeException("Failed to call OpenAI chat completions API.", e);
}
Expand Down
Loading