Skip to content

Commit ae724c1

Browse files
eiriktsarpalisstephentoub
authored andcommitted
Disable STJ default reflection and fix a number of failing tests. (dotnet#6241)
1 parent f09ec8c commit ae724c1

File tree

12 files changed

+144
-72
lines changed

12 files changed

+144
-72
lines changed

src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Defaults.cs

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -42,30 +42,18 @@ public static partial class AIJsonUtilities
4242
[UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026", Justification = "DefaultJsonTypeInfoResolver is only used when reflection-based serialization is enabled")]
4343
private static JsonSerializerOptions CreateDefaultOptions()
4444
{
45-
// If reflection-based serialization is enabled by default, use it, as it's the most permissive in terms of what it can serialize,
46-
// and we want to be flexible in terms of what can be put into the various collections in the object model.
47-
// Otherwise, use the source-generated options to enable trimming and Native AOT.
48-
JsonSerializerOptions options;
45+
// Copy configuration from the source generated context.
46+
JsonSerializerOptions options = new(JsonContext.Default.Options)
47+
{
48+
Encoder = JavaScriptEncoder.UnsafeRelaxedJsonEscaping,
49+
};
4950

5051
if (JsonSerializer.IsReflectionEnabledByDefault)
5152
{
52-
// Keep in sync with the JsonSourceGenerationOptions attribute on JsonContext below.
53-
options = new(JsonSerializerDefaults.Web)
54-
{
55-
TypeInfoResolver = new DefaultJsonTypeInfoResolver(),
56-
Converters = { new JsonStringEnumConverter() },
57-
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
58-
Encoder = JavaScriptEncoder.UnsafeRelaxedJsonEscaping,
59-
WriteIndented = true,
60-
};
61-
}
62-
else
63-
{
64-
options = new(JsonContext.Default.Options)
65-
{
66-
// Compile-time encoder setting not yet available
67-
Encoder = JavaScriptEncoder.UnsafeRelaxedJsonEscaping,
68-
};
53+
// If reflection-based serialization is enabled by default, use it as a fallback for all other types.
54+
// Also turn on string-based enum serialization for all unknown enums.
55+
options.TypeInfoResolverChain.Add(new DefaultJsonTypeInfoResolver());
56+
options.Converters.Add(new JsonStringEnumConverter());
6957
}
7058

7159
options.MakeReadOnly();
@@ -83,6 +71,8 @@ private static JsonSerializerOptions CreateDefaultOptions()
8371
[JsonSerializable(typeof(SpeechToTextResponseUpdate))]
8472
[JsonSerializable(typeof(IReadOnlyList<SpeechToTextResponseUpdate>))]
8573
[JsonSerializable(typeof(IList<ChatMessage>))]
74+
[JsonSerializable(typeof(IEnumerable<ChatMessage>))]
75+
[JsonSerializable(typeof(ChatMessage[]))]
8676
[JsonSerializable(typeof(ChatOptions))]
8777
[JsonSerializable(typeof(EmbeddingGenerationOptions))]
8878
[JsonSerializable(typeof(ChatClientMetadata))]
@@ -95,14 +85,24 @@ private static JsonSerializerOptions CreateDefaultOptions()
9585
[JsonSerializable(typeof(JsonDocument))]
9686
[JsonSerializable(typeof(JsonElement))]
9787
[JsonSerializable(typeof(JsonNode))]
88+
[JsonSerializable(typeof(JsonObject))]
89+
[JsonSerializable(typeof(JsonValue))]
90+
[JsonSerializable(typeof(JsonArray))]
9891
[JsonSerializable(typeof(IEnumerable<string>))]
92+
[JsonSerializable(typeof(char))]
9993
[JsonSerializable(typeof(string))]
10094
[JsonSerializable(typeof(int))]
95+
[JsonSerializable(typeof(short))]
10196
[JsonSerializable(typeof(long))]
97+
[JsonSerializable(typeof(uint))]
98+
[JsonSerializable(typeof(ushort))]
99+
[JsonSerializable(typeof(ulong))]
102100
[JsonSerializable(typeof(float))]
103101
[JsonSerializable(typeof(double))]
102+
[JsonSerializable(typeof(decimal))]
104103
[JsonSerializable(typeof(bool))]
105104
[JsonSerializable(typeof(TimeSpan))]
105+
[JsonSerializable(typeof(DateTime))]
106106
[JsonSerializable(typeof(DateTimeOffset))]
107107
[JsonSerializable(typeof(Embedding))]
108108
[JsonSerializable(typeof(Embedding<byte>))]

src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -487,9 +487,7 @@ static bool IsAsyncMethod(MethodInfo method)
487487
Throw.ArgumentException(nameof(parameter), "Parameter is missing a name.");
488488
}
489489

490-
// Resolve the contract used to marshal the value from JSON -- can throw if not supported or not found.
491490
Type parameterType = parameter.ParameterType;
492-
JsonTypeInfo typeInfo = serializerOptions.GetTypeInfo(parameterType);
493491

494492
// For CancellationToken parameters, we always bind to the token passed directly to InvokeAsync.
495493
if (parameterType == typeof(CancellationToken))
@@ -530,6 +528,8 @@ static bool IsAsyncMethod(MethodInfo method)
530528
}
531529

532530
// For all other parameters, create a marshaller that tries to extract the value from the arguments dictionary.
531+
// Resolve the contract used to marshal the value from JSON -- can throw if not supported or not found.
532+
JsonTypeInfo typeInfo = serializerOptions.GetTypeInfo(parameterType);
533533
return (arguments, _) =>
534534
{
535535
// If the parameter has an argument specified in the dictionary, return that argument.
@@ -636,14 +636,22 @@ static bool IsAsyncMethod(MethodInfo method)
636636
if (returnType.GetGenericTypeDefinition() == typeof(Task<>))
637637
{
638638
MethodInfo taskResultGetter = GetMethodFromGenericMethodDefinition(returnType, _taskGetResult);
639+
if (marshalResult is not null)
640+
{
641+
return async (taskObj, cancellationToken) =>
642+
{
643+
await ((Task)ThrowIfNullResult(taskObj)).ConfigureAwait(false);
644+
object? result = ReflectionInvoke(taskResultGetter, taskObj, null);
645+
return await marshalResult(result, taskResultGetter.ReturnType, cancellationToken).ConfigureAwait(false);
646+
};
647+
}
648+
639649
returnTypeInfo = serializerOptions.GetTypeInfo(taskResultGetter.ReturnType);
640650
return async (taskObj, cancellationToken) =>
641651
{
642652
await ((Task)ThrowIfNullResult(taskObj)).ConfigureAwait(false);
643653
object? result = ReflectionInvoke(taskResultGetter, taskObj, null);
644-
return marshalResult is not null ?
645-
await marshalResult(result, returnTypeInfo.Type, cancellationToken).ConfigureAwait(false) :
646-
await SerializeResultAsync(result, returnTypeInfo, cancellationToken).ConfigureAwait(false);
654+
return await SerializeResultAsync(result, returnTypeInfo, cancellationToken).ConfigureAwait(false);
647655
};
648656
}
649657

@@ -652,24 +660,37 @@ await marshalResult(result, returnTypeInfo.Type, cancellationToken).ConfigureAwa
652660
{
653661
MethodInfo valueTaskAsTask = GetMethodFromGenericMethodDefinition(returnType, _valueTaskAsTask);
654662
MethodInfo asTaskResultGetter = GetMethodFromGenericMethodDefinition(valueTaskAsTask.ReturnType, _taskGetResult);
663+
664+
if (marshalResult is not null)
665+
{
666+
return async (taskObj, cancellationToken) =>
667+
{
668+
var task = (Task)ReflectionInvoke(valueTaskAsTask, ThrowIfNullResult(taskObj), null)!;
669+
await task.ConfigureAwait(false);
670+
object? result = ReflectionInvoke(asTaskResultGetter, task, null);
671+
return await marshalResult(result, asTaskResultGetter.ReturnType, cancellationToken).ConfigureAwait(false);
672+
};
673+
}
674+
655675
returnTypeInfo = serializerOptions.GetTypeInfo(asTaskResultGetter.ReturnType);
656676
return async (taskObj, cancellationToken) =>
657677
{
658678
var task = (Task)ReflectionInvoke(valueTaskAsTask, ThrowIfNullResult(taskObj), null)!;
659679
await task.ConfigureAwait(false);
660680
object? result = ReflectionInvoke(asTaskResultGetter, task, null);
661-
return marshalResult is not null ?
662-
await marshalResult(result, returnTypeInfo.Type, cancellationToken).ConfigureAwait(false) :
663-
await SerializeResultAsync(result, returnTypeInfo, cancellationToken).ConfigureAwait(false);
681+
return await SerializeResultAsync(result, returnTypeInfo, cancellationToken).ConfigureAwait(false);
664682
};
665683
}
666684
}
667685

668686
// For everything else, just serialize the result as-is.
687+
if (marshalResult is not null)
688+
{
689+
return (result, cancellationToken) => marshalResult(result, returnType, cancellationToken);
690+
}
691+
669692
returnTypeInfo = serializerOptions.GetTypeInfo(returnType);
670-
return marshalResult is not null ?
671-
(result, cancellationToken) => marshalResult(result, returnTypeInfo.Type, cancellationToken) :
672-
(result, cancellationToken) => SerializeResultAsync(result, returnTypeInfo, cancellationToken);
693+
return (result, cancellationToken) => SerializeResultAsync(result, returnTypeInfo, cancellationToken);
673694

674695
static async ValueTask<object?> SerializeResultAsync(object? result, JsonTypeInfo returnTypeInfo, CancellationToken cancellationToken)
675696
{

test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/AssertExtensions.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,12 @@ public static void EqualFunctionCallResults(object? expected, object? actual, Js
5555

5656
private static void AreJsonEquivalentValues(object? expected, object? actual, JsonSerializerOptions? options, string? propertyName = null)
5757
{
58-
options ??= JsonSerializerOptions.Default;
58+
options ??= AIJsonUtilities.DefaultOptions;
5959
JsonElement expectedElement = NormalizeToElement(expected, options);
6060
JsonElement actualElement = NormalizeToElement(actual, options);
6161
if (!JsonNode.DeepEquals(
62-
JsonSerializer.SerializeToNode(expectedElement),
63-
JsonSerializer.SerializeToNode(actualElement)))
62+
JsonSerializer.SerializeToNode(expectedElement, AIJsonUtilities.DefaultOptions),
63+
JsonSerializer.SerializeToNode(actualElement, AIJsonUtilities.DefaultOptions)))
6464
{
6565
string message = propertyName is null
6666
? $"Function result does not match expected JSON.\r\nExpected: {expectedElement.GetRawText()}\r\nActual: {actualElement.GetRawText()}"

test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseFormatTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ public void Serialization_JsonRoundtrips()
7171
public void Serialization_ForJsonSchemaRoundtrips()
7272
{
7373
string json = JsonSerializer.Serialize(
74-
ChatResponseFormat.ForJsonSchema(JsonSerializer.Deserialize<JsonElement>("[1,2,3]"), "name", "description"),
74+
ChatResponseFormat.ForJsonSchema(JsonSerializer.Deserialize<JsonElement>("[1,2,3]", AIJsonUtilities.DefaultOptions), "name", "description"),
7575
TestJsonSerializerContext.Default.ChatResponseFormat);
7676
Assert.Equal("""{"$type":"json","schema":[1,2,3],"schemaName":"name","schemaDescription":"description"}""", json);
7777

test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/ErrorContentTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ public void JsonSerialization_ShouldSerializeAndDeserializeCorrectly()
5151
ErrorCode = "ERR001",
5252
Details = "Something went wrong"
5353
};
54-
var options = new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase };
54+
JsonSerializerOptions options = new(AIJsonUtilities.DefaultOptions) { PropertyNamingPolicy = JsonNamingPolicy.CamelCase };
5555

5656
// Act
5757
var json = JsonSerializer.Serialize(errorContent, options);

test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/FunctionCallContentTests..cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ public static void CreateFromParsedArguments_ObjectJsonInput_ReturnsElementArgum
262262
"""{"Key1":{}, "Key2":null, "Key3" : [], "Key4" : 42, "Key5" : true }""",
263263
"callId",
264264
"functionName",
265-
argumentParser: static json => JsonSerializer.Deserialize<Dictionary<string, object?>>(json));
265+
argumentParser: static json => JsonSerializer.Deserialize<Dictionary<string, object?>>(json, AIJsonUtilities.DefaultOptions));
266266

267267
Assert.NotNull(content);
268268
Assert.Null(content.Exception);

test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Microsoft.Extensions.AI.Abstractions.Tests.csproj

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
1111
</PropertyGroup>
1212

13+
<PropertyGroup Condition="$([MSBuild]::IsTargetFrameworkCompatible('$(TargetFramework)', 'net8.0'))">
14+
<JsonSerializerIsReflectionEnabledByDefault>false</JsonSerializerIsReflectionEnabledByDefault>
15+
</PropertyGroup>
16+
1317
<PropertyGroup>
1418
<InjectDiagnosticAttributesOnLegacy>true</InjectDiagnosticAttributesOnLegacy>
1519
<InjectCompilerFeatureRequiredOnLegacy>true</InjectCompilerFeatureRequiredOnLegacy>

test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Utilities/AIJsonUtilitiesTests.cs

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
namespace Microsoft.Extensions.AI;
1919

20-
public static class AIJsonUtilitiesTests
20+
public static partial class AIJsonUtilitiesTests
2121
{
2222
[Fact]
2323
public static void DefaultOptions_HasExpectedConfiguration()
@@ -53,6 +53,18 @@ public static void DefaultOptions_UsesExpectedEscaping(string input, string expe
5353
Assert.Equal($@"""{expectedJsonString}""", json);
5454
}
5555
56+
[Fact]
57+
public static void DefaultOptions_UsesReflectionWhenDefault()
58+
{
59+
// Reflection is only turned off in .NET Core test environments.
60+
bool isDotnetCore = Type.GetType("System.Half") is not null;
61+
var options = AIJsonUtilities.DefaultOptions;
62+
Type anonType = new { Name = 42 }.GetType();
63+
64+
Assert.Equal(!isDotnetCore, JsonSerializer.IsReflectionEnabledByDefault);
65+
Assert.Equal(JsonSerializer.IsReflectionEnabledByDefault, AIJsonUtilities.DefaultOptions.TryGetTypeInfo(anonType, out _));
66+
}
67+
5668
[Theory]
5769
[InlineData(false)]
5870
[InlineData(true)]
@@ -145,7 +157,7 @@ public static void CreateJsonSchema_DefaultParameters_GeneratesExpectedJsonSchem
145157
}
146158
""").RootElement;
147159

148-
JsonElement actual = AIJsonUtilities.CreateJsonSchema(typeof(MyPoco), serializerOptions: JsonSerializerOptions.Default);
160+
JsonElement actual = AIJsonUtilities.CreateJsonSchema(typeof(MyPoco), serializerOptions: JsonContext.Default.Options);
149161

150162
Assert.True(DeepEquals(expected, actual));
151163
}
@@ -189,7 +201,7 @@ public static void CreateJsonSchema_OverriddenParameters_GeneratesExpectedJsonSc
189201
description: "alternative description",
190202
hasDefaultValue: true,
191203
defaultValue: null,
192-
serializerOptions: JsonSerializerOptions.Default,
204+
serializerOptions: JsonContext.Default.Options,
193205
inferenceOptions: inferenceOptions);
194206

195207
Assert.True(DeepEquals(expected, actual));
@@ -235,7 +247,7 @@ public static void CreateJsonSchema_UserDefinedTransformer()
235247
}
236248
};
237249

238-
JsonElement actual = AIJsonUtilities.CreateJsonSchema(typeof(MyPoco), serializerOptions: JsonSerializerOptions.Default, inferenceOptions: inferenceOptions);
250+
JsonElement actual = AIJsonUtilities.CreateJsonSchema(typeof(MyPoco), serializerOptions: JsonContext.Default.Options, inferenceOptions: inferenceOptions);
239251

240252
Assert.True(DeepEquals(expected, actual));
241253
}
@@ -263,7 +275,7 @@ public static void CreateJsonSchema_FiltersDisallowedKeywords()
263275
}
264276
""").RootElement;
265277

266-
JsonElement actual = AIJsonUtilities.CreateJsonSchema(typeof(PocoWithTypesWithOpenAIUnsupportedKeywords), serializerOptions: JsonSerializerOptions.Default);
278+
JsonElement actual = AIJsonUtilities.CreateJsonSchema(typeof(PocoWithTypesWithOpenAIUnsupportedKeywords), serializerOptions: JsonContext.Default.Options);
267279

268280
Assert.True(DeepEquals(expected, actual));
269281
}
@@ -283,7 +295,7 @@ public class PocoWithTypesWithOpenAIUnsupportedKeywords
283295
[Fact]
284296
public static void CreateFunctionJsonSchema_ReturnsExpectedValue()
285297
{
286-
JsonSerializerOptions options = new(JsonSerializerOptions.Default);
298+
JsonSerializerOptions options = new(AIJsonUtilities.DefaultOptions);
287299
AIFunction func = AIFunctionFactory.Create((int x, int y) => x + y, serializerOptions: options);
288300

289301
Assert.NotNull(func.UnderlyingMethod);
@@ -295,7 +307,7 @@ public static void CreateFunctionJsonSchema_ReturnsExpectedValue()
295307
[Fact]
296308
public static void CreateFunctionJsonSchema_TreatsIntegralTypesAsInteger_EvenWithAllowReadingFromString()
297309
{
298-
JsonSerializerOptions options = new(JsonSerializerOptions.Default) { NumberHandling = JsonNumberHandling.AllowReadingFromString };
310+
JsonSerializerOptions options = new(AIJsonUtilities.DefaultOptions) { NumberHandling = JsonNumberHandling.AllowReadingFromString };
299311
AIFunction func = AIFunctionFactory.Create((int a, int? b, long c, short d, float e, double f, decimal g) => { }, serializerOptions: options);
300312

301313
JsonElement schemaParameters = func.JsonSchema.GetProperty("properties");
@@ -376,7 +388,11 @@ public static void CreateJsonSchema_ValidateWithTestData(ITestData testData)
376388
[Fact]
377389
public static void AddAIContentType_DerivedAIContent()
378390
{
379-
JsonSerializerOptions options = new();
391+
JsonSerializerOptions options = new()
392+
{
393+
TypeInfoResolver = JsonTypeInfoResolver.Combine(AIJsonUtilities.DefaultOptions.TypeInfoResolver, JsonContext.Default),
394+
};
395+
380396
options.AddAIContentType<DerivedAIContent>("derivativeContent");
381397

382398
AIContent c = new DerivedAIContent { DerivedValue = 42 };
@@ -465,7 +481,7 @@ public static void CreateFunctionJsonSchema_InvokesIncludeParameterCallbackForEv
465481
{
466482
names.Add(p.Name);
467483
return p.Name is "first" or "fifth";
468-
}
484+
},
469485
});
470486

471487
Assert.Equal(["first", "second", "third", "fifth"], names);
@@ -483,14 +499,19 @@ private class DerivedAIContent : AIContent
483499
public int DerivedValue { get; set; }
484500
}
485501

502+
[JsonSerializable(typeof(DerivedAIContent))]
503+
[JsonSerializable(typeof(MyPoco))]
504+
[JsonSerializable(typeof(PocoWithTypesWithOpenAIUnsupportedKeywords))]
505+
private partial class JsonContext : JsonSerializerContext;
506+
486507
private static bool DeepEquals(JsonElement element1, JsonElement element2)
487508
{
488509
#if NET9_0_OR_GREATER
489510
return JsonElement.DeepEquals(element1, element2);
490511
#else
491512
return JsonNode.DeepEquals(
492-
JsonSerializer.SerializeToNode(element1),
493-
JsonSerializer.SerializeToNode(element2));
513+
JsonSerializer.SerializeToNode(element1, AIJsonUtilities.DefaultOptions),
514+
JsonSerializer.SerializeToNode(element2, AIJsonUtilities.DefaultOptions));
494515
#endif
495516
}
496517
}

0 commit comments

Comments
 (0)