Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Total Tokens & Total Cost by Model was not being output correctly… #53

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 64 additions & 57 deletions llm/src/main/java/com/instana/dc/llm/impl/llm/LLMDc.java
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ public void resetMetrics() {

public LLMDc(Map<String, Object> properties, CustomDcConfig cdcConfig) throws Exception {
super(properties, cdcConfig);
logLLMSpecificConfig(properties);
watsonxPricePromptTokens = (Double) properties.getOrDefault(WATSONX_PRICE_PROMPT_TOKES_PER_KILO, 0.0);
watsonxPriceCompleteTokens = (Double) properties.getOrDefault(WATSONX_PRICE_COMPLETE_TOKES_PER_KILO, 0.0);
openaiPricePromptTokens = (Double) properties.getOrDefault(OPENAI_PRICE_PROMPT_TOKES_PER_KILO, 0.0);
Expand All @@ -158,6 +159,19 @@ public LLMDc(Map<String, Object> properties, CustomDcConfig cdcConfig) throws Ex
listenPort = (int) properties.getOrDefault(SERVICE_LISTEN_PORT, 8000);
}

private void logLLMSpecificConfig(Map<String, Object> properties) {
logger.info("LLM Specific Configuration:");
logger.info(" OPENAI_PRICE_PROMPT_TOKES_PER_KILO: " +
properties.getOrDefault(OPENAI_PRICE_PROMPT_TOKES_PER_KILO, "Not set"));
logger.info(" OPENAI_PRICE_COMPLETE_TOKES_PER_KILO: " +
properties.getOrDefault(OPENAI_PRICE_COMPLETE_TOKES_PER_KILO, "Not set"));
// 他の関連する設定値もここでログ出力
Double promptPrice = (Double) properties.getOrDefault("openai.price.prompt.tokens.per.kilo", 0.0);
logger.info("Prompt price per kilo tokens: " + promptPrice);


}

@Override
public void initOnce() throws ClassNotFoundException {
var server = Server.builder()
Expand Down Expand Up @@ -194,61 +208,43 @@ public void registerMetrics() {
@Override
public void collectData() {
logger.info("Start to collect metrics");

for(Map.Entry<String,ModelAggregation> entry : modelAggrMap.entrySet()){
ModelAggregation aggr = entry.getValue();
aggr.resetMetrics();
}

List<OtelMetric> otelMetrics = metricsCollector.getMetrics();
logger.info("Number of metrics received: " + otelMetrics.size());
metricsCollector.clearMetrics();

for (OtelMetric metric : otelMetrics) {
try {
double duration = metric.getDuration();
if(duration == 0.0) {
continue;
}
String modelId = metric.getModelId();
String aiSystem = metric.getAiSystem();
long promptTokens = metric.getPromtTokens();
long completeTokens = metric.getCompleteTokens();
double duration = metric.getDuration();
long requestCount = metric.getReqCount();

ModelAggregation modelAggr = modelAggrMap.get(modelId);
if (modelAggr == null) {
modelAggr = new ModelAggregation(modelId, aiSystem);
modelAggrMap.put(modelId, modelAggr);
}
logger.info("Processing metric - Model: " + modelId + ", AI System: " + aiSystem +
", Prompt Tokens: " + promptTokens + ", Complete Tokens: " + completeTokens +
", Duration: " + duration + ", Request Count: " + requestCount);

ModelAggregation modelAggr = modelAggrMap.computeIfAbsent(modelId, k -> new ModelAggregation(modelId, aiSystem));
modelAggr.addDeltaDuration((long)(duration*1000), requestCount);
modelAggr.addDeltaReqCount(requestCount);
modelAggr.addDeltaPromptTokens(promptTokens, requestCount);
modelAggr.addDeltaCompleteTokens(completeTokens, requestCount);

logger.info("After aggregation - Delta Prompt Tokens: " + modelAggr.getDeltaPromptTokens() +
", Delta Complete Tokens: " + modelAggr.getDeltaCompleteTokens());
} catch (Exception e) {
logger.severe("Error processing metric: " + e.getMessage());
e.printStackTrace();
}
}
for (OtelMetric metric : otelMetrics) {
try {
String modelId = metric.getModelId();
String aiSystem = metric.getAiSystem();
long promptTokens = metric.getPromtTokens();
long completeTokens = metric.getCompleteTokens();
if(promptTokens == 0 && completeTokens == 0) {
continue;
}
ModelAggregation modelAggr = modelAggrMap.get(modelId);
if (modelAggr == null) {
modelAggr = new ModelAggregation(modelId, aiSystem);
modelAggrMap.put(modelId, modelAggr);
}
long currentReqCount = modelAggr.getCurrentReqCount();
if(promptTokens > 0) {
modelAggr.addDeltaPromptTokens(promptTokens, currentReqCount);
}
if(completeTokens > 0) {
modelAggr.addDeltaCompleteTokens(completeTokens, currentReqCount);
}
} catch (Exception e) {
e.printStackTrace();
}
}


logger.info("-----------------------------------------");
for(Map.Entry<String,ModelAggregation> entry : modelAggrMap.entrySet()){
ModelAggregation aggr = entry.getValue();
Expand All @@ -259,48 +255,38 @@ public void collectData() {
long deltaPromptTokens = aggr.getDeltaPromptTokens();
long deltaCompleteTokens = aggr.getDeltaCompleteTokens();
long maxDuration = aggr.getMaxDuration();

long avgDuration = deltaDuration/(deltaRequestCount==0?1:deltaRequestCount);
long avgDuration = deltaRequestCount == 0 ? 0 : deltaDuration / deltaRequestCount;
if(avgDuration > maxDuration) {
maxDuration = avgDuration;
aggr.setMaxDuration(maxDuration);
}

int intervalSeconds = LLM_POLL_INTERVAL;
String agentLess = System.getenv("AGENTLESS_MODE_ENABLED");
if (agentLess != null) {
intervalSeconds = 1;
}

double pricePromptTokens = 0.0;
double priceCompleteTokens = 0.0;
if (aiSystem.compareTo("watsonx") == 0) {
pricePromptTokens = watsonxPricePromptTokens;
priceCompleteTokens = watsonxPriceCompleteTokens;
} else if (aiSystem.compareTo("openai") == 0) {
pricePromptTokens = openaiPricePromptTokens;
priceCompleteTokens = openaiPriceCompleteTokens;
} else if (aiSystem.compareTo("anthropic") == 0) {
pricePromptTokens = anthropicPricePromptTokens;
priceCompleteTokens = anthropicPriceCompleteTokens;
}

double pricePromptTokens = getPricePromptTokens(aiSystem);
double priceCompleteTokens = getPriceCompleteTokens(aiSystem);

double intervalReqCount = (double)deltaRequestCount/intervalSeconds;
double intervalPromptTokens = (double)deltaPromptTokens/intervalSeconds;
double intervalCompleteTokens = (double)deltaCompleteTokens/intervalSeconds;
double intervalTotalTokens = intervalPromptTokens + intervalCompleteTokens;
double intervalPromptCost = (intervalPromptTokens/1000) * pricePromptTokens;
double intervalCompleteCost = (intervalCompleteTokens/1000) * priceCompleteTokens;
double intervalTotalCost = intervalPromptCost + intervalCompleteCost;
aggr.resetMetrics();


logger.info("ModelId : " + modelId);
logger.info("AiSystem : " + aiSystem);
logger.info("AvgDuration : " + avgDuration);
logger.info("MaxDuration : " + maxDuration);
logger.info("IntervalTokens : " + intervalTotalTokens);
logger.info("IntervalCost : " + intervalTotalCost);
logger.info("IntervalRequest : " + intervalReqCount);

Map<String, Object> attributes = new HashMap<>();
attributes.put("model_id", modelId);
attributes.put("ai_system", aiSystem);
Expand All @@ -310,7 +296,28 @@ public void collectData() {
getRawMetric(LLM_COST_NAME).getDataPoint(modelId).setValue(intervalTotalCost, attributes);
getRawMetric(LLM_TOKEN_NAME).getDataPoint(modelId).setValue(intervalTotalTokens, attributes);
getRawMetric(LLM_REQ_COUNT_NAME).getDataPoint(modelId).setValue(intervalReqCount, attributes);

aggr.resetMetrics();
}
logger.info("-----------------------------------------");
}


private double getPricePromptTokens(String aiSystem) {
switch (aiSystem) {
case "watsonx": return watsonxPricePromptTokens;
case "openai": return openaiPricePromptTokens;
case "anthropic": return anthropicPricePromptTokens;
default: return 0.0;
}
}

private double getPriceCompleteTokens(String aiSystem) {
switch (aiSystem) {
case "watsonx": return watsonxPriceCompleteTokens;
case "openai": return openaiPriceCompleteTokens;
case "anthropic": return anthropicPriceCompleteTokens;
default: return 0.0;
}
}
}
Loading