Skip to content

Commit

Permalink
Fix GetStructuredResponse GPT-4 responses. Add InvalidJsonException i…
Browse files Browse the repository at this point in the history
…n case of deserialization errors.
  • Loading branch information
rodion-m committed Jul 20, 2023
1 parent 11658b7 commit 11a3368
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 40 deletions.
2 changes: 1 addition & 1 deletion Directory.Build.props
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<Project>
<PropertyGroup>
<Version>2.7.1</Version>
<Version>2.8.0</Version>
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
</PropertyGroup>

Expand Down
17 changes: 17 additions & 0 deletions OpenAI.ChatGpt.Modules.StructuredResponse/InvalidJsonException.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using System.Text.Json;

namespace OpenAI.ChatGpt.Modules.StructuredResponse;

public class InvalidJsonException : Exception
{
public InvalidJsonException(string message, string response, JsonException? jsonException = null)
: base(message, jsonException)
{
Response = response ?? throw new ArgumentNullException(nameof(response));
JsonException = jsonException;
}

public JsonException? JsonException { get; }

public string Response { get; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ public static Task<TObject> GetStructuredResponse<TObject>(
requestModifier: requestModifier,
rawResponseGetter: rawResponseGetter,
jsonDeserializerOptions: jsonDeserializerOptions,
cancellationToken: cancellationToken);
cancellationToken: cancellationToken
);
}

internal static async Task<TObject> GetStructuredResponse<TObject>(
Expand All @@ -80,7 +81,7 @@ internal static async Task<TObject> GetStructuredResponse<TObject>(
ArgumentNullException.ThrowIfNull(dialog);

var editMsg = dialog.GetMessages().FirstOrDefault(it => it is SystemMessage)
?? dialog.GetMessages().First();
?? dialog.GetMessages()[0];
var originalContent = editMsg.Content;
try
{
Expand All @@ -97,25 +98,10 @@ internal static async Task<TObject> GetStructuredResponse<TObject>(
user,
requestModifier,
rawResponseGetter,
cancellationToken);

response = response.Trim();
if(response.StartsWith("```") && response.EndsWith("```"))
{
response = response[3..^3];
}
jsonDeserializerOptions ??= new JsonSerializerOptions
{
PropertyNameCaseInsensitive = true,
Converters = { new JsonStringEnumConverter() }
};
var deserialized = JsonSerializer.Deserialize<TObject>(response, jsonDeserializerOptions);
if (deserialized is null)
{
throw new InvalidOperationException(
$"Failed to deserialize response to {typeof(TObject)}. Response: {response}");
}
cancellationToken
);

var deserialized = DeserializeOrThrow<TObject>(jsonDeserializerOptions, response);
return deserialized;
}
finally
Expand All @@ -124,10 +110,43 @@ internal static async Task<TObject> GetStructuredResponse<TObject>(
}
}

private static TObject DeserializeOrThrow<TObject>(JsonSerializerOptions? jsonDeserializerOptions, string response)
{
ArgumentNullException.ThrowIfNull(response);
response = response.Trim();
if (response.StartsWith("```") && response.EndsWith("```"))
{
response = response[3..^3];
}

jsonDeserializerOptions ??= new JsonSerializerOptions
{
PropertyNameCaseInsensitive = true,
Converters = { new JsonStringEnumConverter() }
};
TObject? deserialized;
try
{
deserialized = JsonSerializer.Deserialize<TObject>(response, jsonDeserializerOptions);
if (deserialized is null)
{
throw new InvalidJsonException(
$"Failed to deserialize response to {typeof(TObject)}. Response: {response}.", response);
}
}
catch (JsonException jsonException)
{
throw new InvalidJsonException(
$"Failed to deserialize response to {typeof(TObject)}. Response: {response}.", response, jsonException);
}

return deserialized;
}

private static string GetAdditionalJsonResponsePrompt(string responseFormat)
{
return$"\n\nWrite your response in compact JSON format with escaped strings. " +
$"Here is the response structure, it is enclosed within double backticks (JSON Schema) ``{responseFormat}``";
$"Here is the response structure (JSON Schema): {responseFormat}";
}

internal static string CreateResponseFormatJson<TObject>()
Expand Down
5 changes: 2 additions & 3 deletions OpenAI.ChatGpt/AsyncEnumerableExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ internal static async IAsyncEnumerable<T> ConfigureExceptions<T>(
bool throwOnCancellation) where T: class
{
ArgumentNullException.ThrowIfNull(stream);
var enumerator = stream.GetAsyncEnumerator();
await using var enumerator = stream.GetAsyncEnumerator();
T? result = null;
var hasResult = true;
while (hasResult)
Expand All @@ -31,8 +31,6 @@ internal static async IAsyncEnumerable<T> ConfigureExceptions<T>(
yield return result;
}
}

await enumerator.DisposeAsync();
}

internal static async IAsyncEnumerable<T> ConfigureExceptions<T>(
Expand Down Expand Up @@ -76,6 +74,7 @@ internal static async IAsyncEnumerable<T> ConfigureExceptions<T>(
}

await DisposeAsyncSafe();
yield break;

async Task DisposeAsyncSafe()
{
Expand Down
31 changes: 16 additions & 15 deletions OpenAI.ChatGpt/Models/ChatGPTConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public class ChatGPTConfig
/// </example>
/// <remarks>
/// If <see cref="InitialSystemMessage"/> and <see cref="InitialUserMessage"/> are provided,
/// then both messages will be send to OpenAI API.
/// then both messages will be sent to OpenAI API.
/// More info about initial message: https://github.com/openai/openai-python/blob/main/chatml.md
/// </remarks>
public string? InitialSystemMessage { get; set; }
Expand All @@ -50,7 +50,7 @@ public class ChatGPTConfig
/// </summary>
/// <remarks>
/// If <see cref="InitialSystemMessage"/> and <see cref="InitialUserMessage"/> are provided,
/// then both messages will be send to OpenAI API.
/// then both messages will be sent to OpenAI API.
/// More info about initial message: https://github.com/openai/openai-python/blob/main/chatml.md
/// </remarks>
public string? InitialUserMessage { get; set; }
Expand Down Expand Up @@ -154,21 +154,22 @@ internal void ModifyRequest(ChatCompletionRequest request)
ChatGPTConfig? baseConfig,
ChatGPTConfig? config)
{
if (baseConfig is null && config is null) return null;
if (baseConfig is null) return config;
if (config is null) return baseConfig;

var result = new ChatGPTConfig()
return (baseConfig, config) switch
{
_model = config._model ?? baseConfig._model,
_maxTokens = config._maxTokens ?? baseConfig._maxTokens,
_temperature = config._temperature ?? baseConfig._temperature,
PassUserIdToOpenAiRequests = config.PassUserIdToOpenAiRequests ??
baseConfig.PassUserIdToOpenAiRequests,
InitialSystemMessage = config.InitialSystemMessage ?? baseConfig.InitialSystemMessage,
InitialUserMessage = config.InitialUserMessage ?? baseConfig.InitialUserMessage
(null, null) => null,
(null, not null) => config,
(not null, null) => baseConfig,
_ => new ChatGPTConfig()
{
_model = config._model ?? baseConfig._model,
_maxTokens = config._maxTokens ?? baseConfig._maxTokens,
_temperature = config._temperature ?? baseConfig._temperature,
PassUserIdToOpenAiRequests = config.PassUserIdToOpenAiRequests ??
baseConfig.PassUserIdToOpenAiRequests,
InitialSystemMessage = config.InitialSystemMessage ?? baseConfig.InitialSystemMessage,
InitialUserMessage = config.InitialUserMessage ?? baseConfig.InitialUserMessage
}
};
return result;
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,17 @@ public async void Get_structured_response_with_extra_data_from_ChatGPT()
//response.Country.Should().Be("Kazakhstan");
}

[Fact]
public async void Get_structured_response_for_tic_tak_toe_from_ChatGPT_GPT4()
{
var message =
Dialog.StartAsSystem("This is a game of tic tac toe. X goes first. Your turn is O. What is your next move? Board: [{\"Row\":0,\"Column\":0},{\"Row\":0,\"Column\":1},{\"Row\":0,\"Column\":2},{\"Row\":1,\"Column\":0},{\"Row\":1,\"Column\":1},{\"Row\":1,\"Column\":2},{\"Row\":2,\"Column\":0},{\"Row\":2,\"Column\":1},{\"Row\":2,\"Column\":2}]");
var response = await _client.GetStructuredResponse<CellPosition>(message, model: ChatCompletionModels.Gpt4);
response.Should().NotBeNull();
}

private record CellPosition(int Row, int Column);

private class Order
{
public UserInfo? UserInfo { get; set; }
Expand Down

0 comments on commit 11a3368

Please sign in to comment.