diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Defaults.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Defaults.cs index de2c2a695b6..3f76fe591bd 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Defaults.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Defaults.cs @@ -8,6 +8,9 @@ using System.Text.Json.Nodes; using System.Text.Json.Serialization; using System.Text.Json.Serialization.Metadata; +#if NET +using System.Buffers; +#endif namespace Microsoft.Extensions.AI; @@ -31,7 +34,7 @@ private static JsonSerializerOptions CreateDefaultOptions() JsonSerializerOptions options = new(JsonSerializerDefaults.Web) { TypeInfoResolver = new DefaultJsonTypeInfoResolver(), - Converters = { new JsonStringEnumConverter() }, + Converters = { new JsonStringEnumConverter(), new JsonStringBooleanConverter() }, DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, WriteIndented = true, }; @@ -83,4 +86,60 @@ private static JsonSerializerOptions CreateDefaultOptions() [JsonSerializable(typeof(Embedding))] [JsonSerializable(typeof(AIContent))] private sealed partial class JsonContext : JsonSerializerContext; + + private sealed class JsonStringBooleanConverter : JsonConverter + { + public override bool Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + bool? boolResult = null; + + if (reader.TokenType == JsonTokenType.String) + { +#if NET + static int GetValueLength(ref Utf8JsonReader reader) + => reader.HasValueSequence + ? checked((int)reader.ValueSequence.Length) + : reader.ValueSpan.Length; + + const int StackallocByteThreshold = 256; + const int StackallocCharThreshold = StackallocByteThreshold / 2; + + int bufferLength = GetValueLength(ref reader); + char[]? rentedBuffer = null; + if (bufferLength <= StackallocCharThreshold) + { + rentedBuffer = ArrayPool.Shared.Rent(bufferLength); + } + + Span charBuffer = rentedBuffer ?? stackalloc char[StackallocCharThreshold]; + + int actualLength = reader.CopyString(charBuffer); + ReadOnlySpan stringSpan = charBuffer.Slice(0, actualLength); + + if (bool.TryParse(stringSpan, out var result)) + { + boolResult = result; + } + + if (rentedBuffer != null) + { + charBuffer.Clear(); + ArrayPool.Shared.Return(rentedBuffer); + } +#else + if (bool.TryParse(reader.GetString(), out var result)) + { + boolResult = result; + } +#endif + } + + return boolResult ?? reader.GetBoolean(); + } + + public override void Write(Utf8JsonWriter writer, bool value, JsonSerializerOptions options) + { + writer.WriteBooleanValue(value); + } + } } diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/FunctionCallContentTests..cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/FunctionCallContentTests..cs index 82b1a518aca..4a98e6b7d52 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/FunctionCallContentTests..cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/FunctionCallContentTests..cs @@ -121,18 +121,23 @@ public async Task AIFunctionFactory_JsonElementValues_ValuesDeserialized() Dictionary arguments = JsonSerializer.Deserialize>(""" { "a": ["Monday", "Tuesday", "Wednesday"], - "b": 123.4, + "b1": 123.4, + "b2": "123.4", "c": "072c2d93-7cf6-4d0d-aebc-acc51e6ee7ee", "d": { "property1": "42", "property2": "43", "property3": "44" - } + }, + "e1": true, + "e2": "false" } """, TestJsonSerializerContext.Default.Options)!; Assert.All(arguments.Values, v => Assert.IsType(v)); - AIFunction function = AIFunctionFactory.Create((DayOfWeek[] a, double b, Guid c, Dictionary d) => b, serializerOptions: TestJsonSerializerContext.Default.Options); + AIFunction function = AIFunctionFactory.Create( + (DayOfWeek[] a, double b1, float b2, Guid c, Dictionary d, bool e1, bool e2) => b1, + serializerOptions: TestJsonSerializerContext.Default.Options); var result = await function.InvokeAsync(arguments); AssertExtensions.EqualFunctionCallResults(123.4, result); } diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestJsonSerializerContext.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestJsonSerializerContext.cs index 5a3e966c17b..c8eac798968 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestJsonSerializerContext.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestJsonSerializerContext.cs @@ -2,6 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +#if NET +using System.Buffers; +#endif using System.Collections.Generic; using System.Collections.ObjectModel; using System.Text.Json; @@ -12,7 +15,9 @@ namespace Microsoft.Extensions.AI; [JsonSourceGenerationOptions( PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase, + NumberHandling = JsonNumberHandling.AllowReadingFromString, DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, + Converters = new[] { typeof(JsonStringBooleanConverter) }, UseStringEnumConverter = true)] [JsonSerializable(typeof(ChatCompletion))] [JsonSerializable(typeof(StreamingChatCompletionUpdate))] @@ -28,4 +33,63 @@ namespace Microsoft.Extensions.AI; [JsonSerializable(typeof(DayOfWeek[]))] // Used in Content tests [JsonSerializable(typeof(Guid))] // Used in Content tests [JsonSerializable(typeof(decimal))] // Used in Content tests -internal sealed partial class TestJsonSerializerContext : JsonSerializerContext; +[JsonSerializable(typeof(bool))] // Used in Content tests +[JsonSerializable(typeof(float))] // Used in Content tests +internal sealed partial class TestJsonSerializerContext : JsonSerializerContext +{ + private sealed class JsonStringBooleanConverter : JsonConverter + { + public override bool Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + bool? boolResult = null; + + if (reader.TokenType == JsonTokenType.String) + { +#if NET + static int GetValueLength(ref Utf8JsonReader reader) + => reader.HasValueSequence + ? checked((int)reader.ValueSequence.Length) + : reader.ValueSpan.Length; + + const int StackallocByteThreshold = 256; + const int StackallocCharThreshold = StackallocByteThreshold / 2; + + int bufferLength = GetValueLength(ref reader); + char[]? rentedBuffer = null; + if (bufferLength > StackallocCharThreshold) + { + rentedBuffer = ArrayPool.Shared.Rent(bufferLength); + } + + Span charBuffer = rentedBuffer ?? stackalloc char[StackallocCharThreshold]; + + int actualLength = reader.CopyString(charBuffer); + ReadOnlySpan stringSpan = charBuffer.Slice(0, actualLength); + + if (bool.TryParse(stringSpan, out var result)) + { + boolResult = result; + } + + if (rentedBuffer != null) + { + charBuffer.Clear(); + ArrayPool.Shared.Return(rentedBuffer); + } +#else + if (bool.TryParse(reader.GetString(), out var result)) + { + boolResult = result; + } +#endif + } + + return boolResult ?? reader.GetBoolean(); + } + + public override void Write(Utf8JsonWriter writer, bool value, JsonSerializerOptions options) + { + writer.WriteBooleanValue(value); + } + } +}