diff --git a/src/DataProtection/DataProtection/src/Internal/DefaultTypeNameResolver.cs b/src/DataProtection/DataProtection/src/Internal/DefaultTypeNameResolver.cs new file mode 100644 index 000000000000..40a63826048a --- /dev/null +++ b/src/DataProtection/DataProtection/src/Internal/DefaultTypeNameResolver.cs @@ -0,0 +1,34 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.AspNetCore.DataProtection.Internal; + +internal sealed class DefaultTypeNameResolver : ITypeNameResolver +{ + public static readonly DefaultTypeNameResolver Instance = new(); + + private DefaultTypeNameResolver() + { + } + + [UnconditionalSuppressMessage("Trimmer", "IL2057", Justification = "Type.GetType is only used to resolve statically known types that are referenced by DataProtection assembly.")] + public bool TryResolveType(string typeName, [NotNullWhen(true)] out Type? type) + { + try + { + // Some exceptions are thrown regardless of the value of throwOnError. + // For example, if the type is found but cannot be loaded, + // a System.TypeLoadException is thrown even if throwOnError is false. + type = Type.GetType(typeName, throwOnError: false); + return type != null; + } + catch + { + type = null; + return false; + } + } +} diff --git a/src/DataProtection/DataProtection/src/Internal/ITypeNameResolver.cs b/src/DataProtection/DataProtection/src/Internal/ITypeNameResolver.cs new file mode 100644 index 000000000000..f31a037201b8 --- /dev/null +++ b/src/DataProtection/DataProtection/src/Internal/ITypeNameResolver.cs @@ -0,0 +1,12 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.AspNetCore.DataProtection.Internal; + +internal interface ITypeNameResolver +{ + bool TryResolveType(string typeName, [NotNullWhen(true)] out Type? type); +} diff --git a/src/DataProtection/DataProtection/src/KeyManagement/XmlKeyManager.cs b/src/DataProtection/DataProtection/src/KeyManagement/XmlKeyManager.cs index 31b1fc66b4ad..46ef62df0f38 100644 --- a/src/DataProtection/DataProtection/src/KeyManagement/XmlKeyManager.cs +++ b/src/DataProtection/DataProtection/src/KeyManagement/XmlKeyManager.cs @@ -49,6 +49,7 @@ public sealed class XmlKeyManager : IKeyManager, IInternalXmlKeyManager private const string RevokeAllKeysValue = "*"; private readonly IActivator _activator; + private readonly ITypeNameResolver _typeNameResolver; private readonly AlgorithmConfiguration _authenticatedEncryptorConfiguration; private readonly IKeyEscrowSink? _keyEscrowSink; private readonly IInternalXmlKeyManager _internalKeyManager; @@ -112,6 +113,8 @@ internal XmlKeyManager( var escrowSinks = keyManagementOptions.Value.KeyEscrowSinks; _keyEscrowSink = escrowSinks.Count > 0 ? new AggregateKeyEscrowSink(escrowSinks) : null; _activator = activator; + // Note: ITypeNameResolver is only implemented on the activator in tests. In production, it's always DefaultTypeNameResolver. + _typeNameResolver = activator as ITypeNameResolver ?? DefaultTypeNameResolver.Instance; TriggerAndResetCacheExpirationToken(suppressLogging: true); _internalKeyManager = _internalKeyManager ?? this; _encryptorFactories = keyManagementOptions.Value.AuthenticatedEncryptorFactories; @@ -463,27 +466,27 @@ IAuthenticatedEncryptorDescriptor IInternalXmlKeyManager.DeserializeDescriptorFr } } - [UnconditionalSuppressMessage("Trimmer", "IL2057", Justification = "Type.GetType result is only useful with types that are referenced by DataProtection assembly.")] private IAuthenticatedEncryptorDescriptorDeserializer CreateDeserializer(string descriptorDeserializerTypeName) { - var resolvedTypeName = TypeForwardingActivator.TryForwardTypeName(descriptorDeserializerTypeName, out var forwardedTypeName) + // typeNameToMatch will be used for matching against known types but not passed to the activator. + // The activator will do its own forwarding. + var typeNameToMatch = TypeForwardingActivator.TryForwardTypeName(descriptorDeserializerTypeName, out var forwardedTypeName) ? forwardedTypeName : descriptorDeserializerTypeName; - var type = Type.GetType(resolvedTypeName, throwOnError: false); - if (type == typeof(AuthenticatedEncryptorDescriptorDeserializer)) + if (typeof(AuthenticatedEncryptorDescriptorDeserializer).MatchName(typeNameToMatch, _typeNameResolver)) { return _activator.CreateInstance(descriptorDeserializerTypeName); } - else if (type == typeof(CngCbcAuthenticatedEncryptorDescriptorDeserializer) && RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && typeof(CngCbcAuthenticatedEncryptorDescriptorDeserializer).MatchName(typeNameToMatch, _typeNameResolver)) { return _activator.CreateInstance(descriptorDeserializerTypeName); } - else if (type == typeof(CngGcmAuthenticatedEncryptorDescriptorDeserializer) && RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && typeof(CngGcmAuthenticatedEncryptorDescriptorDeserializer).MatchName(typeNameToMatch, _typeNameResolver)) { return _activator.CreateInstance(descriptorDeserializerTypeName); } - else if (type == typeof(ManagedAuthenticatedEncryptorDescriptorDeserializer)) + else if (typeof(ManagedAuthenticatedEncryptorDescriptorDeserializer).MatchName(typeNameToMatch, _typeNameResolver)) { return _activator.CreateInstance(descriptorDeserializerTypeName); } diff --git a/src/DataProtection/DataProtection/src/TypeExtensions.cs b/src/DataProtection/DataProtection/src/TypeExtensions.cs index 89b69d0b70db..1766ff6a7c2d 100644 --- a/src/DataProtection/DataProtection/src/TypeExtensions.cs +++ b/src/DataProtection/DataProtection/src/TypeExtensions.cs @@ -3,6 +3,7 @@ using System; using System.Diagnostics.CodeAnalysis; +using Microsoft.AspNetCore.DataProtection.Internal; namespace Microsoft.AspNetCore.DataProtection; @@ -39,4 +40,16 @@ public static Type GetTypeWithTrimFriendlyErrorMessage(string typeName) throw new InvalidOperationException($"Unable to load type '{typeName}'. If the app is published with trimming then this type may have been trimmed. Ensure the type's assembly is excluded from trimming.", ex); } } + + public static bool MatchName(this Type matchType, string resolvedTypeName, ITypeNameResolver typeNameResolver) + { + // Before attempting to resolve the name to a type, check if it starts with the full name of the type. + // Use StartsWith to ignore potential assembly version differences. + if (matchType.FullName != null && resolvedTypeName.StartsWith(matchType.FullName, StringComparison.Ordinal)) + { + return typeNameResolver.TryResolveType(resolvedTypeName, out var resolvedType) && resolvedType == matchType; + } + + return false; + } } diff --git a/src/DataProtection/DataProtection/src/XmlEncryption/XmlEncryptionExtensions.cs b/src/DataProtection/DataProtection/src/XmlEncryption/XmlEncryptionExtensions.cs index 1b99664b486c..62d1bdf0b99c 100644 --- a/src/DataProtection/DataProtection/src/XmlEncryption/XmlEncryptionExtensions.cs +++ b/src/DataProtection/DataProtection/src/XmlEncryption/XmlEncryptionExtensions.cs @@ -67,27 +67,30 @@ public static XElement DecryptElement(this XElement element, IActivator activato return doc.Root!; } - [UnconditionalSuppressMessage("Trimmer", "IL2057", Justification = "Type.GetType result is only useful with types that are referenced by DataProtection assembly.")] private static IXmlDecryptor CreateDecryptor(IActivator activator, string decryptorTypeName) { - var resolvedTypeName = TypeForwardingActivator.TryForwardTypeName(decryptorTypeName, out var forwardedTypeName) + // typeNameToMatch will be used for matching against known types but not passed to the activator. + // The activator will do its own forwarding. + var typeNameToMatch = TypeForwardingActivator.TryForwardTypeName(decryptorTypeName, out var forwardedTypeName) ? forwardedTypeName : decryptorTypeName; - var type = Type.GetType(resolvedTypeName, throwOnError: false); - if (type == typeof(DpapiNGXmlDecryptor)) + // Note: ITypeNameResolver is only implemented on the activator in tests. In production, it's always DefaultTypeNameResolver. + var typeNameResolver = activator as ITypeNameResolver ?? DefaultTypeNameResolver.Instance; + + if (typeof(DpapiNGXmlDecryptor).MatchName(typeNameToMatch, typeNameResolver)) { return activator.CreateInstance(decryptorTypeName); } - else if (type == typeof(DpapiXmlDecryptor)) + else if (typeof(DpapiXmlDecryptor).MatchName(typeNameToMatch, typeNameResolver)) { return activator.CreateInstance(decryptorTypeName); } - else if (type == typeof(EncryptedXmlDecryptor)) + else if (typeof(EncryptedXmlDecryptor).MatchName(typeNameToMatch, typeNameResolver)) { return activator.CreateInstance(decryptorTypeName); } - else if (type == typeof(NullXmlDecryptor)) + else if (typeof(NullXmlDecryptor).MatchName(typeNameToMatch, typeNameResolver)) { return activator.CreateInstance(decryptorTypeName); } diff --git a/src/DataProtection/DataProtection/test/Microsoft.AspNetCore.DataProtection.Tests/XmlEncryption/XmlEncryptionExtensionsTests.cs b/src/DataProtection/DataProtection/test/Microsoft.AspNetCore.DataProtection.Tests/XmlEncryption/XmlEncryptionExtensionsTests.cs index dffa10477e87..2897f4b46182 100644 --- a/src/DataProtection/DataProtection/test/Microsoft.AspNetCore.DataProtection.Tests/XmlEncryption/XmlEncryptionExtensionsTests.cs +++ b/src/DataProtection/DataProtection/test/Microsoft.AspNetCore.DataProtection.Tests/XmlEncryption/XmlEncryptionExtensionsTests.cs @@ -49,6 +49,100 @@ public void DecryptElement_RootNodeRequiresDecryption_Success() XmlAssert.Equal("", retVal); } + [Fact] + public void DecryptElement_CustomType_TypeNameResolverNotCalled() + { + // Arrange + var decryptorTypeName = typeof(MyXmlDecryptor).AssemblyQualifiedName; + + var original = XElement.Parse(@$" + + + "); + + var mockActivator = new Mock(); + mockActivator.ReturnDecryptedElementGivenDecryptorTypeNameAndInput(decryptorTypeName, "", ""); + var mockTypeNameResolver = mockActivator.As(); + + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(mockActivator.Object); + var services = serviceCollection.BuildServiceProvider(); + var activator = services.GetActivator(); + + // Act + var retVal = original.DecryptElement(activator); + + // Assert + XmlAssert.Equal("", retVal); + Type resolvedType; + mockTypeNameResolver.Verify(o => o.TryResolveType(It.IsAny(), out resolvedType), Times.Never()); + } + + [Fact] + public void DecryptElement_KnownType_TypeNameResolverCalled() + { + // Arrange + var decryptorTypeName = typeof(NullXmlDecryptor).AssemblyQualifiedName; + TypeForwardingActivator.TryForwardTypeName(decryptorTypeName, out var forwardedTypeName); + + var original = XElement.Parse(@$" + + + + + "); + + var mockActivator = new Mock(); + mockActivator.Setup(o => o.CreateInstance(typeof(NullXmlDecryptor), decryptorTypeName)).Returns(new NullXmlDecryptor()); + var mockTypeNameResolver = mockActivator.As(); + var resolvedType = typeof(NullXmlDecryptor); + mockTypeNameResolver.Setup(mockTypeNameResolver => mockTypeNameResolver.TryResolveType(forwardedTypeName, out resolvedType)).Returns(true); + + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(mockActivator.Object); + var services = serviceCollection.BuildServiceProvider(); + var activator = services.GetActivator(); + + // Act + var retVal = original.DecryptElement(activator); + + // Assert + XmlAssert.Equal("", retVal); + mockTypeNameResolver.Verify(o => o.TryResolveType(It.IsAny(), out resolvedType), Times.Once()); + } + + [Fact] + public void DecryptElement_KnownType_UnableToResolveType_Success() + { + // Arrange + var decryptorTypeName = typeof(NullXmlDecryptor).AssemblyQualifiedName; + + var original = XElement.Parse(@$" + + + + + "); + + var mockActivator = new Mock(); + mockActivator.Setup(o => o.CreateInstance(typeof(IXmlDecryptor), decryptorTypeName)).Returns(new NullXmlDecryptor()); + var mockTypeNameResolver = mockActivator.As(); + Type resolvedType = null; + mockTypeNameResolver.Setup(mockTypeNameResolver => mockTypeNameResolver.TryResolveType(It.IsAny(), out resolvedType)).Returns(false); + + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(mockActivator.Object); + var services = serviceCollection.BuildServiceProvider(); + var activator = services.GetActivator(); + + // Act + var retVal = original.DecryptElement(activator); + + // Assert + XmlAssert.Equal("", retVal); + mockTypeNameResolver.Verify(o => o.TryResolveType(It.IsAny(), out resolvedType), Times.Once()); + } + [Fact] public void DecryptElement_MultipleNodesRequireDecryption_AvoidsRecursion_Success() {