Skip to content

Commit

Permalink
Add extension method GetStructuredResponse for IOpenAiClient
Browse files Browse the repository at this point in the history
  • Loading branch information
rodion-m committed Jul 2, 2023
1 parent 16d08e8 commit c546c07
Show file tree
Hide file tree
Showing 16 changed files with 448 additions and 26 deletions.
4 changes: 2 additions & 2 deletions OpenAI.ChatGpt.AspNetCore/ChatGPTFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,9 @@ public async Task<ChatGPT> Create(

public void Dispose()
{
if (!_isHttpClientInjected)
if (!_isHttpClientInjected && _client is IDisposable disposableClient)
{
_client.Dispose();
disposableClient.Dispose();
}
}
}
14 changes: 11 additions & 3 deletions OpenAI.ChatGpt.Modules.Translator/ChatGPTTranslatorService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public ChatGPTTranslatorService(

public ChatGPTTranslatorService(
string apiKey,
string? host,
string? host = null,
string? defaultSourceLanguage = null,
string? defaultTargetLanguage = null,
string? extraPrompt = null)
Expand All @@ -39,9 +39,9 @@ public ChatGPTTranslatorService(

public void Dispose()
{
if (!_isHttpClientInjected)
if (!_isHttpClientInjected && _client is IDisposable disposableClient)
{
_client.Dispose();
disposableClient.Dispose();
}
}

Expand All @@ -55,6 +55,14 @@ public async Task<string> Translate(
if (text == null) throw new ArgumentNullException(nameof(text));
var sourceLanguageOrDefault = sourceLanguage ?? _defaultSourceLanguage;
var targetLanguageOrDefault = targetLanguage ?? _defaultTargetLanguage;
if (sourceLanguageOrDefault is null)
{
throw new ArgumentNullException(nameof(sourceLanguage), "Source language is not specified");
}
if (targetLanguageOrDefault is null)
{
throw new ArgumentNullException(nameof(targetLanguage), "Target language is not specified");
}
var prompt = GetPrompt(sourceLanguageOrDefault, targetLanguageOrDefault);
var response = await _client.GetChatCompletions(
Dialog.StartAsSystem(prompt).ThenUser(text),
Expand Down
4 changes: 2 additions & 2 deletions OpenAI.ChatGpt/ChatGPT.cs
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ public void Dispose()
{
Stop();
_currentChat?.Dispose();
if (!_isClientInjected)
if (!_isClientInjected && _client is IDisposable disposableClient)
{
_client.Dispose();
disposableClient.Dispose();
}
}

Expand Down
18 changes: 15 additions & 3 deletions OpenAI.ChatGpt/ChatService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,16 @@ private async Task<string> GetNextMessageResponse(
cancellationToken = _cts.Token;

var history = await LoadHistory(cancellationToken);
var messages = history.Append(message);
var messages = history.Append(message).ToArray();

IsWriting = true;
try
{
var (model, maxTokens) = FindOptimalModelAndMaxToken(messages);
var response = await _client.GetChatCompletionsRaw(
messages,
maxTokens: maxTokens,
model: model,
user: Topic.Config.PassUserIdToOpenAiRequests is true ? UserId : null,
requestModifier: Topic.Config.ModifyRequest,
cancellationToken: cancellationToken
Expand All @@ -117,7 +120,13 @@ await _chatHistoryStorage.SaveMessages(
IsWriting = false;
}
}


private (string model, int maxTokens) FindOptimalModelAndMaxToken(ChatCompletionMessage[] messages)
{
return ChatCompletionMessage.FindOptimalModelAndMaxToken(
messages, Topic.Config.Model, Topic.Config.MaxTokens);
}

public IAsyncEnumerable<string> StreamNextMessageResponse(
string message,
bool throwOnCancellation = true,
Expand All @@ -143,11 +152,14 @@ private async IAsyncEnumerable<string> StreamNextMessageResponse(
cancellationToken = _cts.Token;

var history = await LoadHistory(cancellationToken);
var messages = history.Append(message);
var messages = history.Append(message).ToArray();
var sb = new StringBuilder();
IsWriting = true;
var (model, maxTokens) = FindOptimalModelAndMaxToken(messages);
var stream = _client.StreamChatCompletions(
messages,
maxTokens: maxTokens,
model: model,
user: Topic.Config.PassUserIdToOpenAiRequests is true ? UserId : null,
requestModifier: Topic.Config.ModifyRequest,
cancellationToken: cancellationToken
Expand Down
10 changes: 5 additions & 5 deletions OpenAI.ChatGpt/IOpenAiClient.cs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
using System.Runtime.CompilerServices;
using OpenAI.ChatGpt.Models.ChatCompletion;
using OpenAI.ChatGpt.Models.ChatCompletion;
using OpenAI.ChatGpt.Models.ChatCompletion.Messaging;

namespace OpenAI.ChatGpt;

public interface IOpenAiClient : IDisposable
public interface IOpenAiClient
{
Task<string> GetChatCompletions(
UserOrSystemMessage dialog,
Expand All @@ -13,6 +12,7 @@ Task<string> GetChatCompletions(
float temperature = ChatCompletionTemperatures.Default,
string? user = null,
Action<ChatCompletionRequest>? requestModifier = null,
Action<ChatCompletionResponse>? rawResponseGetter = null,
CancellationToken cancellationToken = default);

Task<string> GetChatCompletions(
Expand All @@ -22,6 +22,7 @@ Task<string> GetChatCompletions(
float temperature = ChatCompletionTemperatures.Default,
string? user = null,
Action<ChatCompletionRequest>? requestModifier = null,
Action<ChatCompletionResponse>? rawResponseGetter = null,
CancellationToken cancellationToken = default);

Task<ChatCompletionResponse> GetChatCompletionsRaw(
Expand Down Expand Up @@ -81,8 +82,7 @@ IAsyncEnumerable<string> StreamChatCompletions(
CancellationToken cancellationToken = default);

IAsyncEnumerable<string> StreamChatCompletions(
ChatCompletionRequest request,
[EnumeratorCancellation] CancellationToken cancellationToken = default);
ChatCompletionRequest request,CancellationToken cancellationToken = default);

IAsyncEnumerable<ChatCompletionResponse> StreamChatCompletionsRaw(
ChatCompletionRequest request, CancellationToken cancellationToken = default);
Expand Down
4 changes: 2 additions & 2 deletions OpenAI.ChatGpt/Models/ChatCompletion/ChatCompletionModels.cs
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ public static void EnsureMaxTokensIsSupported(string model, int maxTokens)
{
throw new ArgumentOutOfRangeException(
nameof(maxTokens),
$"Max tokens must be less than or equal to {limit} for model {model}"
$"Max tokens must be less than or equal to {limit} for model {model} but was {maxTokens}"
);
}
}
Expand All @@ -210,7 +210,7 @@ public static void EnsureMaxTokensIsSupportedByAnyModel(int maxTokens)
if (maxTokens > limit)
{
throw new ArgumentOutOfRangeException(
nameof(maxTokens), $"Max tokens must be less than or equal to {limit}");
nameof(maxTokens), $"Max tokens must be less than or equal to {limit} but was {maxTokens}");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,34 @@ public override string ToString()
? $"{Role}: {Content}"
: string.Join(Environment.NewLine, _messages.Select(m => $"{m.Role}: {m.Content}"));
}

public static (string model, int maxTokens) FindOptimalModelAndMaxToken(
IEnumerable<ChatCompletionMessage> messages,
string? model,
int? maxTokens,
string smallModel = ChatCompletionModels.Default,
string bigModel = ChatCompletionModels.Gpt3_5_Turbo_16k,
bool useMaxPossibleTokens = true)
{
var tokenCount = CalculateApproxTotalTokenCount(messages);
switch (model, maxTokens)
{
case (null, null):
{
model = tokenCount > 6000 ? bigModel : smallModel;
maxTokens = GetMaxPossibleTokens(model);
break;
}
case (null, _):
model = smallModel;
break;
case (_, null):
maxTokens = useMaxPossibleTokens ? GetMaxPossibleTokens(model) : ChatCompletionRequest.MaxTokensDefault;
break;
}

return (model, maxTokens.Value);

int GetMaxPossibleTokens(string s) => ChatCompletionModels.GetMaxTokensLimitForModel(s) - tokenCount - 500;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace OpenAI.ChatGpt;

/// <summary> Thread-safe OpenAI client. </summary>
[Fody.ConfigureAwait(false)]
public class OpenAiClient : IDisposable, IOpenAiClient
public class OpenAiClient : IOpenAiClient, IDisposable
{
private const string DefaultHost = "https://api.openai.com/v1/";
private const string ImagesEndpoint = "images/generations";
Expand Down Expand Up @@ -122,6 +122,7 @@ public async Task<string> GetChatCompletions(
float temperature = ChatCompletionTemperatures.Default,
string? user = null,
Action<ChatCompletionRequest>? requestModifier = null,
Action<ChatCompletionResponse>? rawResponseGetter = null,
CancellationToken cancellationToken = default)
{
if (dialog == null) throw new ArgumentNullException(nameof(dialog));
Expand All @@ -135,6 +136,7 @@ public async Task<string> GetChatCompletions(
requestModifier
);
var response = await GetChatCompletionsRaw(request, cancellationToken);
rawResponseGetter?.Invoke(response);
return response.Choices[0].Message!.Content;
}

Expand All @@ -145,6 +147,7 @@ public async Task<string> GetChatCompletions(
float temperature = ChatCompletionTemperatures.Default,
string? user = null,
Action<ChatCompletionRequest>? requestModifier = null,
Action<ChatCompletionResponse>? rawResponseGetter = null,
CancellationToken cancellationToken = default)
{
if (messages == null) throw new ArgumentNullException(nameof(messages));
Expand All @@ -158,6 +161,7 @@ public async Task<string> GetChatCompletions(
requestModifier
);
var response = await GetChatCompletionsRaw(request, cancellationToken);
rawResponseGetter?.Invoke(response);
return response.GetMessageContent();
}

Expand Down
Loading

0 comments on commit c546c07

Please sign in to comment.