Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.Net: [WIP] Gemini - new Function Calling #9597

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
<!-- IMPORT NUGET PACKAGE SHARED PROPERTIES -->
<Import Project="$(RepoRoot)/dotnet/nuget/nuget-package.props" />
<Import Project="$(RepoRoot)/dotnet/src/InternalUtilities/src/InternalUtilities.props" />
<Import Project="$(RepoRoot)/dotnet/src/InternalUtilities/connectors/AI/FunctionCalling/FunctionCallingUtilities.props"/>

<PropertyGroup>
<!-- NuGet Package Settings -->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.FunctionCalling;
using Microsoft.SemanticKernel.Diagnostics;
using Microsoft.SemanticKernel.Http;
using Microsoft.SemanticKernel.Text;
Expand All @@ -28,30 +29,10 @@ internal sealed class GeminiChatCompletionClient : ClientBase
private readonly string _modelId;
private readonly Uri _chatGenerationEndpoint;
private readonly Uri _chatStreamingEndpoint;
private readonly FunctionCallsProcessor _functionCallsProcessor;

private static readonly string s_namespace = typeof(GoogleAIGeminiChatCompletionService).Namespace!;

/// <summary>
/// The maximum number of auto-invokes that can be in-flight at any given time as part of the current
/// asynchronous chain of execution.
/// </summary>
/// <remarks>
/// This is a fail-safe mechanism. If someone accidentally manages to set up execution settings in such a way that
/// auto-invocation is invoked recursively, and in particular where a prompt function is able to auto-invoke itself,
/// we could end up in an infinite loop. This const is a backstop against that happening. We should never come close
/// to this limit, but if we do, auto-invoke will be disabled for the current flow in order to prevent runaway execution.
/// With the current setup, the way this could possibly happen is if a prompt function is configured with built-in
/// execution settings that opt-in to auto-invocation of everything in the kernel, in which case the invocation of that
/// prompt function could advertise itself as a candidate for auto-invocation. We don't want to outright block that,
/// if that's something a developer has asked to do (e.g. it might be invoked with different arguments than its parent
/// was invoked with), but we do want to limit it. This limit is arbitrary and can be tweaked in the future and/or made
/// configurable should need arise.
/// </remarks>
private const int MaxInflightAutoInvokes = 128;

/// <summary>Tracking <see cref="AsyncLocal{Int32}"/> for <see cref="MaxInflightAutoInvokes"/>.</summary>
private static readonly AsyncLocal<int> s_inflightAutoInvokes = new();

/// <summary>
/// Instance of <see cref="Meter"/> for metrics.
/// </summary>
Expand Down Expand Up @@ -84,6 +65,12 @@ internal sealed class GeminiChatCompletionClient : ClientBase
unit: "{token}",
description: "Number of tokens used");

private sealed record ToolCallingConfig(
IList<GeminiTool.FunctionDeclaration>? Tools,
GeminiFunctionCallingMode? Mode,
bool AutoInvoke,
FunctionChoiceBehaviorOptions? Options);

/// <summary>
/// Represents a client for interacting with the chat completion Gemini model via GoogleAI.
/// </summary>
Expand All @@ -108,6 +95,7 @@ public GeminiChatCompletionClient(
string versionSubLink = GetApiVersionSubLink(apiVersion);

this._modelId = modelId;
this._functionCallsProcessor = new FunctionCallsProcessor(this.Logger);
this._chatGenerationEndpoint = new Uri($"https://generativelanguage.googleapis.com/{versionSubLink}/models/{this._modelId}:generateContent?key={apiKey}");
this._chatStreamingEndpoint = new Uri($"https://generativelanguage.googleapis.com/{versionSubLink}/models/{this._modelId}:streamGenerateContent?key={apiKey}&alt=sse");
}
Expand Down Expand Up @@ -142,6 +130,7 @@ public GeminiChatCompletionClient(
string versionSubLink = GetApiVersionSubLink(apiVersion);

this._modelId = modelId;
this._functionCallsProcessor = new FunctionCallsProcessor(this.Logger);
this._chatGenerationEndpoint = new Uri($"https://{location}-aiplatform.googleapis.com/{versionSubLink}/projects/{projectId}/locations/{location}/publishers/google/models/{this._modelId}:generateContent");
this._chatStreamingEndpoint = new Uri($"https://{location}-aiplatform.googleapis.com/{versionSubLink}/projects/{projectId}/locations/{location}/publishers/google/models/{this._modelId}:streamGenerateContent?alt=sse");
}
Expand All @@ -162,11 +151,16 @@ public async Task<IReadOnlyList<ChatMessageContent>> GenerateChatMessageAsync(
{
var state = this.ValidateInputAndCreateChatCompletionState(chatHistory, kernel, executionSettings);

for (state.Iteration = 1; ; state.Iteration++)
for (state.RequestIndex = 0;; state.RequestIndex++)
{
// TODO: do something with this variable
var functionCallingConfig = this.GetFunctionCallingConfiguration(state);

// TODO: Here should be request created not above loop

List<GeminiChatMessageContent> chatResponses;
using (var activity = ModelDiagnostics.StartCompletionActivity(
this._chatGenerationEndpoint, this._modelId, ModelProvider, chatHistory, state.ExecutionSettings))
this._chatGenerationEndpoint, this._modelId, ModelProvider, chatHistory, state.ExecutionSettings))
{
GeminiResponse geminiResponse;
try
Expand All @@ -190,22 +184,38 @@ public async Task<IReadOnlyList<ChatMessageContent>> GenerateChatMessageAsync(

// If we don't want to attempt to invoke any functions, just return the result.
// Or if we are auto-invoking but we somehow end up with other than 1 choice even though only 1 was requested, similarly bail.
if (!state.AutoInvoke || chatResponses.Count != 1)
if (!state.AutoInvoke || chatResponses.Count == 0)
{
return chatResponses;
}

state.LastMessage = chatResponses[0];
// TODO: will ToolCalls property shoul be removed from GeminiChatMessageContent?
if (state.LastMessage.ToolCalls is null)
{
return chatResponses;
}

// ToolCallBehavior is not null because we are in auto-invoke mode but we check it again to be sure it wasn't changed in the meantime
Verify.NotNull(state.ExecutionSettings.ToolCallBehavior);

state.AddLastMessageToChatHistoryAndRequest();
await this.ProcessFunctionsAsync(state, cancellationToken).ConfigureAwait(false);
// TODO: to remove?
// state.AddLastMessageToChatHistoryAndRequest();

// Process function calls by invoking the functions and adding the results to the chat history.
// Each function call will trigger auto-function-invocation filters, which can terminate the process.
// In such cases, we'll return the last message in the chat history.
var lastMessage = await this._functionCallsProcessor.ProcessFunctionCallsAsync(
state.LastMessage,
chatHistory,
state.RequestIndex,
content => IsRequestableTool(state.LastMessage.ToolCalls, content),
functionCallingConfig.Options ?? new FunctionChoiceBehaviorOptions(),
kernel,
isStreaming: false,
cancellationToken).ConfigureAwait(false);

if (lastMessage != null)
{
return [lastMessage];
}
}
}

Expand All @@ -225,10 +235,10 @@ public async IAsyncEnumerable<StreamingChatMessageContent> StreamGenerateChatMes
{
var state = this.ValidateInputAndCreateChatCompletionState(chatHistory, kernel, executionSettings);

for (state.Iteration = 1; ; state.Iteration++)
for (state.RequestIndex = 1;; state.RequestIndex++)
{
using (var activity = ModelDiagnostics.StartCompletionActivity(
this._chatGenerationEndpoint, this._modelId, ModelProvider, chatHistory, state.ExecutionSettings))
this._chatGenerationEndpoint, this._modelId, ModelProvider, chatHistory, state.ExecutionSettings))
{
HttpResponseMessage? httpResponseMessage = null;
Stream? responseStream = null;
Expand Down Expand Up @@ -292,6 +302,79 @@ public async IAsyncEnumerable<StreamingChatMessageContent> StreamGenerateChatMes
}
}

private ToolCallingConfig GetFunctionCallingConfiguration(ChatCompletionState state)
{
// If neither behavior is specified, we just return default configuration with no tool and no choice
if (state.ExecutionSettings.FunctionChoiceBehavior is null)
{
return new ToolCallingConfig(Tools: null, Mode: null, AutoInvoke: false, Options: null);
}

return this.ConfigureFunctionCalling(state);
}

private ToolCallingConfig ConfigureFunctionCalling(ChatCompletionState state)
{
var config =
this._functionCallsProcessor.GetConfiguration(state.ExecutionSettings.FunctionChoiceBehavior, state.ChatHistory, state.RequestIndex, state.Kernel);

IList<GeminiTool.FunctionDeclaration>? tools = null;
GeminiFunctionCallingMode? toolMode = null;
bool autoInvoke = config?.AutoInvoke ?? false;

if (config?.Functions is { Count: > 0 } functions)
{
if (config.Choice == FunctionChoice.Auto)
{
toolMode = GeminiFunctionCallingMode.Default;
}
else if (config.Choice == FunctionChoice.Required)
{
toolMode = GeminiFunctionCallingMode.Any;
}
else if (config.Choice == FunctionChoice.None)
{
toolMode = GeminiFunctionCallingMode.None;
}
else
{
throw new NotSupportedException($"Unsupported function choice '{config.Choice}'.");
}

tools = [];

foreach (var function in functions)
{
tools.Add(function.Metadata.ToOpenAIFunction().ToFunctionDefinition());
}
}

return new ToolCallingConfig(
Tools: tools,
Mode: toolMode ?? GeminiFunctionCallingMode.None,
AutoInvoke: autoInvoke,
Options: config?.Options);
}

/// <summary>Checks if a tool call is for a function that was defined.</summary>
private static bool IsRequestableTool(IReadOnlyList<GeminiFunctionToolCall> tools, FunctionCallContent functionCallContent)
{
foreach (var tool in tools)
{
if (string.Equals(tool.FunctionName,
FunctionName.ToFullyQualifiedName(
functionCallContent.FunctionName,
functionCallContent.PluginName,
GeminiFunction.NameSeparator),
StringComparison.OrdinalIgnoreCase))
{
return true;
}
}

return false;
}

private ChatCompletionState ValidateInputAndCreateChatCompletionState(
ChatHistory chatHistory,
Kernel? kernel,
Expand Down Expand Up @@ -391,7 +474,7 @@ private async Task ProcessFunctionsAsync(ChatCompletionState state, Cancellation
// Clear the tools. If we end up wanting to use tools, we'll reset it to the desired value.
state.GeminiRequest.Tools = null;

if (state.Iteration >= state.ExecutionSettings.ToolCallBehavior!.MaximumUseAttempts)
if (state.RequestIndex >= state.ExecutionSettings.ToolCallBehavior!.MaximumUseAttempts)
{
// Don't add any tools as we've reached the maximum attempts limit.
if (this.Logger.IsEnabled(LogLevel.Debug))
Expand All @@ -408,7 +491,7 @@ private async Task ProcessFunctionsAsync(ChatCompletionState state, Cancellation
}

// Disable auto invocation if we've exceeded the allowed limit.
if (state.Iteration >= state.ExecutionSettings.ToolCallBehavior!.MaximumAutoInvokeAttempts)
if (state.RequestIndex >= state.ExecutionSettings.ToolCallBehavior!.MaximumAutoInvokeAttempts)
{
state.AutoInvoke = false;
if (this.Logger.IsEnabled(LogLevel.Debug))
Expand Down Expand Up @@ -481,11 +564,6 @@ private async Task<GeminiResponse> SendRequestAndReturnValidGeminiResponseAsync(
return geminiResponse;
}

/// <summary>Checks if a tool call is for a function that was defined.</summary>
private static bool IsRequestableTool(IEnumerable<GeminiTool.FunctionDeclaration> functions, GeminiFunctionToolCall ftc)
=> functions.Any(geminiFunction =>
string.Equals(geminiFunction.Name, ftc.FullyQualifiedName, StringComparison.OrdinalIgnoreCase));

private void AddToolResponseMessage(
ChatHistory chat,
GeminiRequest request,
Expand Down Expand Up @@ -591,8 +669,8 @@ private void LogUsage(List<GeminiChatMessageContent> chatMessageContents)
}

private List<GeminiChatMessageContent> GetChatMessageContentsFromResponse(GeminiResponse geminiResponse)
=> geminiResponse.Candidates == null ?
[new GeminiChatMessageContent(role: AuthorRole.Assistant, content: string.Empty, modelId: this._modelId)]
=> geminiResponse.Candidates == null
? [new GeminiChatMessageContent(role: AuthorRole.Assistant, content: string.Empty, modelId: this._modelId)]
: geminiResponse.Candidates.Select(candidate => this.GetChatMessageContentFromCandidate(geminiResponse, candidate)).ToList();

private GeminiChatMessageContent GetChatMessageContentFromCandidate(GeminiResponse geminiResponse, GeminiResponseCandidate candidate)
Expand Down Expand Up @@ -663,17 +741,17 @@ private static void ValidateAutoInvoke(bool autoInvoke, int resultsPerPrompt)
private static GeminiMetadata GetResponseMetadata(
GeminiResponse geminiResponse,
GeminiResponseCandidate candidate) => new()
{
FinishReason = candidate.FinishReason,
Index = candidate.Index,
PromptTokenCount = geminiResponse.UsageMetadata?.PromptTokenCount ?? 0,
CurrentCandidateTokenCount = candidate.TokenCount,
CandidatesTokenCount = geminiResponse.UsageMetadata?.CandidatesTokenCount ?? 0,
TotalTokenCount = geminiResponse.UsageMetadata?.TotalTokenCount ?? 0,
PromptFeedbackBlockReason = geminiResponse.PromptFeedback?.BlockReason,
PromptFeedbackSafetyRatings = geminiResponse.PromptFeedback?.SafetyRatings.ToList(),
ResponseSafetyRatings = candidate.SafetyRatings?.ToList(),
};
{
FinishReason = candidate.FinishReason,
Index = candidate.Index,
PromptTokenCount = geminiResponse.UsageMetadata?.PromptTokenCount ?? 0,
CurrentCandidateTokenCount = candidate.TokenCount,
CandidatesTokenCount = geminiResponse.UsageMetadata?.CandidatesTokenCount ?? 0,
TotalTokenCount = geminiResponse.UsageMetadata?.TotalTokenCount ?? 0,
PromptFeedbackBlockReason = geminiResponse.PromptFeedback?.BlockReason,
PromptFeedbackSafetyRatings = geminiResponse.PromptFeedback?.SafetyRatings.ToList(),
ResponseSafetyRatings = candidate.SafetyRatings?.ToList(),
};

private sealed class ChatCompletionState
{
Expand All @@ -682,7 +760,7 @@ private sealed class ChatCompletionState
internal Kernel Kernel { get; set; } = null!;
internal GeminiPromptExecutionSettings ExecutionSettings { get; set; } = null!;
internal GeminiChatMessageContent? LastMessage { get; set; }
internal int Iteration { get; set; }
internal int RequestIndex { get; set; }
internal bool AutoInvoke { get; set; }

internal void AddLastMessageToChatHistoryAndRequest()
Expand Down
Loading
Loading