Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

M.E.AI - Add support for string JsonElement parsing to primitive #1

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
using System.Text.Json.Nodes;
using System.Text.Json.Serialization;
using System.Text.Json.Serialization.Metadata;
#if NET
using System.Buffers;
#endif

namespace Microsoft.Extensions.AI;

Expand All @@ -31,7 +34,7 @@ private static JsonSerializerOptions CreateDefaultOptions()
JsonSerializerOptions options = new(JsonSerializerDefaults.Web)
{
TypeInfoResolver = new DefaultJsonTypeInfoResolver(),
Converters = { new JsonStringEnumConverter() },
Converters = { new JsonStringEnumConverter(), new JsonStringBooleanConverter() },
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
WriteIndented = true,
};
Expand Down Expand Up @@ -83,4 +86,60 @@ private static JsonSerializerOptions CreateDefaultOptions()
[JsonSerializable(typeof(Embedding<double>))]
[JsonSerializable(typeof(AIContent))]
private sealed partial class JsonContext : JsonSerializerContext;

private sealed class JsonStringBooleanConverter : JsonConverter<bool>
{
public override bool Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
bool? boolResult = null;

if (reader.TokenType == JsonTokenType.String)
{
#if NET
static int GetValueLength(ref Utf8JsonReader reader)
=> reader.HasValueSequence
? checked((int)reader.ValueSequence.Length)
: reader.ValueSpan.Length;

const int StackallocByteThreshold = 256;
const int StackallocCharThreshold = StackallocByteThreshold / 2;

int bufferLength = GetValueLength(ref reader);
char[]? rentedBuffer = null;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Valid boolean values are at most 5 characters long. You can avoid pooling buffers if you just check the length of the input before stack allocating a buffer of length 5.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if we care, but if there was whitespace padding around the true/false, it could be longer than 5 chars. bool.Parse trims and thus will happily successfully parse e.g. " true ".

if (bufferLength <= StackallocCharThreshold)
{
rentedBuffer = ArrayPool<char>.Shared.Rent(bufferLength);
}

Span<char> charBuffer = rentedBuffer ?? stackalloc char[StackallocCharThreshold];

int actualLength = reader.CopyString(charBuffer);
ReadOnlySpan<char> stringSpan = charBuffer.Slice(0, actualLength);

if (bool.TryParse(stringSpan, out var result))
{
boolResult = result;
}

if (rentedBuffer != null)
{
charBuffer.Clear();
ArrayPool<char>.Shared.Return(rentedBuffer);
}
#else
if (bool.TryParse(reader.GetString(), out var result))
{
boolResult = result;
}
#endif
}

return boolResult ?? reader.GetBoolean();
}

public override void Write(Utf8JsonWriter writer, bool value, JsonSerializerOptions options)
{
writer.WriteBooleanValue(value);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -121,18 +121,23 @@ public async Task AIFunctionFactory_JsonElementValues_ValuesDeserialized()
Dictionary<string, object?> arguments = JsonSerializer.Deserialize<Dictionary<string, object?>>("""
{
"a": ["Monday", "Tuesday", "Wednesday"],
"b": 123.4,
"b1": 123.4,
"b2": "123.4",
"c": "072c2d93-7cf6-4d0d-aebc-acc51e6ee7ee",
"d": {
"property1": "42",
"property2": "43",
"property3": "44"
}
},
"e1": true,
"e2": "false"
}
""", TestJsonSerializerContext.Default.Options)!;
Assert.All(arguments.Values, v => Assert.IsType<JsonElement>(v));

AIFunction function = AIFunctionFactory.Create((DayOfWeek[] a, double b, Guid c, Dictionary<string, string> d) => b, serializerOptions: TestJsonSerializerContext.Default.Options);
AIFunction function = AIFunctionFactory.Create(
(DayOfWeek[] a, double b1, float b2, Guid c, Dictionary<string, string> d, bool e1, bool e2) => b1,
serializerOptions: TestJsonSerializerContext.Default.Options);
var result = await function.InvokeAsync(arguments);
AssertExtensions.EqualFunctionCallResults(123.4, result);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
#if NET
using System.Buffers;
#endif
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Text.Json;
Expand All @@ -12,7 +15,9 @@ namespace Microsoft.Extensions.AI;

[JsonSourceGenerationOptions(
PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase,
NumberHandling = JsonNumberHandling.AllowReadingFromString,
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
Converters = new[] { typeof(JsonStringBooleanConverter) },
UseStringEnumConverter = true)]
[JsonSerializable(typeof(ChatCompletion))]
[JsonSerializable(typeof(StreamingChatCompletionUpdate))]
Expand All @@ -28,4 +33,63 @@ namespace Microsoft.Extensions.AI;
[JsonSerializable(typeof(DayOfWeek[]))] // Used in Content tests
[JsonSerializable(typeof(Guid))] // Used in Content tests
[JsonSerializable(typeof(decimal))] // Used in Content tests
internal sealed partial class TestJsonSerializerContext : JsonSerializerContext;
[JsonSerializable(typeof(bool))] // Used in Content tests
[JsonSerializable(typeof(float))] // Used in Content tests
internal sealed partial class TestJsonSerializerContext : JsonSerializerContext
{
private sealed class JsonStringBooleanConverter : JsonConverter<bool>
{
public override bool Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
bool? boolResult = null;

if (reader.TokenType == JsonTokenType.String)
{
#if NET
static int GetValueLength(ref Utf8JsonReader reader)
=> reader.HasValueSequence
? checked((int)reader.ValueSequence.Length)
: reader.ValueSpan.Length;

const int StackallocByteThreshold = 256;
const int StackallocCharThreshold = StackallocByteThreshold / 2;

int bufferLength = GetValueLength(ref reader);
char[]? rentedBuffer = null;
if (bufferLength > StackallocCharThreshold)
{
rentedBuffer = ArrayPool<char>.Shared.Rent(bufferLength);
}

Span<char> charBuffer = rentedBuffer ?? stackalloc char[StackallocCharThreshold];

int actualLength = reader.CopyString(charBuffer);
ReadOnlySpan<char> stringSpan = charBuffer.Slice(0, actualLength);

if (bool.TryParse(stringSpan, out var result))
{
boolResult = result;
}

if (rentedBuffer != null)
{
charBuffer.Clear();
ArrayPool<char>.Shared.Return(rentedBuffer);
}
#else
if (bool.TryParse(reader.GetString(), out var result))
{
boolResult = result;
}
#endif
}

return boolResult ?? reader.GetBoolean();
}

public override void Write(Utf8JsonWriter writer, bool value, JsonSerializerOptions options)
{
writer.WriteBooleanValue(value);
}
}
}