diff --git a/src/coreclr/System.Private.CoreLib/src/System/Reflection/TypeNameResolver.CoreCLR.cs b/src/coreclr/System.Private.CoreLib/src/System/Reflection/TypeNameResolver.CoreCLR.cs index 9bf62c527ed789..23546a595564d8 100644 --- a/src/coreclr/System.Private.CoreLib/src/System/Reflection/TypeNameResolver.CoreCLR.cs +++ b/src/coreclr/System.Private.CoreLib/src/System/Reflection/TypeNameResolver.CoreCLR.cs @@ -3,6 +3,7 @@ using System.Diagnostics; using System.Diagnostics.CodeAnalysis; +using System.Globalization; using System.IO; using System.Reflection.Metadata; using System.Runtime.CompilerServices; @@ -21,9 +22,12 @@ internal partial struct TypeNameResolver private bool _extensibleParser; private bool _requireAssemblyQualifiedName; private bool _suppressContextualReflectionContext; + private IntPtr _unsafeAccessorMethod; private Assembly? _requestingAssembly; private Assembly? _topLevelAssembly; + private bool SupportsUnboundGenerics { get => _unsafeAccessorMethod != IntPtr.Zero; } + [RequiresUnreferencedCode("The type might be removed")] internal static Type? GetType( string typeName, @@ -128,14 +132,14 @@ internal static RuntimeType GetTypeReferencedByCustomAttribute(string typeName, // Used by VM internal static unsafe RuntimeType? GetTypeHelper(char* pTypeName, RuntimeAssembly? requestingAssembly, - bool throwOnError, bool requireAssemblyQualifiedName) + bool throwOnError, bool requireAssemblyQualifiedName, IntPtr unsafeAccessorMethod) { ReadOnlySpan typeName = MemoryMarshal.CreateReadOnlySpanFromNullTerminated(pTypeName); - return GetTypeHelper(typeName, requestingAssembly, throwOnError, requireAssemblyQualifiedName); + return GetTypeHelper(typeName, requestingAssembly, throwOnError, requireAssemblyQualifiedName, unsafeAccessorMethod); } internal static unsafe RuntimeType? GetTypeHelper(ReadOnlySpan typeName, RuntimeAssembly? requestingAssembly, - bool throwOnError, bool requireAssemblyQualifiedName) + bool throwOnError, bool requireAssemblyQualifiedName, IntPtr unsafeAccessorMethod = 0) { // Compat: Empty name throws TypeLoadException instead of // the natural ArgumentException @@ -158,6 +162,7 @@ internal static RuntimeType GetTypeReferencedByCustomAttribute(string typeName, _throwOnError = throwOnError, _suppressContextualReflectionContext = true, _requireAssemblyQualifiedName = requireAssemblyQualifiedName, + _unsafeAccessorMethod = unsafeAccessorMethod, }.Resolve(parsed); if (type != null) @@ -186,6 +191,9 @@ internal static RuntimeType GetTypeReferencedByCustomAttribute(string typeName, return assembly; } + [LibraryImport(RuntimeHelpers.QCall, EntryPoint = "UnsafeAccessors_ResolveGenericParamToTypeHandle")] + private static partial IntPtr ResolveGenericParamToTypeHandle(IntPtr unsafeAccessorMethod, [MarshalAs(UnmanagedType.Bool)] bool isMethodParam, uint paramIndex); + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026:RequiresUnreferencedCode", Justification = "TypeNameResolver.GetType is marked as RequiresUnreferencedCode.")] [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2075:UnrecognizedReflectionPattern", @@ -228,6 +236,42 @@ internal static RuntimeType GetTypeReferencedByCustomAttribute(string typeName, { if (assembly is null) { + if (SupportsUnboundGenerics + && !string.IsNullOrEmpty(escapedTypeName) + && escapedTypeName[0] == '!') + { + Debug.Assert(_throwOnError); // Unbound generic support currently always throws. + + // Parse the type as an unbound generic parameter. Following the common VAR/MVAR IL syntax: + // ! - Represents a zero-based index into the type's generic parameters. + // !! - Represents a zero-based index into the method's generic parameters. + + // Confirm we have at least one more character + if (escapedTypeName.Length == 1) + { + throw new TypeLoadException(SR.Format(SR.TypeLoad_ResolveType, escapedTypeName), typeName: escapedTypeName); + } + + // At this point we expect either another '!' and then a number or a number. + bool isMethodParam = escapedTypeName[1] == '!'; + ReadOnlySpan toParse = isMethodParam + ? escapedTypeName.AsSpan(2) // Skip over "!!" + : escapedTypeName.AsSpan(1); // Skip over "!" + if (!uint.TryParse(toParse, NumberStyles.None, null, out uint paramIndex)) + { + throw new TypeLoadException(SR.Format(SR.TypeLoad_ResolveType, escapedTypeName), typeName: escapedTypeName); + } + + Debug.Assert(_unsafeAccessorMethod != IntPtr.Zero); + IntPtr typeHandle = ResolveGenericParamToTypeHandle(_unsafeAccessorMethod, isMethodParam, paramIndex); + if (typeHandle == IntPtr.Zero) + { + throw new TypeLoadException(SR.Format(SR.TypeLoad_ResolveType, escapedTypeName), typeName: escapedTypeName); + } + + return RuntimeTypeHandle.GetRuntimeTypeFromHandle(typeHandle); + } + if (_requireAssemblyQualifiedName) { if (_throwOnError) diff --git a/src/coreclr/dlls/mscorrc/mscorrc.rc b/src/coreclr/dlls/mscorrc/mscorrc.rc index cb371763691e65..aa9c0464a05507 100644 --- a/src/coreclr/dlls/mscorrc/mscorrc.rc +++ b/src/coreclr/dlls/mscorrc/mscorrc.rc @@ -596,6 +596,8 @@ BEGIN BFA_BAD_FIELD_TOKEN "Field token out of range." BFA_INVALID_FIELD_ACC_FLAGS "Invalid Field Access Flags." BFA_INVALID_UNSAFEACCESSOR "Invalid usage of UnsafeAccessorAttribute." + BFA_INVALID_UNSAFEACCESSORTYPE "Invalid usage of UnsafeAccessorTypeAttribute." + BFA_INVALID_UNSAFEACCESSORTYPE_VALUETYPE "ValueTypes are not supported with UnsafeAccessorTypeAttribute." BFA_FIELD_LITERAL_AND_INIT "Field is Literal and InitOnly." BFA_NONSTATIC_GLOBAL_FIELD "Non-Static Global Field." BFA_INSTANCE_FIELD_IN_INT "Instance Field in an Interface." diff --git a/src/coreclr/dlls/mscorrc/resource.h b/src/coreclr/dlls/mscorrc/resource.h index fc103677bcd4b2..c81cad38f84b2b 100644 --- a/src/coreclr/dlls/mscorrc/resource.h +++ b/src/coreclr/dlls/mscorrc/resource.h @@ -415,10 +415,11 @@ #define BFA_BAD_TYPEREF_TOKEN 0x2046 #define BFA_BAD_CLASS_INT_CA_FORMAT 0x2048 #define BFA_BAD_COMPLUS_SIG 0x2049 -#define BFA_BAD_ELEM_IN_SIZEOF 0x204b -#define BFA_IJW_IN_COLLECTIBLE_ALC 0x204c -#define BFA_INVALID_UNSAFEACCESSOR 0x204d - +#define BFA_BAD_ELEM_IN_SIZEOF 0x204a +#define BFA_IJW_IN_COLLECTIBLE_ALC 0x204b +#define BFA_INVALID_UNSAFEACCESSOR 0x204c +#define BFA_INVALID_UNSAFEACCESSORTYPE 0x204d +#define BFA_INVALID_UNSAFEACCESSORTYPE_VALUETYPE 0x204e #define IDS_CLASSLOAD_INTERFACE_NO_ACCESS 0x204f #define BFA_BAD_CA_HEADER 0x2050 diff --git a/src/coreclr/tools/Common/TypeSystem/Common/Utilities/CustomAttributeTypeNameParser.cs b/src/coreclr/tools/Common/TypeSystem/Common/Utilities/CustomAttributeTypeNameParser.cs index 83e74b02dd9f0c..1c2de5d1743ff5 100644 --- a/src/coreclr/tools/Common/TypeSystem/Common/Utilities/CustomAttributeTypeNameParser.cs +++ b/src/coreclr/tools/Common/TypeSystem/Common/Utilities/CustomAttributeTypeNameParser.cs @@ -21,7 +21,7 @@ public static class CustomAttributeTypeNameParser /// This is the inverse of what does. /// public static TypeDesc GetTypeByCustomAttributeTypeName(this ModuleDesc module, string name, bool throwIfNotFound = true, - Func canonResolver = null) + Func canonGenericResolver = null) { if (!TypeName.TryParse(name.AsSpan(), out TypeName parsed, s_typeNameParseOptions)) ThrowHelper.ThrowTypeLoadException(name, module); @@ -31,7 +31,7 @@ public static TypeDesc GetTypeByCustomAttributeTypeName(this ModuleDesc module, _context = module.Context, _module = module, _throwIfNotFound = throwIfNotFound, - _canonResolver = canonResolver + _canonGenericResolver = canonGenericResolver }.Resolve(parsed); } @@ -91,7 +91,7 @@ private struct TypeNameResolver internal TypeSystemContext _context; internal ModuleDesc _module; internal bool _throwIfNotFound; - internal Func _canonResolver; + internal Func _canonGenericResolver; internal List _referencedModules; @@ -136,30 +136,30 @@ private TypeDesc GetSimpleType(TypeName typeName) } ModuleDesc module = _module; - if (topLevelTypeName.AssemblyName != null) + if (topLevelTypeName.AssemblyName is not null) { module = _context.ResolveAssembly(typeName.AssemblyName, throwIfNotFound: _throwIfNotFound); if (module == null) return null; } - if (module != null) + if (module is not null) { TypeDesc type = GetSimpleTypeFromModule(typeName, module); - if (type != null) + if (type is not null) { _referencedModules?.Add(module); return type; } } - // If it didn't resolve and wasn't assembly-qualified, we also try core library if (topLevelTypeName.AssemblyName == null) { + // If it didn't resolve and wasn't assembly-qualified, we also try core library if (module != _context.SystemModule) { TypeDesc type = GetSimpleTypeFromModule(typeName, _context.SystemModule); - if (type != null) + if (type is not null) { _referencedModules?.Add(_context.SystemModule); return type; @@ -184,9 +184,9 @@ private TypeDesc GetSimpleTypeFromModule(TypeName typeName, ModuleDesc module) string fullName = TypeNameHelpers.Unescape(typeName.FullName); - if (_canonResolver != null) + if (_canonGenericResolver != null) { - MetadataType canonType = _canonResolver(module, fullName); + TypeDesc canonType = _canonGenericResolver(module, fullName); if (canonType != null) return canonType; } diff --git a/src/coreclr/tools/Common/TypeSystem/IL/UnsafeAccessors.cs b/src/coreclr/tools/Common/TypeSystem/IL/UnsafeAccessors.cs index e8e97f2eb31973..b6be413b9baa14 100644 --- a/src/coreclr/tools/Common/TypeSystem/IL/UnsafeAccessors.cs +++ b/src/coreclr/tools/Common/TypeSystem/IL/UnsafeAccessors.cs @@ -2,10 +2,13 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Collections; using System.Diagnostics; using System.Globalization; +using System.Reflection; using System.Reflection.Metadata; using System.Runtime.InteropServices; +using System.Threading; using Internal.IL.Stubs; using Internal.TypeSystem; using Internal.TypeSystem.Ecma; @@ -40,15 +43,21 @@ public static MethodIL TryGetIL(EcmaMethod method) Declaration = method }; - MethodSignature sig = method.Signature; - TypeDesc retType = sig.ReturnType; - TypeDesc firstArgType = null; - if (sig.Length > 0) + SetTargetResult result; + + result = TrySetTargetMethodSignature(ref context); + if (result is not SetTargetResult.Success) { - firstArgType = sig[0]; + return GenerateAccessorSpecificFailure(ref context, name, result); } - SetTargetResult result; + TypeDesc retType = context.DeclarationSignature.ReturnType; + + TypeDesc firstArgType = null; + if (context.DeclarationSignature.Length > 0) + { + firstArgType = context.DeclarationSignature[0]; + } // Using the kind type, perform the following: // 1) Validate the basic type information from the signature. @@ -110,7 +119,7 @@ public static MethodIL TryGetIL(EcmaMethod method) case UnsafeAccessorKind.Field: case UnsafeAccessorKind.StaticField: // Field access requires a single argument for target type and a return type. - if (sig.Length != 1 || retType.IsVoid) + if (context.DeclarationSignature.Length != 1 || retType.IsVoid) { return GenerateAccessorBadImageFailure(method); } @@ -209,6 +218,8 @@ private struct GenerationContext { public UnsafeAccessorKind Kind; public EcmaMethod Declaration; + public MethodSignature DeclarationSignature; + public BitArray ReplacedSignatureElements; public TypeDesc TargetType; public bool IsTargetStatic; public MethodDesc TargetMethod; @@ -241,7 +252,7 @@ private static bool ValidateTargetType(TypeDesc targetTypeMaybe, out TypeDesc va private static bool DoesMethodMatchUnsafeAccessorDeclaration(ref GenerationContext context, MethodDesc method, bool ignoreCustomModifiers) { - MethodSignature declSig = context.Declaration.Signature; + MethodSignature declSig = context.DeclarationSignature; MethodSignature maybeSig = method.Signature; // Check if we need to also validate custom modifiers. @@ -249,14 +260,14 @@ private static bool DoesMethodMatchUnsafeAccessorDeclaration(ref GenerationConte if (!ignoreCustomModifiers) { // Compare any unmanaged callconv and custom modifiers on the signatures. - // We treat unmanaged calling conventions at the same level of precedance + // We treat unmanaged calling conventions at the same level of precedence // as custom modifiers, eventhough they are normally bits in a signature. - ReadOnlySpan kinds = new EmbeddedSignatureDataKind[] - { + ReadOnlySpan kinds = + [ EmbeddedSignatureDataKind.UnmanagedCallConv, EmbeddedSignatureDataKind.RequiredCustomModifier, EmbeddedSignatureDataKind.OptionalCustomModifier - }; + ]; var declData = declSig.GetEmbeddedSignatureData(kinds) ?? Array.Empty(); var maybeData = maybeSig.GetEmbeddedSignatureData(kinds) ?? Array.Empty(); @@ -403,8 +414,10 @@ private enum SetTargetResult { Success, Missing, + MissingType, Ambiguous, Invalid, + NotSupported } private static SetTargetResult TrySetTargetMethod(ref GenerationContext context, string name, bool ignoreCustomModifiers = true) @@ -494,20 +507,250 @@ private static SetTargetResult TrySetTargetField(ref GenerationContext context, return SetTargetResult.Missing; } + private static bool IsValidInitialTypeForReplacementType(TypeDesc initialType, TypeDesc replacementType) + { + if (replacementType.IsByRef) + { + if (!initialType.IsByRef) + { + // We can't replace a non-byref with a byref. + return false; + } + + return IsValidInitialTypeForReplacementType(((ByRefType)initialType).ParameterType, ((ByRefType)replacementType).ParameterType); + } + else if (initialType.IsByRef) + { + // We can't replace a byref with a non-byref. + return false; + } + + if (replacementType.IsPointer) + { + return initialType is PointerType { ParameterType.IsVoid: true }; + } + + Debug.Assert(!replacementType.IsValueType); + + return initialType.IsObject; + } + + private static SetTargetResult TrySetTargetMethodSignature(ref GenerationContext context) + { + EcmaMethod method = context.Declaration; + MetadataReader reader = method.MetadataReader; + MethodDefinition methodDef = reader.GetMethodDefinition(method.Handle); + ParameterHandleCollection parameters = methodDef.GetParameters(); + + MethodSignature originalSignature = method.Signature; + + MethodSignatureBuilder updatedSignature = new MethodSignatureBuilder(originalSignature); + + foreach (ParameterHandle parameterHandle in parameters) + { + Parameter parameter = reader.GetParameter(parameterHandle); + + if (parameter.SequenceNumber > originalSignature.Length) + { + // This is invalid metadata (parameter metadata for a parameter that doesn't exist in the signature). + return SetTargetResult.Invalid; + } + + CustomAttributeHandle unsafeAccessorTypeAttributeHandle = FindUnsafeAccessorTypeAttribute(reader, parameter); + + if (unsafeAccessorTypeAttributeHandle.IsNil) + { + continue; + } + + bool isReturnValue = parameter.SequenceNumber == 0; + + TypeDesc initialType = isReturnValue ? originalSignature.ReturnType : originalSignature[parameter.SequenceNumber - 1]; + + if (isReturnValue && initialType.IsByRef) + { + // We can't support UnsafeAccessorTypeAttribute on by-ref returns + // today as it would create a type-safety hole. + return SetTargetResult.NotSupported; + } + + SetTargetResult decodeResult = DecodeUnsafeAccessorType(method, reader.GetCustomAttribute(unsafeAccessorTypeAttributeHandle), out TypeDesc replacementType); + if (decodeResult != SetTargetResult.Success) + { + return decodeResult; + } + + // Future versions of the runtime may support + // UnsafeAccessorTypeAttribute on value types. + if (replacementType.IsValueType) + { + return SetTargetResult.NotSupported; + } + + if (!IsValidInitialTypeForReplacementType(initialType, replacementType)) + { + return SetTargetResult.Invalid; + } + + context.ReplacedSignatureElements ??= new BitArray(originalSignature.Length + 1, false); + context.ReplacedSignatureElements[parameter.SequenceNumber] = true; + + if (isReturnValue) + { + updatedSignature.ReturnType = replacementType; + } + else + { + updatedSignature[parameter.SequenceNumber - 1] = replacementType; + } + } + + context.DeclarationSignature = updatedSignature.ToSignature(); + return SetTargetResult.Success; + } + + private static SetTargetResult DecodeUnsafeAccessorType(EcmaMethod method, CustomAttribute unsafeAccessorTypeAttribute, out TypeDesc replacementType) + { + replacementType = null; + CustomAttributeValue decoded = unsafeAccessorTypeAttribute.DecodeValue( + new CustomAttributeTypeProvider(method.Module)); + + if (decoded.FixedArguments[0].Value is not string replacementTypeName) + { + return SetTargetResult.Invalid; + } + + replacementType = method.Module.GetTypeByCustomAttributeTypeName( + replacementTypeName, + throwIfNotFound: false, + canonGenericResolver: (module, name) => + { + if (!name.StartsWith('!')) + { + return null; + } + + bool isMethodParameter = name.StartsWith("!!", StringComparison.Ordinal); + + if (!int.TryParse(name.AsSpan(isMethodParameter ? 2 : 1), NumberStyles.None, CultureInfo.InvariantCulture, out int index)) + { + return null; + } + + if (isMethodParameter) + { + if (index >= method.Instantiation.Length) + { + return null; + } + } + else + { + if (index >= method.OwningType.Instantiation.Length) + { + return null; + } + } + + return module.Context.GetSignatureVariable(index, isMethodParameter); + }); + + return replacementType is null + ? SetTargetResult.MissingType + : SetTargetResult.Success; + } + + private static CustomAttributeHandle FindUnsafeAccessorTypeAttribute(MetadataReader reader, Parameter parameter) + { + foreach (CustomAttributeHandle customAttributeHandle in parameter.GetCustomAttributes()) + { + reader.GetAttributeNamespaceAndName(customAttributeHandle, out StringHandle namespaceName, out StringHandle name); + if (reader.StringComparer.Equals(namespaceName, "System.Runtime.CompilerServices") + && reader.StringComparer.Equals(name, "UnsafeAccessorTypeAttribute")) + { + return customAttributeHandle; + } + } + + return default; + } + + private static ParameterHandle FindParameterForSequenceNumber(MetadataReader reader, ref ParameterHandleCollection.Enumerator parameterEnumerator, int sequenceNumber) + { + Parameter currentParameter = reader.GetParameter(parameterEnumerator.Current); + if (currentParameter.SequenceNumber == sequenceNumber) + { + return parameterEnumerator.Current; + } + + // Scan until we are either at this parameter or at the first one after it (if there is no Parameter row in the table) + while (parameterEnumerator.MoveNext()) + { + Parameter thisParameterMaybe = reader.GetParameter(parameterEnumerator.Current); + if (thisParameterMaybe.SequenceNumber > sequenceNumber) + { + // We've passed where it should be. + return default; + } + + if (thisParameterMaybe.SequenceNumber == sequenceNumber) + { + // We found it. + return parameterEnumerator.Current; + } + } + + return default; + } + private static MethodIL GenerateAccessor(ref GenerationContext context) { ILEmitter emit = new ILEmitter(); ILCodeStream codeStream = emit.NewCodeStream(); + MetadataReader reader = context.Declaration.MetadataReader; + ParameterHandleCollection.Enumerator parameterEnumerator = reader.GetMethodDefinition(context.Declaration.Handle).GetParameters().GetEnumerator(); + parameterEnumerator.MoveNext(); + // Load stub arguments. // When the target is static, the first argument is only // used to look up the target member to access and ignored // during dispatch. int beginIndex = context.IsTargetStatic ? 1 : 0; - int stubArgCount = context.Declaration.Signature.Length; + int stubArgCount = context.DeclarationSignature.Length; + Stubs.ILLocalVariable?[] localsToRestore = null; + for (int i = beginIndex; i < stubArgCount; ++i) { codeStream.EmitLdArg(i); + if (context.ReplacedSignatureElements?[i + 1] == true) + { + if (context.DeclarationSignature[i] is { Category: TypeFlags.Class } classType) + { + codeStream.Emit(ILOpcode.unbox_any, emit.NewToken(classType)); + } + else if (context.DeclarationSignature[i] is ByRefType { ParameterType.Category: TypeFlags.Class } byrefType) + { + localsToRestore ??= new Stubs.ILLocalVariable?[stubArgCount]; + + TypeDesc targetType = byrefType.ParameterType; + Stubs.ILLocalVariable local = emit.NewLocal(targetType); + codeStream.EmitLdInd(targetType); + codeStream.Emit(ILOpcode.unbox_any, emit.NewToken(targetType)); + codeStream.EmitStLoc(local); + codeStream.EmitLdLoca(local); + + // Only mark the local to be restored after the call + // if the parameter is not marked as "in". + // The "sequence number" for parameters is 1-based, whereas the parameter index is 0-based. + ParameterHandle paramHandle = FindParameterForSequenceNumber(reader, ref parameterEnumerator, i + 1); + if (paramHandle.IsNil + || !reader.GetParameter(paramHandle).Attributes.HasFlag(ParameterAttributes.In)) + { + localsToRestore[i] = local; + } + } + } } // Provide access to the target member @@ -538,6 +781,19 @@ private static MethodIL GenerateAccessor(ref GenerationContext context) break; } + if (localsToRestore is not null) + { + for (int i = beginIndex; i < stubArgCount; ++i) + { + if (localsToRestore[i] != null) + { + codeStream.EmitLdArg(i); + codeStream.EmitLdLoc(localsToRestore[i].Value); + codeStream.EmitStInd(((ParameterizedType)context.Declaration.Signature[i]).ParameterType); + } + } + } + // Return from the generated stub codeStream.Emit(ILOpcode.ret); return emit.Link(context.Declaration); @@ -563,6 +819,14 @@ private static MethodIL GenerateAccessorSpecificFailure(ref GenerationContext co codeStream.EmitLdc((int)ExceptionStringID.InvalidProgramDefault); thrower = typeSysContext.GetHelperEntryPoint("ThrowHelpers", "ThrowInvalidProgramException"); } + else if (result is SetTargetResult.NotSupported) + { + thrower = typeSysContext.GetHelperEntryPoint("ThrowHelpers", "ThrowNotSupportedException"); + } + else if (result is SetTargetResult.MissingType) + { + thrower = typeSysContext.GetHelperEntryPoint("ThrowHelpers", "ThrowUnavailableType"); + } else { Debug.Assert(result is SetTargetResult.Missing); diff --git a/src/coreclr/vm/corelib.h b/src/coreclr/vm/corelib.h index ba3c64b7e08375..be54e7e196e0e3 100644 --- a/src/coreclr/vm/corelib.h +++ b/src/coreclr/vm/corelib.h @@ -384,7 +384,7 @@ DEFINE_FIELD(RT_TYPE_HANDLE, M_TYPE, m_type) DEFINE_METHOD(TYPED_REFERENCE, GETREFANY, GetRefAny, NoSig) DEFINE_CLASS(TYPE_NAME_RESOLVER, Reflection, TypeNameResolver) -DEFINE_METHOD(TYPE_NAME_RESOLVER, GET_TYPE_HELPER, GetTypeHelper, SM_Type_CharPtr_RuntimeAssembly_Bool_Bool_RetRuntimeType) +DEFINE_METHOD(TYPE_NAME_RESOLVER, GET_TYPE_HELPER, GetTypeHelper, SM_Type_CharPtr_RuntimeAssembly_Bool_Bool_IntPtr_RetRuntimeType) DEFINE_CLASS_U(Reflection, RtFieldInfo, NoClass) DEFINE_FIELD_U(m_fieldHandle, ReflectFieldObject, m_pFD) diff --git a/src/coreclr/vm/ilstubresolver.cpp b/src/coreclr/vm/ilstubresolver.cpp index 7e2786baa9d82c..0d877e94e454c5 100644 --- a/src/coreclr/vm/ilstubresolver.cpp +++ b/src/coreclr/vm/ilstubresolver.cpp @@ -213,6 +213,18 @@ void ILStubResolver::ResolveToken(mdToken token, ResolvedToken* resolvedToken) resolvedToken->TypeHandle = TypeHandle(pMT); } break; + + case mdtTypeSpec: + { + TokenLookupMap::TypeSpecEntry entry = m_pCompileTimeState->m_tokenLookupMap.LookupTypeSpec(token); + _ASSERTE(entry.ClassSignatureToken != mdTokenNil); + _ASSERTE(!entry.Type.IsNull()); + resolvedToken->TypeSignature = m_pCompileTimeState->m_tokenLookupMap.LookupSig(entry.ClassSignatureToken); + + SigTypeContext typeContext{ m_pStubMD->GetClassInstantiation(), m_pStubMD->GetMethodInstantiation() }; + resolvedToken->TypeHandle = resolvedToken->TypeSignature.GetTypeHandleThrowing(m_pStubMD->GetModule(), &typeContext); + } + break; #endif // !defined(DACCESS_COMPILE) default: diff --git a/src/coreclr/vm/jitinterface.cpp b/src/coreclr/vm/jitinterface.cpp index 737f431d9102dc..eb3f3ef4f99b83 100644 --- a/src/coreclr/vm/jitinterface.cpp +++ b/src/coreclr/vm/jitinterface.cpp @@ -917,8 +917,18 @@ void CEEInfo::resolveToken(/* IN, OUT */ CORINFO_RESOLVED_TOKEN * pResolvedToken } else { - if ((tkType != mdtTypeDef) && (tkType != mdtTypeRef)) + if (tkType == mdtTypeSpec) + { + // We have a TypeSpec, so we need to verify the signature is non-NULL + // and the typehandle has been fully instantiated. + if (pResolvedToken->pTypeSpec == NULL || th.ContainsGenericVariables()) + ThrowBadTokenException(pResolvedToken); + } + else if ((tkType != mdtTypeDef) && (tkType != mdtTypeRef)) + { ThrowBadTokenException(pResolvedToken); + } + if ((tokenType & CORINFO_TOKENKIND_Class) == 0) ThrowBadTokenException(pResolvedToken); if (th.IsNull()) @@ -14552,7 +14562,8 @@ static Signature BuildResumptionStubCalliSignature(MetaSig& msig, MethodTable* m sigBuilder.AppendByte(callConv); sigBuilder.AppendData(numArgs); - auto appendTypeHandle = [&](TypeHandle th) { + auto appendTypeHandle = [](SigBuilder& sigBuilder, TypeHandle th) + { _ASSERTE(!th.IsByRef()); CorElementType ty = th.GetSignatureCorElementType(); if (CorTypeInfo::IsObjRef(ty)) @@ -14569,9 +14580,9 @@ static Signature BuildResumptionStubCalliSignature(MetaSig& msig, MethodTable* m sigBuilder.AppendElementType(ELEMENT_TYPE_INTERNAL); sigBuilder.AppendPointer(th.AsPtr()); } - }; + }; - appendTypeHandle(msig.GetRetTypeHandleThrowing()); // return type + appendTypeHandle(sigBuilder, msig.GetRetTypeHandleThrowing()); // return type #ifndef TARGET_X86 if (msig.HasGenericContextArg()) { @@ -14586,7 +14597,7 @@ static Signature BuildResumptionStubCalliSignature(MetaSig& msig, MethodTable* m while ((ty = msig.NextArg()) != ELEMENT_TYPE_END) { TypeHandle tyHnd = msig.GetLastTypeHandleThrowing(); - appendTypeHandle(tyHnd); + appendTypeHandle(sigBuilder, tyHnd); } #ifdef TARGET_X86 diff --git a/src/coreclr/vm/metasig.h b/src/coreclr/vm/metasig.h index 840cd137c94e96..7ad14c4d1f5c7a 100644 --- a/src/coreclr/vm/metasig.h +++ b/src/coreclr/vm/metasig.h @@ -170,7 +170,7 @@ // static methods: DEFINE_METASIG_T(SM(Int_IntPtr_Obj_RetException, i I j, C(EXCEPTION))) -DEFINE_METASIG_T(SM(Type_CharPtr_RuntimeAssembly_Bool_Bool_RetRuntimeType, P(u) C(ASSEMBLY) F F, C(CLASS))) +DEFINE_METASIG_T(SM(Type_CharPtr_RuntimeAssembly_Bool_Bool_IntPtr_RetRuntimeType, P(u) C(ASSEMBLY) F F I, C(CLASS))) DEFINE_METASIG_T(SM(Type_RetIntPtr, C(TYPE), I)) DEFINE_METASIG(SM(RefIntPtr_IntPtr_IntPtr_Int_RetObj, r(I) I I i, j)) DEFINE_METASIG(SM(IntPtr_UInt_VoidPtr_RetObj, I K P(v), j)) diff --git a/src/coreclr/vm/method.hpp b/src/coreclr/vm/method.hpp index e066f857450bf7..e5430ffd9992d5 100644 --- a/src/coreclr/vm/method.hpp +++ b/src/coreclr/vm/method.hpp @@ -2085,6 +2085,10 @@ class MethodDesc friend struct ::cdac_data; }; +#ifndef DACCESS_COMPILE +extern "C" void* QCALLTYPE UnsafeAccessors_ResolveGenericParamToTypeHandle(MethodDesc* unsafeAccessorMethod, BOOL isMethodParam, DWORD paramIndex); +#endif // DACCESS_COMPILE + template<> struct cdac_data { static constexpr size_t ChunkIndex = offsetof(MethodDesc, m_chunkIndex); diff --git a/src/coreclr/vm/qcallentrypoints.cpp b/src/coreclr/vm/qcallentrypoints.cpp index 19ce7a5ffa05ab..d0f625b895d504 100644 --- a/src/coreclr/vm/qcallentrypoints.cpp +++ b/src/coreclr/vm/qcallentrypoints.cpp @@ -177,6 +177,7 @@ static const Entry s_QCall[] = DllImportEntry(RuntimeFieldHandle_GetEnCFieldAddr) DllImportEntry(RuntimeFieldHandle_GetRVAFieldInfo) DllImportEntry(RuntimeFieldHandle_GetFieldDataReference) + DllImportEntry(UnsafeAccessors_ResolveGenericParamToTypeHandle) DllImportEntry(StackTrace_GetStackFramesInternal) DllImportEntry(StackFrame_GetMethodDescFromNativeIP) DllImportEntry(ModuleBuilder_GetStringConstant) diff --git a/src/coreclr/vm/siginfo.cpp b/src/coreclr/vm/siginfo.cpp index 12c4e9eff79d76..ee39fbf5cbe676 100644 --- a/src/coreclr/vm/siginfo.cpp +++ b/src/coreclr/vm/siginfo.cpp @@ -387,6 +387,50 @@ void SigPointer::ConvertToInternalSignature(Module* pSigModule, const SigTypeCon } } +void SigPointer::CopyModOptsReqs(Module* pSigModule, SigBuilder* pSigBuilder) +{ + CONTRACTL + { + INSTANCE_CHECK; + STANDARD_VM_CHECK; + } + CONTRACTL_END + + CorElementType typ; + IfFailThrowBF(PeekElemType(&typ), BFA_BAD_COMPLUS_SIG, pSigModule); + while (typ == ELEMENT_TYPE_CMOD_REQD || typ == ELEMENT_TYPE_CMOD_OPT) + { + // Skip the custom modifier + IfFailThrowBF(GetByte(NULL), BFA_BAD_COMPLUS_SIG, pSigModule); + + // Get the encoded token. + uint32_t token; + IfFailThrowBF(GetToken(&token), BFA_BAD_COMPLUS_SIG, pSigModule); + + // Append the custom modifier and encoded token to the signature. + pSigBuilder->AppendElementType(typ); + pSigBuilder->AppendToken(token); + + typ = ELEMENT_TYPE_END; + IfFailThrowBF(PeekElemType(&typ), BFA_BAD_COMPLUS_SIG, pSigModule); + } +} + +void SigPointer::CopyExactlyOne(Module* pSigModule, SigBuilder* pSigBuilder) +{ + CONTRACTL + { + INSTANCE_CHECK; + STANDARD_VM_CHECK; + } + CONTRACTL_END + + intptr_t beginExactlyOne = (intptr_t)m_ptr; + IfFailThrowBF(SkipExactlyOne(), BFA_BAD_COMPLUS_SIG, pSigModule); + intptr_t endExactlyOne = (intptr_t)m_ptr; + pSigBuilder->AppendBlob((const PVOID)beginExactlyOne, endExactlyOne - beginExactlyOne); +} + void SigPointer::CopySignature(Module* pSigModule, SigBuilder* pSigBuilder, BYTE additionalCallConv) { CONTRACTL @@ -396,10 +440,10 @@ void SigPointer::CopySignature(Module* pSigModule, SigBuilder* pSigBuilder, BYTE } CONTRACTL_END - SigPointer spEnd(*this); - IfFailThrowBF(spEnd.SkipSignature(), BFA_BAD_COMPLUS_SIG, pSigModule); - pSigBuilder->AppendByte(*m_ptr | additionalCallConv); - pSigBuilder->AppendBlob((const PVOID)(m_ptr + 1), spEnd.m_ptr - (m_ptr + 1)); + PCCOR_SIGNATURE beginSignature = m_ptr; + IfFailThrowBF(SkipSignature(), BFA_BAD_COMPLUS_SIG, pSigModule); + pSigBuilder->AppendByte(*beginSignature | additionalCallConv); + pSigBuilder->AppendBlob((const PVOID)(beginSignature + 1), m_ptr - (beginSignature + 1)); } #endif // DACCESS_COMPILE @@ -3880,7 +3924,6 @@ MetaSig::CompareElementType( pOtherModule = pModule1; } - // Internal types can only correspond to types or value types. switch (eOtherType) { case ELEMENT_TYPE_OBJECT: @@ -3908,7 +3951,7 @@ MetaSig::CompareElementType( pOtherModule, tkOther, ClassLoader::ReturnNullIfNotFound, - ClassLoader::FailIfUninstDefOrRef); + ClassLoader::PermitUninstDefOrRef); return (hInternal == hOtherType); } @@ -5708,6 +5751,12 @@ TokenPairList TokenPairList::AdjustForTypeSpec(TokenPairList *pTemplate, ModuleB result.m_bInTypeEquivalenceForbiddenScope = !IsTdInterface(dwAttrType); } } + else if (elemType == ELEMENT_TYPE_INTERNAL) + { + TypeHandle typeHandle; + IfFailThrow(sig.GetPointer((void**)&typeHandle)); + result.m_bInTypeEquivalenceForbiddenScope = !typeHandle.IsInterface(); + } else { _ASSERTE(elemType == ELEMENT_TYPE_VALUETYPE); diff --git a/src/coreclr/vm/siginfo.hpp b/src/coreclr/vm/siginfo.hpp index 4610b66b6a9b4e..b0c7b3820230a2 100644 --- a/src/coreclr/vm/siginfo.hpp +++ b/src/coreclr/vm/siginfo.hpp @@ -129,7 +129,12 @@ class SigPointer : public SigParser void ConvertToInternalExactlyOne(Module* pSigModule, const SigTypeContext *pTypeContext, SigBuilder * pSigBuilder, BOOL bSkipCustomModifier = TRUE); void ConvertToInternalSignature(Module* pSigModule, const SigTypeContext *pTypeContext, SigBuilder * pSigBuilder, BOOL bSkipCustomModifier = TRUE); - void CopySignature(Module* pSigModule, SigBuilder * pSigBuilder, BYTE additionalCallConv); + + // Copy the current part of the signature to the SigBuilder. + // All copy methods advance internal state as if a Get was called. + void CopyModOptsReqs(Module* pSigModule, SigBuilder* pSigBuilder); + void CopyExactlyOne(Module* pSigModule, SigBuilder* pSigBuilder); + void CopySignature(Module* pSigModule, SigBuilder* pSigBuilder, BYTE additionalCallConv); //========================================================================= // The CLOSED interface for reading signatures. With the following diff --git a/src/coreclr/vm/stubgen.cpp b/src/coreclr/vm/stubgen.cpp index 6f44bae72be4ef..7b0a97a5c92ab7 100644 --- a/src/coreclr/vm/stubgen.cpp +++ b/src/coreclr/vm/stubgen.cpp @@ -1049,6 +1049,8 @@ LPCSTR ILCodeStream::GetStreamDescription(ILStubLinker::CodeStreamType streamTyp "ExceptionCleanup", "Cleanup", "ExceptionHandler", + "TypeCheckAndCallMethod", + "UpdateByRefsAndReturn" }; #ifdef _DEBUG @@ -1995,7 +1997,7 @@ DWORD StubSigBuilder::Append(LocalDesc* pLoc) m_pbSigCursor += sizeof(TypeHandle); m_cbSig += sizeof(TypeHandle); break; - + case ELEMENT_TYPE_CMOD_INTERNAL: { // Nove later elements in the signature to make room for the CMOD_INTERNAL payload @@ -3232,16 +3234,16 @@ int ILStubLinker::GetToken(MethodDesc* pMD, mdToken typeSignature, mdToken metho return m_tokenMap.GetToken(pMD, typeSignature, methodSignature); } -int ILStubLinker::GetToken(MethodTable* pMT) +int ILStubLinker::GetToken(TypeHandle th) { STANDARD_VM_CONTRACT; - return m_tokenMap.GetToken(TypeHandle(pMT)); + return m_tokenMap.GetToken(th); } -int ILStubLinker::GetToken(TypeHandle th) +int ILStubLinker::GetToken(TypeHandle th, mdToken typeSignature) { STANDARD_VM_CONTRACT; - return m_tokenMap.GetToken(th); + return m_tokenMap.GetToken(th, typeSignature); } int ILStubLinker::GetToken(FieldDesc* pFD) @@ -3352,6 +3354,11 @@ int ILCodeStream::GetToken(TypeHandle th) STANDARD_VM_CONTRACT; return m_pOwner->GetToken(th); } +int ILCodeStream::GetToken(TypeHandle th, mdToken typeSignature) +{ + STANDARD_VM_CONTRACT; + return m_pOwner->GetToken(th, typeSignature); +} int ILCodeStream::GetToken(FieldDesc* pFD) { STANDARD_VM_CONTRACT; diff --git a/src/coreclr/vm/stubgen.h b/src/coreclr/vm/stubgen.h index c848b0665008fa..2c4e843f6b9787 100644 --- a/src/coreclr/vm/stubgen.h +++ b/src/coreclr/vm/stubgen.h @@ -40,7 +40,7 @@ struct LocalDesc TypeHandle InternalToken; // only valid with ELEMENT_TYPE_INTERNAL // only valid with ELEMENT_TYPE_CMOD_INTERNAL - bool InternalModifierRequired; + bool InternalModifierRequired; TypeHandle InternalModifierToken; // used only for E_T_FNPTR and E_T_ARRAY @@ -313,12 +313,33 @@ class TokenLookupMap m_memberRefs.Set(pSrc->m_memberRefs); m_methodSpecs.Set(pSrc->m_methodSpecs); + m_typeSpecs.Set(pSrc->m_typeSpecs); } TypeHandle LookupTypeDef(mdToken token) { WRAPPER_NO_CONTRACT; - return LookupTokenWorker(token); + return LookupTokenWorker(token); + } + struct TypeSpecEntry final + { + mdToken ClassSignatureToken; + TypeHandle Type; + }; + TypeSpecEntry LookupTypeSpec(mdToken token) + { + CONTRACTL + { + NOTHROW; + MODE_ANY; + GC_NOTRIGGER; + PRECONDITION(RidFromToken(token) - 1 < m_typeSpecs.GetCount()); + PRECONDITION(RidFromToken(token) != 0); + PRECONDITION(TypeFromToken(token) == mdtTypeSpec); + } + CONTRACTL_END; + + return m_typeSpecs[static_cast(RidFromToken(token) - 1)]; } MethodDesc* LookupMethodDef(mdToken token) { @@ -398,10 +419,30 @@ class TokenLookupMap return SigPointer(pSig, cbSig); } - mdToken GetToken(TypeHandle pMT) + mdToken GetToken(TypeHandle th) { WRAPPER_NO_CONTRACT; - return GetTokenWorker(pMT); + return GetTokenWorker(th); + } + mdToken GetToken(TypeHandle th, mdToken typeSignature) + { + CONTRACTL + { + THROWS; + MODE_ANY; + GC_NOTRIGGER; + PRECONDITION(!th.IsNull()); + PRECONDITION(!th.IsTypeDesc() && th.GetMethodTable()->ContainsGenericVariables()); + PRECONDITION(typeSignature != mdTokenNil); + } + CONTRACTL_END; + + TypeSpecEntry* entry; + mdToken token = GetTypeSpecWorker(&entry); + entry->ClassSignatureToken = typeSignature; + entry->Type = th; + return token; + } mdToken GetToken(MethodDesc* pMD) { @@ -488,6 +529,22 @@ class TokenLookupMap } protected: + mdToken GetTypeSpecWorker(TypeSpecEntry** entry) + { + CONTRACTL + { + THROWS; + MODE_ANY; + GC_NOTRIGGER; + PRECONDITION(entry != NULL); + } + CONTRACTL_END; + + mdToken token = TokenFromRid(m_typeSpecs.GetCount(), mdtTypeSpec) + 1; + *entry = &*m_typeSpecs.Append(); // Dereference the iterator and then take the address + return token; + } + mdToken GetMemberRefWorker(MemberRefEntry** entry) { CONTRACTL @@ -566,6 +623,7 @@ class TokenLookupMap SArray, FALSE> m_signatures; SArray m_memberRefs; SArray m_methodSpecs; + SArray m_typeSpecs; }; class ILCodeLabel; @@ -688,6 +746,8 @@ class ILStubLinker kExceptionCleanup, kCleanup, kExceptionHandler, + kTypeCheckDispatch, + kUpdateByRefsReturn }; ILCodeStream* NewCodeStream(CodeStreamType codeStreamType); @@ -734,8 +794,8 @@ class ILStubLinker int GetToken(MethodDesc* pMD); int GetToken(MethodDesc* pMD, mdToken typeSignature); int GetToken(MethodDesc* pMD, mdToken typeSignature, mdToken methodSignature); - int GetToken(MethodTable* pMT); int GetToken(TypeHandle th); + int GetToken(TypeHandle th, mdToken typeSignature); int GetToken(FieldDesc* pFD); int GetToken(FieldDesc* pFD, mdToken typeSignature); int GetSigToken(PCCOR_SIGNATURE pSig, DWORD cbSig); @@ -971,6 +1031,7 @@ class ILCodeStream int GetToken(MethodDesc* pMD, mdToken typeSignature, mdToken methodSignature); int GetToken(MethodTable* pMT); int GetToken(TypeHandle th); + int GetToken(TypeHandle th, mdToken typeSignature); int GetToken(FieldDesc* pFD); int GetToken(FieldDesc* pFD, mdToken typeSignature); int GetSigToken(PCCOR_SIGNATURE pSig, DWORD cbSig); diff --git a/src/coreclr/vm/typehandle.h b/src/coreclr/vm/typehandle.h index 862e2fa137aeec..dcf1852bacab3b 100644 --- a/src/coreclr/vm/typehandle.h +++ b/src/coreclr/vm/typehandle.h @@ -84,12 +84,6 @@ class ComCallWrapperTemplate; class TypeHandle { public: - TypeHandle() { - LIMITED_METHOD_DAC_CONTRACT; - - m_asTAddr = 0; - } - static TypeHandle FromPtr(PTR_VOID aPtr) { LIMITED_METHOD_DAC_CONTRACT; @@ -104,29 +98,34 @@ class TypeHandle return TypeHandle(data); } + TypeHandle() + : m_asTAddr{ 0 } + { + LIMITED_METHOD_DAC_CONTRACT; + } + // When you ask for a class in JitInterface when all you have // is a methodDesc of an array method... // Convert from a JitInterface handle to an internal EE TypeHandle explicit TypeHandle(struct CORINFO_CLASS_STRUCT_*aPtr) + : m_asTAddr{ dac_cast(aPtr) } { LIMITED_METHOD_DAC_CONTRACT; - - m_asTAddr = dac_cast(aPtr); INDEBUGIMPL(Verify()); } - TypeHandle(MethodTable const * aMT) { + TypeHandle(MethodTable const * aMT) + : m_asTAddr{ dac_cast(aMT) } + { LIMITED_METHOD_DAC_CONTRACT; - - m_asTAddr = dac_cast(aMT); INDEBUGIMPL(Verify()); } - explicit TypeHandle(TypeDesc *aType) { + explicit TypeHandle(TypeDesc *aType) + : m_asTAddr{ dac_cast(aType) | 2 } + { LIMITED_METHOD_DAC_CONTRACT; _ASSERTE(aType); - - m_asTAddr = (dac_cast(aType) | 2); INDEBUGIMPL(Verify()); } @@ -138,9 +137,9 @@ class TypeHandle // TypeHandle::FromPtr and TypeHandle::TAddr instead of these constructors. // Allowing a public constructor that takes a "void *" or a "TADDR" is error-prone. explicit TypeHandle(TADDR aTAddr) + : m_asTAddr{ aTAddr } { LIMITED_METHOD_DAC_CONTRACT; - m_asTAddr = aTAddr; INDEBUGIMPL(Verify()); } @@ -719,7 +718,7 @@ class Instantiation bool ContainsAllOneType(TypeHandle th) { - for (auto i = GetNumArgs(); i > 0;) + for (DWORD i = GetNumArgs(); i > 0;) { if ((*this)[--i] != th) return false; diff --git a/src/coreclr/vm/typeparse.cpp b/src/coreclr/vm/typeparse.cpp index e66d3baef9d240..ae3fdf778f76a2 100644 --- a/src/coreclr/vm/typeparse.cpp +++ b/src/coreclr/vm/typeparse.cpp @@ -7,7 +7,7 @@ #include "common.h" #include "typeparse.h" -static TypeHandle GetTypeHelper(LPCWSTR szTypeName, Assembly* pRequestingAssembly, BOOL bThrowIfNotFound, BOOL bRequireAssemblyQualifiedName) +static TypeHandle GetTypeHelper(LPCWSTR szTypeName, Assembly* pRequestingAssembly, BOOL bThrowIfNotFound, BOOL bRequireAssemblyQualifiedName, MethodDesc* unsafeAccessorMethod) { CONTRACTL { @@ -32,11 +32,12 @@ static TypeHandle GetTypeHelper(LPCWSTR szTypeName, Assembly* pRequestingAssembl OVERRIDE_TYPE_LOAD_LEVEL_LIMIT(CLASS_LOADED); PREPARE_NONVIRTUAL_CALLSITE(METHOD__TYPE_NAME_RESOLVER__GET_TYPE_HELPER); - DECLARE_ARGHOLDER_ARRAY(args, 4); + DECLARE_ARGHOLDER_ARRAY(args, 5); args[ARGNUM_0] = PTR_TO_ARGHOLDER(szTypeName); args[ARGNUM_1] = OBJECTREF_TO_ARGHOLDER(objRequestingAssembly); args[ARGNUM_2] = BOOL_TO_ARGHOLDER(bThrowIfNotFound); args[ARGNUM_3] = BOOL_TO_ARGHOLDER(bRequireAssemblyQualifiedName); + args[ARGNUM_4] = PTR_TO_ARGHOLDER(unsafeAccessorMethod); REFLECTCLASSBASEREF objType = NULL; CALL_MANAGED_METHOD_RETREF(objType, REFLECTCLASSBASEREF, args); @@ -55,17 +56,17 @@ TypeHandle TypeName::GetTypeReferencedByCustomAttribute(LPCUTF8 szTypeName, Asse { WRAPPER_NO_CONTRACT; StackSString sszAssemblyQualifiedName(SString::Utf8, szTypeName); - return GetTypeHelper(sszAssemblyQualifiedName.GetUnicode(), pRequestingAssembly, TRUE /* bThrowIfNotFound */, FALSE /* bRequireAssemblyQualifiedName */); + return GetTypeHelper(sszAssemblyQualifiedName.GetUnicode(), pRequestingAssembly, TRUE /* bThrowIfNotFound */, FALSE /* bRequireAssemblyQualifiedName */, NULL /* unsafeAccessorMethod */); } -TypeHandle TypeName::GetTypeReferencedByCustomAttribute(LPCWSTR szTypeName, Assembly* pRequestingAssembly) +TypeHandle TypeName::GetTypeReferencedByCustomAttribute(LPCWSTR szTypeName, Assembly* pRequestingAssembly, MethodDesc* unsafeAccessorMethod) { WRAPPER_NO_CONTRACT; - return GetTypeHelper(szTypeName, pRequestingAssembly, TRUE /* bThrowIfNotFound */, FALSE /* bRequireAssemblyQualifiedName */); + return GetTypeHelper(szTypeName, pRequestingAssembly, TRUE /* bThrowIfNotFound */, FALSE /* bRequireAssemblyQualifiedName */, unsafeAccessorMethod); } TypeHandle TypeName::GetTypeFromAsmQualifiedName(LPCWSTR szFullyQualifiedName, BOOL bThrowIfNotFound) { WRAPPER_NO_CONTRACT; - return GetTypeHelper(szFullyQualifiedName, NULL, bThrowIfNotFound, TRUE /* bRequireAssemblyQualifiedName */); + return GetTypeHelper(szFullyQualifiedName, NULL, bThrowIfNotFound, TRUE /* bRequireAssemblyQualifiedName */, NULL /* unsafeAccessorMethod */); } diff --git a/src/coreclr/vm/typeparse.h b/src/coreclr/vm/typeparse.h index cbf96f40389a3f..2031bba94421ca 100644 --- a/src/coreclr/vm/typeparse.h +++ b/src/coreclr/vm/typeparse.h @@ -57,7 +57,7 @@ class TypeName // //-------------------------------------------------------------------------------------------- static TypeHandle GetTypeReferencedByCustomAttribute(LPCUTF8 szTypeName, Assembly *pRequestingAssembly); - static TypeHandle GetTypeReferencedByCustomAttribute(LPCWSTR szTypeName, Assembly *pRequestingAssembly); + static TypeHandle GetTypeReferencedByCustomAttribute(LPCWSTR szTypeName, Assembly *pRequestingAssembly, MethodDesc* unsafeAccessorMethod = NULL); }; #endif diff --git a/src/coreclr/vm/unsafeaccessors.cpp b/src/coreclr/vm/unsafeaccessors.cpp index b740e8b642fb20..0badb1570bee2e 100644 --- a/src/coreclr/vm/unsafeaccessors.cpp +++ b/src/coreclr/vm/unsafeaccessors.cpp @@ -3,6 +3,7 @@ #include "common.h" #include "customattribute.h" +#include "typeparse.h" namespace { @@ -60,12 +61,29 @@ namespace return true; } + struct ParamDetails final + { + ParamDetails() = default; + + ParamDetails(TypeHandle type, CorParamAttr attrs) + : Type{ type } + , Attrs{ attrs } + { } + + ParamDetails(const ParamDetails&) = default; + + TypeHandle Type; + CorParamAttr Attrs; + }; + struct GenerationContext final { GenerationContext(UnsafeAccessorKind kind, MethodDesc* pMD) : Kind{ kind } , Declaration{ pMD } - , DeclarationSig{ pMD } + , DeclarationSig{ pMD->GetSigPointer() } + , DeclarationMetaSig{ pMD } + , TranslatedParams{} , TargetTypeSig{} , TargetType{} , IsTargetStatic{ false } @@ -75,7 +93,11 @@ namespace UnsafeAccessorKind Kind; MethodDesc* Declaration; - MetaSig DeclarationSig; + SigPointer DeclarationSig; // This is the official declaration signature. It may be modified + // to include the UnsafeAccessorTypeAttribute types. + MetaSig DeclarationMetaSig; + NewArrayHolder TranslatedParams; // Redefined types from UnsafeAccessorTypeAttribute usage. + // Return type is at 0 index. Function arguments are 1 to N. SigPointer TargetTypeSig; TypeHandle TargetType; bool IsTargetStatic; @@ -83,8 +105,268 @@ namespace FieldDesc* TargetField; }; + void AppendTypeToSignature( + SigBuilder& sig, + TypeHandle th) + { + STANDARD_VM_CONTRACT; + _ASSERTE(!th.IsNull()); + + // + // Building the signature follows details defined in ECMA-335 - II.23.2.12 + // + + CorElementType elemType = th.GetSignatureCorElementType(); + if (CorIsPrimitiveType(elemType)) + { + sig.AppendElementType(elemType); + return; + } + + if (th.IsGenericVariable()) + { + TypeVarTypeDesc* typeVar = th.AsGenericVariable(); + sig.AppendElementType(typeVar->GetInternalCorElementType()); + sig.AppendData(typeVar->GetIndex()); + return; + } + + if (th.HasTypeParam()) + { + // Append the element type. + sig.AppendElementType(elemType); + TypeHandle typeParam = th.GetTypeParam(); + AppendTypeToSignature(sig, typeParam); + + // Append ArrayShape for MD arrays + // See II.23.2.13 ArrayShape + if (elemType == ELEMENT_TYPE_ARRAY) + { + DWORD rank = th.GetRank(); + sig.AppendData(rank); + + // Roslyn always emits size and lower bounds of 0 in C# signatures. + // In order to match the signature, we also need to append the number + // of lower bounds for each dimension. + // We can emit 0 for each lower bound, since UnsafeAccessors is only + // supported in C# and C# doesn't support lower bounds. + sig.AppendData(0); // Append the number of sizes. + sig.AppendData(rank); + for (DWORD i = 0; i < rank; i++) + sig.AppendData(0); + } + return; + } + + MethodTable* pMT = th.GetMethodTable(); + Instantiation inst = pMT->GetInstantiation(); + + // If we have any generic variables, mark as ELEMENT_TYPE_GENERICINST. + BOOL hasGenericVariables = !inst.IsEmpty(); + if (hasGenericVariables) + { + sig.AppendElementType(ELEMENT_TYPE_GENERICINST); + + // Embed the generic type definition in the signature. + th = TypeHandle{ pMT->GetTypicalMethodTable() }; + _ASSERTE(th.IsGenericTypeDefinition()); + } + + // Append the new type to the signature. + sig.AppendElementType(ELEMENT_TYPE_INTERNAL); + sig.AppendPointer(th.AsPtr()); + + // Append the instantiation types to the signature. + if (hasGenericVariables) + { + _ASSERTE(inst.GetNumArgs() > 0); + sig.AppendData(inst.GetNumArgs()); + for (DWORD i = 0; i < inst.GetNumArgs(); i++) + { + AppendTypeToSignature(sig, inst[i]); + } + } + } + + void UpdateDeclarationSigWithTypes(GenerationContext& cxt) + { + STANDARD_VM_CONTRACT; + _ASSERTE(cxt.Declaration != NULL); + _ASSERTE(cxt.TranslatedParams != NULL); + + // + // Parsing and building the signature follows details defined in ECMA-335 - II.23.2.1 + // + + Module* pSigModule = cxt.Declaration->GetModule(); + + // Read the current signature and copy it, updating the + // types for the parameters that had UnsafeAccessorTypeAttribute. + SigPointer origSig = cxt.Declaration->GetSigPointer(); + + // We're going to be modifying the signature to include the translated types. + SigBuilder newSig; + + uint32_t callConvDecl; + IfFailThrow(origSig.GetCallingConvInfo(&callConvDecl)); + newSig.AppendByte((BYTE)callConvDecl); + + if (callConvDecl & IMAGE_CEE_CS_CALLCONV_GENERIC) + { + uint32_t declGenericCount; + IfFailThrow(origSig.GetData(&declGenericCount)); + newSig.AppendData(declGenericCount); + } + + uint32_t declArgCount; + IfFailThrow(origSig.GetData(&declArgCount)); + newSig.AppendData(declArgCount); + + // Now we can copy over the return type and arguments. + // The format for the return type is the same as the arguments, + // except return parameters can be ELEMENT_TYPE_VOID. + const uint32_t totalParamCount = declArgCount + 1; + for (uint32_t i = 0; i < totalParamCount; ++i) + { + // Copy over any modopts or modreqs. + origSig.CopyModOptsReqs(pSigModule, &newSig); + + TypeHandle newTypeMaybe = cxt.TranslatedParams[i].Type; + if (newTypeMaybe.IsNull()) + { + // Copy the original parameter and continue. + origSig.CopyExactlyOne(pSigModule, &newSig); + continue; + } + + // We have a new type to insert and need to update this + // parameter in the signature. + TypeHandle currHandle = origSig.GetTypeHandleThrowing(pSigModule, NULL); + + // Since byrefs don't support variance, we can't allow returning a + // fully typed byref as a byref to an "opaque" type (for example, "ref string" -> "ref object"). + // This is blocked for type safety reasons. + if (i == 0 // Return type + && currHandle.IsByRef()) + { + ThrowHR(COR_E_NOTSUPPORTED, BFA_INVALID_UNSAFEACCESSORTYPE); + } + + // SigPointer::GetTypeHandleThrowing() is non-consuming, so we need to + // consume the signature to move past the current type. + IfFailThrow(origSig.SkipExactlyOne()); + + bool isValid; + if (newTypeMaybe.IsByRef()) + { + isValid = currHandle.IsByRef() + && currHandle.GetTypeParam() == TypeHandle{ g_pObjectClass }; + } + else if (newTypeMaybe.IsPointer()) + { + isValid = currHandle.IsPointer() + && currHandle.GetTypeParam() == TypeHandle{ CoreLibBinder::GetClass(CLASS__VOID) }; + } + else + { + _ASSERTE(!newTypeMaybe.IsValueType()); + isValid = currHandle == TypeHandle{ g_pObjectClass }; + } + + if (!isValid) + ThrowHR(COR_E_BADIMAGEFORMAT, BFA_INVALID_UNSAFEACCESSORTYPE); + + // Append the new type to the signature. + AppendTypeToSignature(newSig, newTypeMaybe); + } + + // Create a copy of the new signature and store it on the context. + DWORD newSigLen; + void* newSigRaw = newSig.GetSignature(&newSigLen); + + // Allocate the signature memory on the loader allocator associated + // with the declaration method. + void* newSigAlloc = cxt.Declaration->GetLoaderAllocator()->GetLowFrequencyHeap()->AllocMem(S_SIZE_T(newSigLen)); + memcpy(newSigAlloc, newSigRaw, newSigLen); + + // Update the declaration signature with the new signature. + cxt.DeclarationSig = SigPointer{ (PCCOR_SIGNATURE)newSigAlloc, newSigLen }; + SigTypeContext tmpContext{ cxt.Declaration }; + cxt.DeclarationMetaSig = MetaSig{ (PCCOR_SIGNATURE)newSigAlloc, newSigLen, cxt.Declaration->GetModule(), &tmpContext }; + } + + void ProcessUnsafeAccessorTypeAttributes(GenerationContext& cxt) + { + STANDARD_VM_CONTRACT; + + // Acquire attribute name to search for. + const char* typeAttrName = GetWellKnownAttributeName(WellKnownAttribute::UnsafeAccessorTypeAttribute); + + // Determine the max parameter count. +1 for the return value, which is always index 0. + const uint32_t totalParamCount = cxt.DeclarationMetaSig.NumFixedArgs() + 1; + + // Inspect all parameters on the declaration method for UnsafeAccessorTypeAttribute. + uint32_t attrCount = 0; + IMDInternalImport *pInternalImport = cxt.Declaration->GetModule()->GetMDImport(); + HENUMInternalHolder hEnumParams(pInternalImport); + hEnumParams.EnumInit(mdtParamDef, cxt.Declaration->GetMemberDef()); + mdParamDef currParamDef = mdParamDefNil; + while (hEnumParams.EnumNext(&currParamDef)) + { + const void *pData; + ULONG cbData; + HRESULT hr = IfFailThrow(pInternalImport->GetCustomAttributeByName(currParamDef, typeAttrName, &pData, &cbData)); + if (hr != S_OK) + continue; + + // The first time we find an attribute, we allocate the translations array. + if (attrCount == 0) + cxt.TranslatedParams = new ParamDetails[totalParamCount]; + + // Parse the attribute data. + CustomAttributeParser cap(pData, cbData); + IfFailThrow(cap.ValidateProlog()); + LPCUTF8 typeString; + ULONG typeStringLen; + IfFailThrow(cap.GetNonNullString(&typeString, &typeStringLen)); + + StackSString typeStringUtf8{ SString::Utf8, typeString, typeStringLen }; + + // Pass the string in the attribute to similar logic as Type.GetType(String). + // The below API will handle any dependency between the returned type and the + // requesting assembly for the purposes of lifetime tracking of collectible types. + TypeHandle targetType = TypeName::GetTypeReferencedByCustomAttribute( + typeStringUtf8.GetUnicode(), + cxt.Declaration->GetAssembly(), + cxt.Declaration /* unsafeAccessorMethod */); + _ASSERTE(!targetType.IsNull()); + + // Future versions of the runtime may support + // UnsafeAccessorTypeAttribute on value types. + if (targetType.IsValueType()) + ThrowHR(COR_E_NOTSUPPORTED, BFA_INVALID_UNSAFEACCESSORTYPE_VALUETYPE); + + USHORT seq; + DWORD attr; + LPCSTR paramName; + IfFailThrow(pInternalImport->GetParamDefProps(currParamDef, &seq, &attr, ¶mName)); + + if (seq >= totalParamCount) + ThrowHR(COR_E_BADIMAGEFORMAT, BFA_INVALID_UNSAFEACCESSORTYPE); + + // Store the TypeHandle for the loaded type at the sequence number for the parameter. + cxt.TranslatedParams[seq] = { targetType, (CorParamAttr)attr }; + attrCount++; + } + + // Update the declaration signatures if any instances of UnsafeAccessorTypeAttribute were found. + if (attrCount != 0) + UpdateDeclarationSigWithTypes(cxt); + } + TypeHandle ValidateTargetType(TypeHandle targetTypeMaybe, CorElementType targetFromSig) { + STANDARD_VM_CONTRACT; TypeHandle targetType = targetTypeMaybe.IsByRef() ? targetTypeMaybe.GetTypeParam() : targetTypeMaybe; @@ -112,17 +394,17 @@ namespace _ASSERTE(method != NULL); PCCOR_SIGNATURE pSig1; - DWORD cSig1; - cxt.Declaration->GetSig(&pSig1, &cSig1); + uint32_t cSig1; + cxt.DeclarationSig.GetSignature(&pSig1, &cSig1); PCCOR_SIGNATURE pEndSig1 = pSig1 + cSig1; - ModuleBase* pModule1 = cxt.Declaration->GetModule(); + Module* pModule1 = cxt.Declaration->GetModule(); const Substitution* pSubst1 = NULL; PCCOR_SIGNATURE pSig2; DWORD cSig2; method->GetSig(&pSig2, &cSig2); PCCOR_SIGNATURE pEndSig2 = pSig2 + cSig2; - ModuleBase* pModule2 = method->GetModule(); + Module* pModule2 = method->GetModule(); const Substitution* pSubst2 = NULL; // @@ -225,14 +507,10 @@ namespace void VerifyDeclarationSatisfiesTargetConstraints(MethodDesc* declaration, MethodTable* targetType, MethodDesc* targetMethod) { - CONTRACTL - { - STANDARD_VM_CHECK; - PRECONDITION(declaration != NULL); - PRECONDITION(targetType != NULL); - PRECONDITION(targetMethod != NULL); - } - CONTRACTL_END; + STANDARD_VM_CONTRACT; + _ASSERTE(declaration != NULL); + _ASSERTE(targetType != NULL); + _ASSERTE(targetMethod != NULL); // If the target method has no generic parameters there is nothing to verify if (!targetMethod->HasClassOrMethodInstantiation()) @@ -363,8 +641,8 @@ namespace _ASSERTE(field != NULL); PCCOR_SIGNATURE pSig1; - DWORD cSig1; - cxt.Declaration->GetSig(&pSig1, &cSig1); + uint32_t cSig1; + cxt.DeclarationSig.GetSignature(&pSig1, &cSig1); PCCOR_SIGNATURE pEndSig1 = pSig1 + cSig1; ModuleBase* pModule1 = cxt.Declaration->GetModule(); const Substitution* pSubst1 = NULL; @@ -477,6 +755,81 @@ namespace return false; } + void EmitTypeCheck(UINT argId, const ParamDetails& param, ILCodeStream* pDispatchCode, ILCodeStream* pReturnCode) + { + STANDARD_VM_CONTRACT; + _ASSERTE(pDispatchCode != NULL); + _ASSERTE(pReturnCode != NULL); + + if (param.Type.IsNull()) + return; + + TypeHandle th = param.Type; + DWORD localIndex = MAXDWORD; + + // We are going to emit a type check for the UnsafeAccessorTypeAttribute + // scenario. We do this because the declared signature isn't going to enforce + // type safety, so we do it here. + // + // Ensuring type safety is paramount in the UnsafeAccessorTypeAttribute scenario + // so when a byref is involved, we pass a byref to a local variable after the type + // check as opposed to simply forwarding the original byref argument. This does + // mean the byref itself isn't the same as the input byref, but the byref is now + // verifiable. + if (th.IsByRef()) + { + th = th.GetTypeParam(); + LocalDesc typedLocal{ th }; + localIndex = pDispatchCode->NewLocal(typedLocal); + + pDispatchCode->EmitLDIND_REF(); + } + else if (th.IsPointer()) + { + // Pointer types are not verifiable, so we skip the type check. + return; + } + _ASSERTE(!th.IsTypeDesc()); + + int tk; + if (!th.GetMethodTable()->ContainsGenericVariables()) + { + tk = pDispatchCode->GetToken(th); + } + else + { + SigBuilder sigBuilder; + AppendTypeToSignature(sigBuilder, th); + + uint32_t sigLen; + PCCOR_SIGNATURE sig = (PCCOR_SIGNATURE)sigBuilder.GetSignature((DWORD*)&sigLen); + mdToken typeSig = pDispatchCode->GetSigToken(sig, sigLen); + tk = pDispatchCode->GetToken(th, typeSig); + } + + // Perform the type check. + // If the type is a reference type, unbox.any has the same semantics as castclass. + pDispatchCode->EmitUNBOX_ANY(tk); + + // If we have a local variable, we need to store the result + // in the local variable and load a byref to the local variable as + // the argument to the target method. Finally, we may need to update + // the byref on return. + if (localIndex != MAXDWORD) + { + pDispatchCode->EmitSTLOC(localIndex); + pDispatchCode->EmitLDLOCA(localIndex); + + // Update the byref on return, except if it is marked "in". + if (!IsPdIn(param.Attrs)) + { + pReturnCode->EmitLDARG(argId); + pReturnCode->EmitLDLOC(localIndex); + pReturnCode->EmitSTIND_REF(); + } + } + } + void GenerateAccessor( GenerationContext& cxt, DynamicResolver** resolver, @@ -496,25 +849,32 @@ namespace ILStubLinker sl( cxt.Declaration->GetModule(), - cxt.Declaration->GetSignature(), + cxt.Declaration->GetSignature(), // Must be the MethodDesc's declaration signature, not the DeclarationSig field. &genericContext, cxt.TargetMethod, (ILStubLinkerFlags)ILSTUB_LINKER_FLAG_NONE); - ILCodeStream* pCode = sl.NewCodeStream(ILStubLinker::kDispatch); + ILCodeStream* pDispatchCode = sl.NewCodeStream(ILStubLinker::kTypeCheckDispatch); + ILCodeStream* pReturnCode = sl.NewCodeStream(ILStubLinker::kUpdateByRefsReturn); // Load stub arguments. // When the target is static, the first argument is only // used to look up the target member to access and ignored // during dispatch. UINT beginIndex = cxt.IsTargetStatic ? 1 : 0; - UINT stubArgCount = cxt.DeclarationSig.NumFixedArgs(); + UINT stubArgCount = cxt.DeclarationMetaSig.NumFixedArgs(); for (UINT i = beginIndex; i < stubArgCount; ++i) - pCode->EmitLDARG(i); + { + pDispatchCode->EmitLDARG(i); + + // Perform a typecheck if the type was translated. + if (cxt.TranslatedParams != NULL) + EmitTypeCheck(i, cxt.TranslatedParams[i + 1], pDispatchCode, pReturnCode); // Index is +1 to account for the return value. + } // Provide access to the target member UINT targetArgCount = stubArgCount - beginIndex; - UINT targetRetCount = cxt.DeclarationSig.IsReturnTypeVoid() ? 0 : 1; + UINT targetRetCount = cxt.DeclarationMetaSig.IsReturnTypeVoid() ? 0 : 1; switch (cxt.Kind) { case UnsafeAccessorKind::Constructor: @@ -523,17 +883,17 @@ namespace mdToken target; if (!cxt.TargetType.HasInstantiation()) { - target = pCode->GetToken(cxt.TargetMethod); + target = pDispatchCode->GetToken(cxt.TargetMethod); } else { PCCOR_SIGNATURE sig; uint32_t sigLen; cxt.TargetTypeSig.GetSignature(&sig, &sigLen); - mdToken targetTypeSigToken = pCode->GetSigToken(sig, sigLen); - target = pCode->GetToken(cxt.TargetMethod, targetTypeSigToken); + mdToken targetTypeSigToken = pDispatchCode->GetSigToken(sig, sigLen); + target = pDispatchCode->GetToken(cxt.TargetMethod, targetTypeSigToken); } - pCode->EmitNEWOBJ(target, targetArgCount); + pDispatchCode->EmitNEWOBJ(target, targetArgCount); break; } case UnsafeAccessorKind::Method: @@ -543,7 +903,7 @@ namespace mdToken target; if (!cxt.TargetMethod->HasClassOrMethodInstantiation()) { - target = pCode->GetToken(cxt.TargetMethod); + target = pDispatchCode->GetToken(cxt.TargetMethod); } else { @@ -563,18 +923,18 @@ namespace sigBuilder.AppendElementType(ELEMENT_TYPE_MVAR); sigBuilder.AppendData(i); } - sigLen; + sig = (PCCOR_SIGNATURE)sigBuilder.GetSignature((DWORD*)&sigLen); - methodSpecSigToken = pCode->GetSigToken(sig, sigLen); + methodSpecSigToken = pDispatchCode->GetSigToken(sig, sigLen); } cxt.TargetTypeSig.GetSignature(&sig, &sigLen); - mdToken targetTypeSigToken = pCode->GetSigToken(sig, sigLen); + mdToken targetTypeSigToken = pDispatchCode->GetSigToken(sig, sigLen); if (methodSpecSigToken == mdTokenNil) { // Create a MemberRef - target = pCode->GetToken(cxt.TargetMethod, targetTypeSigToken); + target = pDispatchCode->GetToken(cxt.TargetMethod, targetTypeSigToken); _ASSERTE(TypeFromToken(target) == mdtMemberRef); } else @@ -584,18 +944,18 @@ namespace MethodDesc* instantiatedTarget = MethodDesc::FindOrCreateAssociatedMethodDesc(cxt.TargetMethod, cxt.TargetType.GetMethodTable(), FALSE, methodInst, TRUE); // Create a MethodSpec - target = pCode->GetToken(instantiatedTarget, targetTypeSigToken, methodSpecSigToken); + target = pDispatchCode->GetToken(instantiatedTarget, targetTypeSigToken, methodSpecSigToken); _ASSERTE(TypeFromToken(target) == mdtMethodSpec); } } if (cxt.Kind == UnsafeAccessorKind::StaticMethod) { - pCode->EmitCALL(target, targetArgCount, targetRetCount); + pDispatchCode->EmitCALL(target, targetArgCount, targetRetCount); } else { - pCode->EmitCALLVIRT(target, targetArgCount, targetRetCount); + pDispatchCode->EmitCALLVIRT(target, targetArgCount, targetRetCount); } break; } @@ -605,15 +965,15 @@ namespace mdToken target; if (!cxt.TargetType.HasInstantiation()) { - target = pCode->GetToken(cxt.TargetField); + target = pDispatchCode->GetToken(cxt.TargetField); } else { // See the static field case for why this can be mdTokenNil. mdToken targetTypeSigToken = mdTokenNil; - target = pCode->GetToken(cxt.TargetField, targetTypeSigToken); + target = pDispatchCode->GetToken(cxt.TargetField, targetTypeSigToken); } - pCode->EmitLDFLDA(target); + pDispatchCode->EmitLDFLDA(target); break; } case UnsafeAccessorKind::StaticField: @@ -621,7 +981,7 @@ namespace mdToken target; if (!cxt.TargetType.HasInstantiation()) { - target = pCode->GetToken(cxt.TargetField); + target = pDispatchCode->GetToken(cxt.TargetField); } else { @@ -634,17 +994,17 @@ namespace PCCOR_SIGNATURE sig; uint32_t sigLen; cxt.TargetTypeSig.GetSignature(&sig, &sigLen); - mdToken targetTypeSigToken = pCode->GetSigToken(sig, sigLen); - target = pCode->GetToken(cxt.TargetField, targetTypeSigToken); + mdToken targetTypeSigToken = pDispatchCode->GetSigToken(sig, sigLen); + target = pDispatchCode->GetToken(cxt.TargetField, targetTypeSigToken); } - pCode->EmitLDSFLDA(target); + pDispatchCode->EmitLDSFLDA(target); break; default: _ASSERTE(!"Unknown UnsafeAccessorKind"); } // Return from the generated stub - pCode->EmitRET(); + pReturnCode->EmitRET(); // Generate all IL associated data for JIT { @@ -701,6 +1061,9 @@ bool MethodDesc::TryGenerateUnsafeAccessor(DynamicResolver** resolver, COR_ILMET GenerationContext context{ kind, this }; + // Parse the signature and check for instances of UnsafeAccessorTypeAttribute. + ProcessUnsafeAccessorTypeAttributes(context); + // Parse the signature to determine the type to use: // * Constructor access - examine the return type // * Instance member access - examine type of first parameter @@ -709,17 +1072,17 @@ bool MethodDesc::TryGenerateUnsafeAccessor(DynamicResolver** resolver, COR_ILMET CorElementType retCorType; TypeHandle firstArgType; CorElementType firstArgCorType = ELEMENT_TYPE_END; - retCorType = context.DeclarationSig.GetReturnType(); - retType = context.DeclarationSig.GetRetTypeHandleThrowing(); - UINT argCount = context.DeclarationSig.NumFixedArgs(); + retCorType = context.DeclarationMetaSig.GetReturnType(); + retType = context.DeclarationMetaSig.GetRetTypeHandleThrowing(); + UINT argCount = context.DeclarationMetaSig.NumFixedArgs(); if (argCount > 0) { - context.DeclarationSig.NextArg(); + context.DeclarationMetaSig.NextArg(); // Get the target type signature and resolve to a type handle. - context.TargetTypeSig = context.DeclarationSig.GetArgProps(); + context.TargetTypeSig = context.DeclarationMetaSig.GetArgProps(); (void)context.TargetTypeSig.PeekElemType(&firstArgCorType); - firstArgType = context.DeclarationSig.GetLastTypeHandleThrowing(); + firstArgType = context.DeclarationMetaSig.GetLastTypeHandleThrowing(); } // Using the kind type, perform the following: @@ -732,13 +1095,13 @@ bool MethodDesc::TryGenerateUnsafeAccessor(DynamicResolver** resolver, COR_ILMET // we don't know the type to construct. // Types should not be parameterized (that is, byref). // The name is defined by the runtime and should be empty. - if (context.DeclarationSig.IsReturnTypeVoid() || retType.IsByRef() || !name.IsEmpty()) + if (context.DeclarationMetaSig.IsReturnTypeVoid() || retType.IsByRef() || !name.IsEmpty()) { ThrowHR(COR_E_BADIMAGEFORMAT, BFA_INVALID_UNSAFEACCESSOR); } // Get the target type signature from the return type. - context.TargetTypeSig = context.DeclarationSig.GetReturnProps(); + context.TargetTypeSig = context.DeclarationMetaSig.GetReturnProps(); context.TargetType = ValidateTargetType(retType, retCorType); if (!TrySetTargetMethod(context, ".ctor")) MemberLoader::ThrowMissingMethodException(context.TargetType.AsMethodTable(), ".ctor"); @@ -768,7 +1131,7 @@ bool MethodDesc::TryGenerateUnsafeAccessor(DynamicResolver** resolver, COR_ILMET case UnsafeAccessorKind::Field: case UnsafeAccessorKind::StaticField: // Field access requires a single argument for target type and a return type. - if (argCount != 1 || firstArgType.IsNull() || context.DeclarationSig.IsReturnTypeVoid()) + if (argCount != 1 || firstArgType.IsNull() || context.DeclarationMetaSig.IsReturnTypeVoid()) { ThrowHR(COR_E_BADIMAGEFORMAT, BFA_INVALID_UNSAFEACCESSOR); } @@ -798,3 +1161,25 @@ bool MethodDesc::TryGenerateUnsafeAccessor(DynamicResolver** resolver, COR_ILMET GenerateAccessor(context, resolver, methodILDecoder); return true; } + +extern "C" void* QCALLTYPE UnsafeAccessors_ResolveGenericParamToTypeHandle(MethodDesc* unsafeAccessorMethod, BOOL isMethodParam, DWORD paramIndex) +{ + QCALL_CONTRACT; + _ASSERTE(unsafeAccessorMethod != NULL); + + TypeHandle ret; + + BEGIN_QCALL; + + MethodDesc* typicalMD = unsafeAccessorMethod->LoadTypicalMethodDefinition(); + Instantiation genericParams = isMethodParam + ? typicalMD->GetMethodInstantiation() + : typicalMD->GetClassInstantiation(); + + if (0 <= paramIndex && paramIndex < genericParams.GetNumArgs()) + ret = genericParams[paramIndex]; + + END_QCALL; + + return ret.AsPtr(); +} diff --git a/src/coreclr/vm/wellknownattributes.h b/src/coreclr/vm/wellknownattributes.h index 0659f736beee54..3f7496ac1b68a9 100644 --- a/src/coreclr/vm/wellknownattributes.h +++ b/src/coreclr/vm/wellknownattributes.h @@ -36,6 +36,7 @@ enum class WellKnownAttribute : DWORD ObjectiveCTrackedTypeAttribute, InlineArrayAttribute, UnsafeAccessorAttribute, + UnsafeAccessorTypeAttribute, CountOfWellKnownAttributes }; @@ -137,6 +138,9 @@ inline const char *GetWellKnownAttributeName(WellKnownAttribute attribute) case WellKnownAttribute::UnsafeAccessorAttribute: ret = "System.Runtime.CompilerServices.UnsafeAccessorAttribute"; break; + case WellKnownAttribute::UnsafeAccessorTypeAttribute: + ret = "System.Runtime.CompilerServices.UnsafeAccessorTypeAttribute"; + break; case WellKnownAttribute::CountOfWellKnownAttributes: default: ret = nullptr; diff --git a/src/libraries/System.Private.CoreLib/src/System.Private.CoreLib.Shared.projitems b/src/libraries/System.Private.CoreLib/src/System.Private.CoreLib.Shared.projitems index d3f62e4c34b7d9..d55579d4221e22 100644 --- a/src/libraries/System.Private.CoreLib/src/System.Private.CoreLib.Shared.projitems +++ b/src/libraries/System.Private.CoreLib/src/System.Private.CoreLib.Shared.projitems @@ -908,6 +908,7 @@ + diff --git a/src/libraries/System.Private.CoreLib/src/System/Runtime/CompilerServices/UnsafeAccessorTypeAttribute.cs b/src/libraries/System.Private.CoreLib/src/System/Runtime/CompilerServices/UnsafeAccessorTypeAttribute.cs new file mode 100644 index 00000000000000..41a4f47c60bdab --- /dev/null +++ b/src/libraries/System.Private.CoreLib/src/System/Runtime/CompilerServices/UnsafeAccessorTypeAttribute.cs @@ -0,0 +1,53 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace System.Runtime.CompilerServices +{ + /// + /// Provides access to an inaccessible type. + /// + [AttributeUsage(AttributeTargets.Parameter | AttributeTargets.ReturnValue, AllowMultiple = false, Inherited = false)] + public sealed class UnsafeAccessorTypeAttribute : Attribute + { + /// + /// Instantiates an providing access to a type supplied by . + /// + /// A fully qualified or partially qualified type name. + /// + /// is expected to follow the same rules as if it were being + /// passed to . When unbound generics are involved they + /// should follow the IL syntax of referencing a type or method generic variables using + /// the syntax of !N or !!N respectively, where N is the zero-based index of the + /// generic parameters. The generic rules defined for + /// apply to this attribute as well, meaning the arity and type of generic parameter must match + /// the target type. + /// + /// This attribute only has behavior on parameters or return values of methods marked with . + /// + /// This attribute should only be applied to parameters or return types of methods that are + /// typed as follows: + /// + ///
    + ///
  • References should be typed as object.
  • + ///
  • Byref arguments should be typed with in, ref, or out to object.
  • + ///
  • Unmanaged pointers should be typed as void*.
  • + ///
  • Byref arguments to reference types or arrays should be typed with in, ref, or out to object.
  • + ///
  • Byref arguments to unmanaged pointer types should be typed with in, ref, or out to void*.
  • + ///
+ /// + /// Value types are not supported. + /// + /// Due to lack of variance for byrefs, returns involving byrefs are not supported. This + /// specifically means that accessors for fields of inaccessible types are not supported. + ///
+ public UnsafeAccessorTypeAttribute(string typeName) + { + TypeName = typeName; + } + + /// + /// Fully qualified or partially qualified type name to target. + /// + public string TypeName { get; } + } +} diff --git a/src/libraries/System.Runtime/ref/System.Runtime.cs b/src/libraries/System.Runtime/ref/System.Runtime.cs index 84742f4e9a03fe..00ff83a03c8995 100644 --- a/src/libraries/System.Runtime/ref/System.Runtime.cs +++ b/src/libraries/System.Runtime/ref/System.Runtime.cs @@ -14026,6 +14026,12 @@ public enum UnsafeAccessorKind Field = 3, StaticField = 4, } + [System.AttributeUsageAttribute(System.AttributeTargets.Parameter | System.AttributeTargets.ReturnValue, AllowMultiple=false, Inherited=false)] + public sealed partial class UnsafeAccessorTypeAttribute : System.Attribute + { + public UnsafeAccessorTypeAttribute(string typeName) { } + public string TypeName { get { throw null; } } + } [System.AttributeUsageAttribute(System.AttributeTargets.Struct)] public sealed partial class UnsafeValueTypeAttribute : System.Attribute { diff --git a/src/mono/mono/metadata/custom-attrs.c b/src/mono/mono/metadata/custom-attrs.c index e0e485dd3e4a44..8bb0fb414a235d 100644 --- a/src/mono/mono/metadata/custom-attrs.c +++ b/src/mono/mono/metadata/custom-attrs.c @@ -51,6 +51,7 @@ static GENERATE_GET_CLASS_WITH_CACHE (custom_attribute_typed_argument, "System.R static GENERATE_GET_CLASS_WITH_CACHE (custom_attribute_named_argument, "System.Reflection", "CustomAttributeNamedArgument"); static GENERATE_TRY_GET_CLASS_WITH_CACHE (customattribute_data, "System.Reflection", "RuntimeCustomAttributeData"); static GENERATE_TRY_GET_CLASS_WITH_CACHE (unsafe_accessor_attribute, "System.Runtime.CompilerServices", "UnsafeAccessorAttribute"); +static GENERATE_TRY_GET_CLASS_WITH_CACHE (unsafe_accessor_type_attribute, "System.Runtime.CompilerServices", "UnsafeAccessorTypeAttribute"); static MonoCustomAttrInfo* mono_custom_attrs_from_builders_handle (MonoImage *alloc_img, MonoImage *image, MonoArrayHandle cattrs, gboolean respect_cattr_visibility); @@ -2099,6 +2100,58 @@ mono_method_get_unsafe_accessor_attr_data (MonoMethod *method, int *accessor_kin return TRUE; } +gboolean +mono_method_param_get_unsafe_accessor_type_attr_data (MonoMethod *method, int param_seq, char **type_name, MonoError *error) +{ + MonoCustomAttrInfo *cinfo = mono_custom_attrs_from_param_checked (method, param_seq, error); + + if (!cinfo || !is_ok (error)) { + mono_error_cleanup (error); + return FALSE; + } + + MonoClass *unsafeAccessorType = mono_class_try_get_unsafe_accessor_type_attribute_class (); + MonoCustomAttrEntry *attr = NULL; + + for (int idx = 0; idx < cinfo->num_attrs; ++idx) { + MonoClass *ctor_class = cinfo->attrs [idx].ctor->klass; + if (ctor_class == unsafeAccessorType) { + attr = &cinfo->attrs [idx]; + break; + } + } + + if (!attr){ + if (!cinfo->cached) + mono_custom_attrs_free(cinfo); + return FALSE; + } + + MonoDecodeCustomAttr *decoded_args = mono_reflection_create_custom_attr_data_args_noalloc (m_class_get_image (attr->ctor->klass), attr->ctor, attr->data, attr->data_size, error); + + if (!is_ok (error)) { + mono_error_cleanup (error); + mono_reflection_free_custom_attr_data_args_noalloc (decoded_args); + if (!cinfo->cached) + mono_custom_attrs_free(cinfo); + return FALSE; + } + + g_assert (decoded_args->typed_args_num == 1); + const char *ptr = (const char*)decoded_args->typed_args [0]->value.primitive; + uint32_t len = mono_metadata_decode_value (ptr, &ptr); + char *name = m_method_alloc0 (method, len + 1); + memcpy (name, ptr, len); + name[len] = 0; + *type_name = (char*)name; + + mono_reflection_free_custom_attr_data_args_noalloc (decoded_args); + if (!cinfo->cached) + mono_custom_attrs_free(cinfo); + + return TRUE; +} + /** * mono_custom_attrs_from_class: */ diff --git a/src/mono/mono/metadata/marshal.c b/src/mono/mono/metadata/marshal.c index e68fba963e70e6..e6d7f80d8e7365 100644 --- a/src/mono/mono/metadata/marshal.c +++ b/src/mono/mono/metadata/marshal.c @@ -5241,6 +5241,62 @@ mono_marshal_get_array_accessor_wrapper (MonoMethod *method) return res; } +static void process_unsafe_accessor_type (MonoUnsafeAccessorKind kind, MonoMethod *accessor_method, MonoMethodSignature *tgt_sig) +{ + g_assert (accessor_method); + g_assert (tgt_sig); + + char *type_name; + MonoAssemblyLoadContext *alc = mono_alc_get_ambient (); + MonoImage *image = m_class_get_image (accessor_method->klass); + MonoType *type; + + // Iterate through all parameters. Zero is the return value, the arguments are 1-based. + for (guint16 seq = 0; seq <= tgt_sig->param_count; ++seq) { + + ERROR_DECL (error); + + type_name = NULL; + if (!mono_method_param_get_unsafe_accessor_type_attr_data (accessor_method, seq, &type_name, error)) + continue; + mono_error_assert_ok (error); + g_assert (type_name); + + type = mono_reflection_type_from_name_checked (type_name, alc, image, error); + if (!type) + continue; + mono_error_assert_ok (error); + + // Future versions of the runtime may support + // UnsafeAccessorTypeAttribute on value types. + g_assert (type->type != MONO_TYPE_VALUETYPE); + + if (seq == 0 && m_type_is_byref (tgt_sig->ret)) { + // [FIXME] UnsafeAccessorType is not supported on return that are byref, type safety issue. + return; + } + + // Check the target signature for attribute, byref and cmods state. This information + // is not contained with in the type name itself, so we may need to check the target + // signature and retain it on the new type. + MonoType *current_param = (seq == 0) ? tgt_sig->ret : tgt_sig->params [seq - 1]; + if (current_param->attrs != 0 || m_type_is_byref(current_param) || current_param->has_cmods) { + type = mono_metadata_type_dup_with_cmods(image, type, current_param); + type->byref__ = current_param->byref__; + type->attrs = current_param->attrs; + } + + // Update the target signature with the new type + if (seq == 0) { + // The return value + tgt_sig->ret = type; + } else { + // The arguments + tgt_sig->params [seq - 1] = type; + } + } +} + /* * mono_marshal_get_unsafe_accessor_wrapper: * @@ -5389,6 +5445,9 @@ mono_marshal_get_unsafe_accessor_wrapper (MonoMethod *accessor_method, MonoUnsaf } sig->pinvoke = 0; + // Parse the signature and check for instances of UnsafeAccessorTypeAttribute. + process_unsafe_accessor_type (kind, accessor_method, sig); + get_marshal_cb ()->mb_skip_visibility (mb); if (generic_wrapper || is_inflated) { diff --git a/src/mono/mono/metadata/reflection-internals.h b/src/mono/mono/metadata/reflection-internals.h index 01e3d136e0aab7..160850ae24ed5a 100644 --- a/src/mono/mono/metadata/reflection-internals.h +++ b/src/mono/mono/metadata/reflection-internals.h @@ -67,6 +67,8 @@ MONO_COMPONENT_API MonoCustomAttrInfo* mono_custom_attrs_from_method_checked (MonoMethod *method, MonoError *error); gboolean mono_method_get_unsafe_accessor_attr_data (MonoMethod *method, int *accessor_kind, char **member_name, MonoError *error); +gboolean +mono_method_param_get_unsafe_accessor_type_attr_data (MonoMethod *method, int param_seq, char **type_name, MonoError *error); MONO_COMPONENT_API MonoCustomAttrInfo* mono_custom_attrs_from_class_checked (MonoClass *klass, MonoError *error); MONO_COMPONENT_API MonoCustomAttrInfo* diff --git a/src/tests/Common/CoreCLRTestLibrary/AssertExtensions.cs b/src/tests/Common/CoreCLRTestLibrary/AssertExtensions.cs index 09f97637f90d66..4963ed0e1ededf 100644 --- a/src/tests/Common/CoreCLRTestLibrary/AssertExtensions.cs +++ b/src/tests/Common/CoreCLRTestLibrary/AssertExtensions.cs @@ -5,8 +5,11 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Threading.Tasks; +using Xunit.Sdk; + namespace Xunit { public static class AssertExtensions @@ -139,6 +142,44 @@ public static TInner ThrowsWithInnerException(Action action) return (TInner)outerException.InnerException; } + public static void ThrowsAny(Type firstExceptionType, Type secondExceptionType, Action action) + { + ThrowsAnyInternal(action, firstExceptionType, secondExceptionType); + } + + private static void ThrowsAnyInternal(Action action, params Type[] exceptionTypes) + { + try + { + action(); + } + catch (Exception e) + { + Type exceptionType = e.GetType(); + if (exceptionTypes.Any(t => t.Equals(exceptionType))) + return; + + throw new XunitException($"Expected one of: ({string.Join(", ", exceptionTypes)}) -> Actual: ({exceptionType}): {e}"); // Log message and callstack to help diagnosis + } + + throw new XunitException($"Expected one of: ({string.Join(", ", exceptionTypes)}) -> Actual: No exception thrown"); + } + + public static void ThrowsAny(Action action) + where TFirstExceptionType : Exception + where TSecondExceptionType : Exception + { + ThrowsAnyInternal(action, typeof(TFirstExceptionType), typeof(TSecondExceptionType)); + } + + public static void ThrowsAny(Action action) + where TFirstExceptionType : Exception + where TSecondExceptionType : Exception + where TThirdExceptionType : Exception + { + ThrowsAnyInternal(action, typeof(TFirstExceptionType), typeof(TSecondExceptionType), typeof(TThirdExceptionType)); + } + /// /// Tests whether the two lists are the same length and contain the same objects (using Object.Equals()) in the same order and diff --git a/src/tests/baseservices/compilerservices/UnsafeAccessors/PrivateLib.cs b/src/tests/baseservices/compilerservices/UnsafeAccessors/PrivateLib.cs new file mode 100644 index 00000000000000..eb3543cae330b2 --- /dev/null +++ b/src/tests/baseservices/compilerservices/UnsafeAccessors/PrivateLib.cs @@ -0,0 +1,50 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; + +using Xunit; + +namespace PrivateLib +{ + class Class1 + { + static int StaticField; + int InstanceField = 456; + + static Class1() + { + StaticField = 123; + } + + Class1() { } + + static Class1 GetClass() => new Class1(); + + Class2 GetClass2() => new Class2(); + } + + class Class2 { } + + class GenericClass + { + List M1() => new List(); + + List M2() => new List(); + + List M3() => new List(); + + List M4() => new List(); + + bool M5(List a, List b, List c, List d) where W : T => true; + + Type M6(Dictionary a) => typeof(X); + + Type M7(Dictionary a) => typeof(Y); + + Z M8() where Z : class, new() => new Z(); + + bool M9(List> a, List b, List> c) => true; + } +} \ No newline at end of file diff --git a/src/tests/baseservices/compilerservices/UnsafeAccessors/PrivateLib.csproj b/src/tests/baseservices/compilerservices/UnsafeAccessors/PrivateLib.csproj new file mode 100644 index 00000000000000..418f0402fb3411 --- /dev/null +++ b/src/tests/baseservices/compilerservices/UnsafeAccessors/PrivateLib.csproj @@ -0,0 +1,9 @@ + + + library + + + + + + diff --git a/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.Generics.cs b/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.Generics.cs index 311550810224c8..ee112b3dae2c66 100644 --- a/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.Generics.cs +++ b/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.Generics.cs @@ -421,10 +421,15 @@ public static void Verify_Generic_MethodConstraintEnforcement() Assert.Equal($"{typeof(ClassWithI1I2)}|{typeof(I1)}", CallMethod(new MethodWithConstraints())); Assert.Equal($"{typeof(ClassWithI1I2)}|{typeof(I1)}", CallStaticMethod(null)); - Assert.Throws(() => CallMethod_NoConstraints(new MethodWithConstraints())); - Assert.Throws(() => CallMethod_MissingConstraint(new MethodWithConstraints())); - Assert.Throws(() => CallStaticMethod_NoConstraints(null)); - Assert.Throws(() => CallStaticMethod_MissingConstraint(null)); + + // Skip validating error cases on Mono runtime + if (TestLibrary.Utilities.IsNotMonoRuntime) + { + Assert.Throws(() => CallMethod_NoConstraints(new MethodWithConstraints())); + Assert.Throws(() => CallMethod_MissingConstraint(new MethodWithConstraints())); + Assert.Throws(() => CallStaticMethod_NoConstraints(null)); + Assert.Throws(() => CallStaticMethod_MissingConstraint(null)); + } [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M")] extern static string CallMethod(MethodWithConstraints c) where V : W, I2; @@ -478,8 +483,12 @@ public static void Verify_Generic_ClassConstraintEnforcement() Assert.Equal($"{typeof(ClassWithI1I2)}|{typeof(I1)}|{typeof(ClassWithI1)}", AccessorsWithConstraints.CallMethod(new ClassWithConstraints())); Assert.Equal($"{typeof(ClassWithI1I2)}|{typeof(I1)}|{typeof(ClassWithI1)}", AccessorsWithConstraints.CallStaticMethod(null)); - Assert.Throws(() => AccessorsWithConstraints.CallMethod_MissingMethodConstraint(new ClassWithConstraints())); - Assert.Throws(() => AccessorsWithConstraints.CallStaticMethod_MissingMethodConstraint(null)); + // Skip validating error cases on Mono runtime + if (TestLibrary.Utilities.IsNotMonoRuntime) + { + Assert.Throws(() => AccessorsWithConstraints.CallMethod_MissingMethodConstraint(new ClassWithConstraints())); + Assert.Throws(() => AccessorsWithConstraints.CallStaticMethod_MissingMethodConstraint(null)); + } } class Invalid diff --git a/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.Types.cs b/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.Types.cs new file mode 100644 index 00000000000000..7226cb051c0d60 --- /dev/null +++ b/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.Types.cs @@ -0,0 +1,577 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Reflection; +using System.Reflection.Metadata; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +using Xunit; + +public static class StaticClass +{ + static StaticClass() + { + StaticField = 123; + } + + public static int StaticField; + public static int StaticMethod() => StaticField; +} + +class C1 { } +class C2 { } + +struct S1 { } + +class TargetClass +{ + private C2 _f1; + private readonly C2 _f2; + private TargetClass(C2 c2) + { + _f1 = c2; + _f2 = c2; + } + private TargetClass(ref C2 c2) + { + _f1 = c2; + _f2 = c2; + } + private C2 M_C1(C1 a) => _f1; + private C2 M_RC1(ref C1 a) => _f1; + private C2 M_RROC1(ref readonly C1 a) => _f1; + private ref C2 M_C1_RC2(C1 a) => ref _f1; + + private void M_ByRefs(C1 c, in C1 ic, ref C1 rc, out C1 oc) + { + Assert.Null(ic); // See caller + rc = c; + oc = c; + } + + private Type M_C1Array(C1[] c) => typeof(C1[]); + private Type M_C1Array(C1[,] c) => typeof(C1[,]); + private Type M_C1Array(C1[,,] c) => typeof(C1[,,]); + private Type M_C1Array(C1[][] c) => typeof(C1[][]); + private Type M_C1Array(C1[][][] c) => typeof(C1[][][]); + private Type M_C1Array(C1[][,] c) => typeof(C1[][,]); + + private Type M_S1Array(S1[] c) => typeof(S1[]); + private Type M_S1Array(S1[,] c) => typeof(S1[,]); + private Type M_S1Array(S1[,,] c) => typeof(S1[,,]); + private Type M_S1Array(S1[][] c) => typeof(S1[][]); + private Type M_S1Array(S1[][][] c) => typeof(S1[][][]); + private Type M_S1Array(S1[][,] c) => typeof(S1[][,]); + +#pragma warning disable CS8500 + private unsafe void M_C1Pointer(C1* c) { } +#pragma warning restore CS8500 + + private class InnerClass + { + private InnerClass() { } + private InnerClass(string _) { } + } +} + +public static unsafe class UnsafeAccessorsTestsTypes +{ + // Skip validating error cases on Mono runtime + [ConditionalFact(typeof(TestLibrary.Utilities), nameof(TestLibrary.Utilities.IsNotMonoRuntime))] + public static void Verify_Type_InvalidArgument() + { + Console.WriteLine($"Running {nameof(Verify_Type_InvalidArgument)}"); + + AssertExtensions.ThrowsAny(() => CallStaticMethod1(null)); + Assert.Throws(() => CallStaticMethod2(null)); + Assert.Throws(() => CallStaticMethod3(null)); + + [UnsafeAccessor(UnsafeAccessorKind.StaticMethod, Name = "MethodName")] + extern static ref int CallStaticMethod1([UnsafeAccessorType(null!)] object a); + + [UnsafeAccessor(UnsafeAccessorKind.StaticMethod, Name = "MethodName")] + extern static ref int CallStaticMethod2([UnsafeAccessorType("_DoesNotExist_")] object a); + + [UnsafeAccessor(UnsafeAccessorKind.StaticMethod, Name = "MethodName")] + extern static ref int CallStaticMethod3([UnsafeAccessorType("S1")] object a); + } + + [Fact] + public static void Verify_Type_StaticClass() + { + Console.WriteLine($"Running {nameof(Verify_Type_StaticClass)}"); + + var f = GetStaticClassField(null); + Assert.Equal(StaticClass.StaticField, f); + Assert.Equal(StaticClass.StaticField, CallStaticClassMethod(null)); + + [UnsafeAccessor(UnsafeAccessorKind.StaticField, Name = "StaticField")] + extern static ref int GetStaticClassField([UnsafeAccessorType("StaticClass")] object a); + + [UnsafeAccessor(UnsafeAccessorKind.StaticMethod, Name = "StaticMethod")] + extern static int CallStaticClassMethod([UnsafeAccessorType("StaticClass")] object a); + } + + [UnsafeAccessor(UnsafeAccessorKind.Constructor)] + extern static TargetClass CreateTargetClass([UnsafeAccessorType("C2")] object a); + + [UnsafeAccessor(UnsafeAccessorKind.Constructor)] + extern static TargetClass CreateTargetClass([UnsafeAccessorType("C2&")] ref object a); + + // Skip validating error cases on Mono runtime + [ConditionalFact(typeof(TestLibrary.Utilities), nameof(TestLibrary.Utilities.IsNotMonoRuntime))] + public static void Verify_Type_TypeCheck() + { + Console.WriteLine($"Running {nameof(Verify_Type_TypeCheck)}"); + + Assert.Throws(() => CreateTargetClass(new C1())); + Assert.Throws(() => + { + object c1 = new C1(); + CreateTargetClass(ref c1); + }); + } + + [Fact] + public static void Verify_Type_CallInstanceMethods() + { + Console.WriteLine($"Running {nameof(Verify_Type_CallInstanceMethods)}"); + + C2 c2 = new(); + object arg = c2; + TargetClass tgt = CreateTargetClass(arg); + + arg = new C1(); + Assert.Equal(c2, CallM_C1(tgt, arg)); + Assert.Equal(c2, CallM_RC1(tgt, ref arg)); + Assert.Equal(c2, CallM_RROC1(tgt, ref arg)); + AssertExtensions.ThrowsAny(()=> CallM_C1_RC2(tgt, arg)); + + object ic = null; + object rc = null; + object oc = null; + CallM_ByRefs(tgt, arg, in ic, ref rc, out oc); + Assert.Null(ic); + Assert.Equal(arg, rc); + Assert.Equal(arg, oc); + + Assert.Equal(typeof(C1[]), CallM_C1Array(tgt, Array.Empty())); + Assert.Equal(typeof(C1[,]), CallM_C1MDArray2(tgt, new C1[1,1])); + Assert.Equal(typeof(C1[,,]), CallM_C1MDArray3(tgt, new C1[1,1,1])); + Assert.Equal(typeof(C1[][]), CallM_C1JaggedArray2(tgt, new C1[0][])); + Assert.Equal(typeof(C1[][][]), CallM_C1JaggedArray3(tgt, new C1[0][][])); + Assert.Equal(typeof(C1[][,]), CallM_C1MixedArrays(tgt, new C1[0][,])); + + Assert.Equal(typeof(S1[]), CallM_S1Array(tgt, Array.Empty())); + Assert.Equal(typeof(S1[,]), CallM_S1MDArray2(tgt, new S1[1,1])); + Assert.Equal(typeof(S1[,,]), CallM_S1MDArray3(tgt, new S1[1,1,1])); + Assert.Equal(typeof(S1[][]), CallM_S1JaggedArray2(tgt, new S1[0][])); + Assert.Equal(typeof(S1[][][]), CallM_S1JaggedArray3(tgt, new S1[0][][])); + Assert.Equal(typeof(S1[][,]), CallM_S1MixedArrays(tgt, new S1[0][,])); + + CallM_C1Pointer(tgt, null); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M_C1")] + [return: UnsafeAccessorType("C2")] + extern static object CallM_C1(TargetClass tgt, [UnsafeAccessorType("C1")] object a); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M_RC1")] + [return: UnsafeAccessorType("C2")] + extern static object CallM_RC1(TargetClass tgt, [UnsafeAccessorType("C1&")] ref object a); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M_RROC1")] + [return: UnsafeAccessorType("C2")] + extern static object CallM_RROC1(TargetClass tgt, [UnsafeAccessorType("C1&")] ref readonly object a); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M_C1_RC2")] + [return: UnsafeAccessorType("C2&")] + extern static ref object CallM_C1_RC2(TargetClass tgt, [UnsafeAccessorType("C1")] object a); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M_ByRefs")] + extern static void CallM_ByRefs(TargetClass tgt, + [UnsafeAccessorType("C1")] object c, + [UnsafeAccessorType("C1&")] in object ic, + [UnsafeAccessorType("C1&")] ref object rc, + [UnsafeAccessorType("C1&")] out object oc); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M_C1Array")] + extern static Type CallM_C1Array(TargetClass tgt, [UnsafeAccessorType("C1[]")] object a); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M_C1Array")] + extern static Type CallM_C1MDArray2(TargetClass tgt, [UnsafeAccessorType("C1[,]")] object a); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M_C1Array")] + extern static Type CallM_C1MDArray3(TargetClass tgt, [UnsafeAccessorType("C1[,,]")] object a); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M_C1Array")] + extern static Type CallM_C1JaggedArray2(TargetClass tgt, [UnsafeAccessorType("C1[][]")] object a); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M_C1Array")] + extern static Type CallM_C1JaggedArray3(TargetClass tgt, [UnsafeAccessorType("C1[][][]")] object a); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M_C1Array")] + extern static Type CallM_C1MixedArrays(TargetClass tgt, [UnsafeAccessorType("C1[,][]")] object a); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M_S1Array")] + extern static Type CallM_S1Array(TargetClass tgt, [UnsafeAccessorType("S1[]")] object a); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M_S1Array")] + extern static Type CallM_S1MDArray2(TargetClass tgt, [UnsafeAccessorType("S1[,]")] object a); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M_S1Array")] + extern static Type CallM_S1MDArray3(TargetClass tgt, [UnsafeAccessorType("S1[,,]")] object a); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M_S1Array")] + extern static Type CallM_S1JaggedArray2(TargetClass tgt, [UnsafeAccessorType("S1[][]")] object a); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M_S1Array")] + extern static Type CallM_S1JaggedArray3(TargetClass tgt, [UnsafeAccessorType("S1[][][]")] object a); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M_S1Array")] + extern static Type CallM_S1MixedArrays(TargetClass tgt, [UnsafeAccessorType("S1[,][]")] object a); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M_C1Pointer")] + extern static void CallM_C1Pointer(TargetClass tgt, [UnsafeAccessorType("C1*")] void* a); + } + + [Fact] + public static void Verify_Type_GetInstanceFields_NotSupported() + { + Console.WriteLine($"Running {nameof(Verify_Type_GetInstanceFields_NotSupported)}"); + + C2 c2 = new(); + TargetClass tgt = CreateTargetClass(c2); + + // The following calls should throw NotSupportedException. + // Mono throws MissingFieldException since throwing NotSupportedException is difficult to implement. + AssertExtensions.ThrowsAny(()=> CallField1(tgt)); + AssertExtensions.ThrowsAny(()=> CallField2(tgt)); + + [UnsafeAccessor(UnsafeAccessorKind.Field, Name = "_f1")] + [return: UnsafeAccessorType("C2")] + extern static ref object CallField1(TargetClass tgt); + + [UnsafeAccessor(UnsafeAccessorKind.Field, Name = "_f2")] + [return: UnsafeAccessorType("C2")] + extern static ref readonly object CallField2(TargetClass tgt); + } + + [Fact] + public static void Verify_Type_CallInnerCtorClass() + { + Console.WriteLine($"Running {nameof(Verify_Type_CallInnerCtorClass)}"); + + object obj; + + obj = CreateInner(); + Assert.Equal("InnerClass", obj.GetType().Name); + + obj = CreateInnerString(string.Empty); + Assert.Equal("InnerClass", obj.GetType().Name); + + [UnsafeAccessor(UnsafeAccessorKind.Constructor)] + [return: UnsafeAccessorType("TargetClass+InnerClass")] + extern static object CreateInner(); + + [UnsafeAccessor(UnsafeAccessorKind.Constructor)] + [return: UnsafeAccessorType("TargetClass+InnerClass")] + extern static object CreateInnerString(string a); + } + + [UnsafeAccessor(UnsafeAccessorKind.StaticMethod, Name = "GetClass")] + [return: UnsafeAccessorType("PrivateLib.Class1, PrivateLib")] + extern static object CallGetClass([UnsafeAccessorType("PrivateLib.Class1, PrivateLib")] object a); + + [Fact] + public static void Verify_Type_CallPrivateLibMethods() + { + Console.WriteLine($"Running {nameof(Verify_Type_CallPrivateLibMethods)}"); + + { + object class1 = CreateClass(); + Assert.Equal("PrivateLib.Class1", class1.GetType().FullName); + } + + { + object class1 = CallGetClass(null); + Assert.Equal("PrivateLib.Class1", class1.GetType().FullName); + object listClass2 = CallGetClass2(class1); + Assert.Equal("PrivateLib.Class2", listClass2.GetType().FullName); + } + + [UnsafeAccessor(UnsafeAccessorKind.Constructor)] + [return: UnsafeAccessorType("PrivateLib.Class1, PrivateLib")] + extern static object CreateClass(); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "GetClass2")] + [return: UnsafeAccessorType("PrivateLib.Class2, PrivateLib")] + extern static object CallGetClass2([UnsafeAccessorType("PrivateLib.Class1, PrivateLib")] object a); + } + + [Fact] + public static void Verify_Type_GetPrivateLibFields() + { + Console.WriteLine($"Running {nameof(Verify_Type_GetPrivateLibFields)}"); + + object class1 = CallGetClass(null); + Assert.Equal("PrivateLib.Class1", class1.GetType().FullName); + + Assert.Equal(123, GetStaticField(null)); + Assert.Equal(456, GetInstanceField(class1)); + + [UnsafeAccessor(UnsafeAccessorKind.StaticField, Name = "StaticField")] + extern static ref int GetStaticField([UnsafeAccessorType("PrivateLib.Class1, PrivateLib")] object a); + + [UnsafeAccessor(UnsafeAccessorKind.Field, Name = "InstanceField")] + extern static ref int GetInstanceField([UnsafeAccessorType("PrivateLib.Class1, PrivateLib")] object a); + } + + partial class Accessors + { + [UnsafeAccessor(UnsafeAccessorKind.Constructor)] + [return: UnsafeAccessorType("PrivateLib.GenericClass`1[[!-0]], PrivateLib")] + public extern static object CreateGenericClass_InvalidGenericIndex1(); + + [UnsafeAccessor(UnsafeAccessorKind.Constructor)] + [return: UnsafeAccessorType("PrivateLib.GenericClass`1[[!+0]], PrivateLib")] + public extern static object CreateGenericClass_InvalidGenericIndex2(); + + [UnsafeAccessor(UnsafeAccessorKind.Constructor)] + [return: UnsafeAccessorType("PrivateLib.GenericClass`1[[!-1]], PrivateLib")] + public extern static object CreateGenericClass_InvalidGenericIndex3(); + + [UnsafeAccessor(UnsafeAccessorKind.Constructor)] + [return: UnsafeAccessorType("PrivateLib.GenericClass`1[[!1]], PrivateLib")] + public extern static object CreateGenericClass_InvalidGenericIndex4(); + + [UnsafeAccessor(UnsafeAccessorKind.Constructor)] + [return: UnsafeAccessorType("PrivateLib.GenericClass`1[[!0]], PrivateLib")] + public extern static object CreateGenericClass(); + + // Class type variables + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M1")] + [return: UnsafeAccessorType("System.Collections.Generic.List`1[[!0]]")] + public extern static object CallGenericClassM1([UnsafeAccessorType("PrivateLib.GenericClass`1[[!0]], PrivateLib")] object a); + + // Method type variables + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M2")] + [return: UnsafeAccessorType("System.Collections.Generic.List`1[[!!0]]")] + public extern static object CallGenericClassM2([UnsafeAccessorType("PrivateLib.GenericClass`1[[!0]], PrivateLib")] object a); + + // Bound type variables + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M3")] + public extern static List CallGenericClassM3([UnsafeAccessorType("PrivateLib.GenericClass`1[[!0]], PrivateLib")] object a); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M4")] + [return: UnsafeAccessorType("System.Collections.Generic.List`1[[PrivateLib.Class2, PrivateLib]]")] + public extern static object CallGenericClassM4([UnsafeAccessorType("PrivateLib.GenericClass`1[[!0]], PrivateLib")] object a); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M4")] + [return: UnsafeAccessorType("System.Collections.Generic.List`1[[System.Object]]")] + public extern static object CallGenericClassM4_InvalidReturn([UnsafeAccessorType("PrivateLib.GenericClass`1[[!0]], PrivateLib")] object a); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M5")] + public extern static bool CallGenericClassM5( + [UnsafeAccessorType("PrivateLib.GenericClass`1[[!0]], PrivateLib")] object tgt, + [UnsafeAccessorType("System.Collections.Generic.List`1[[!0]]")] + object a, + [UnsafeAccessorType("System.Collections.Generic.List`1[[!!0]]")] + object b, + List c, + [UnsafeAccessorType("System.Collections.Generic.List`1[[PrivateLib.Class2, PrivateLib]]")] + object d) where W : T; + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M5")] + public extern static bool CallGenericClassM5_NoConstraint( + [UnsafeAccessorType("PrivateLib.GenericClass`1[[!0]], PrivateLib")] object tgt, + [UnsafeAccessorType("System.Collections.Generic.List`1[[!0]]")] + object a, + [UnsafeAccessorType("System.Collections.Generic.List`1[[!!0]]")] + object b, + List c, + [UnsafeAccessorType("System.Collections.Generic.List`1[[PrivateLib.Class2, PrivateLib]]")] + object d); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M6")] + public extern static Type CallGenericClassM6( + [UnsafeAccessorType("PrivateLib.GenericClass`1[[!0]], PrivateLib")] object tgt, + [UnsafeAccessorType("System.Collections.Generic.Dictionary`2[[!!0],[System.Int32]]")] + object a); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M7")] + public extern static Type CallGenericClassM7( + [UnsafeAccessorType("PrivateLib.GenericClass`1[[!0]], PrivateLib")] object tgt, + [UnsafeAccessorType("System.Collections.Generic.Dictionary`2[[System.Int32],[!!0]]")] + object a); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M8")] + [return: UnsafeAccessorType("!!0")] + public extern static object CallGenericClassM8( + [UnsafeAccessorType("PrivateLib.GenericClass`1[[!0]], PrivateLib")] object tgt) where Z : class, new(); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M9")] + public extern static bool CallGenericClassM9( + [UnsafeAccessorType("PrivateLib.GenericClass`1[[!0]], PrivateLib")] object tgt, + [UnsafeAccessorType("System.Collections.Generic.List`1[[System.Collections.Generic.List`1[[!!0]]]]")] + object a, + [UnsafeAccessorType("System.Collections.Generic.List`1[[!!0[,,]]]")] + object b, + [UnsafeAccessorType("System.Collections.Generic.List`1[[System.Collections.Generic.List`1[[!0[][,]]]]]")] + object c); + } + + // Skip validating error cases on Mono runtime + [ConditionalFact(typeof(TestLibrary.Utilities), nameof(TestLibrary.Utilities.IsNotMonoRuntime))] + public static void Verify_Type_InvalidGenericTypeString() + { + Console.WriteLine($"Running {nameof(Verify_Type_InvalidGenericTypeString)}"); + + Assert.Throws(() => Accessors.CreateGenericClass_InvalidGenericIndex1()); + Assert.Throws(() => Accessors.CreateGenericClass_InvalidGenericIndex2()); + Assert.Throws(() => Accessors.CreateGenericClass_InvalidGenericIndex3()); + Assert.Throws(() => Accessors.CreateGenericClass_InvalidGenericIndex4()); + } + + private static bool TypeNameEquals(TypeName typeName1, TypeName typeName2) + { + if (typeName1.Name != typeName2.Name) + { + return false; + } + + if (typeName1.IsConstructedGenericType != typeName2.IsConstructedGenericType) + { + return false; + } + + var typeArgs1 = typeName1.GetGenericArguments(); + var typeArgs2 = typeName2.GetGenericArguments(); + if (typeArgs1.Length != typeArgs2.Length) + { + return false; + } + + for (int i = 0; i < typeArgs1.Length; i++) + { + if (!TypeNameEquals(typeArgs1[i], typeArgs2[i])) + { + return false; + } + } + + return true; + } + + // Skip private types and Generic support on Mono runtime + [ConditionalFact(typeof(TestLibrary.Utilities), nameof(TestLibrary.Utilities.IsNotMonoRuntime))] + public static void Verify_Type_CallPrivateLibTypeGenericParams() + { + Console.WriteLine($"Running {nameof(Verify_Type_CallPrivateLibTypeGenericParams)}"); + + { + object genericClass = Accessors.CreateGenericClass(); + TypeName genericClassName = TypeName.Parse(genericClass.GetType().FullName); + Assert.True(TypeNameEquals(genericClassName, TypeName.Parse("PrivateLib.GenericClass`1[[System.Int32]]"))); + + object genericListT = Accessors.CallGenericClassM1(genericClass); + TypeName genericListTName = TypeName.Parse(genericListT.GetType().FullName); + Assert.True(TypeNameEquals(genericListTName, TypeName.Parse("System.Collections.Generic.List`1[[System.Int32]]"))); + + List boundListInt = Accessors.CallGenericClassM3(genericClass); + Assert.Empty(boundListInt); + + object genericListClass2 = Accessors.CallGenericClassM4(genericClass); + TypeName genericListClass2Name = TypeName.Parse(genericListClass2.GetType().FullName); + Assert.True(TypeNameEquals(genericListClass2Name, TypeName.Parse("System.Collections.Generic.List`1[[PrivateLib.Class2, PrivateLib]]"))); + + Assert.Throws(() => Accessors.CallGenericClassM4_InvalidReturn(genericClass)); + } + + { + object genericClass = Accessors.CreateGenericClass(); + TypeName genericClassName = TypeName.Parse(genericClass.GetType().FullName); + Assert.True(TypeNameEquals(genericClassName, TypeName.Parse("PrivateLib.GenericClass`1[[System.String]]"))); + + object genericListT = Accessors.CallGenericClassM1(genericClass); + TypeName genericListTName = TypeName.Parse(genericListT.GetType().FullName); + Assert.True(TypeNameEquals(genericListTName, TypeName.Parse("System.Collections.Generic.List`1[[System.String]]"))); + + List boundListInt = Accessors.CallGenericClassM3(genericClass); + Assert.Empty(boundListInt); + + object genericListClass2 = Accessors.CallGenericClassM4(genericClass); + TypeName genericListClass2Name = TypeName.Parse(genericListClass2.GetType().FullName); + Assert.True(TypeNameEquals(genericListClass2Name, TypeName.Parse("System.Collections.Generic.List`1[[PrivateLib.Class2, PrivateLib]]"))); + + Assert.Throws(() => Accessors.CallGenericClassM4_InvalidReturn(genericClass)); + } + } + + // Skip private types and Generic support on Mono runtime + [ConditionalFact(typeof(TestLibrary.Utilities), nameof(TestLibrary.Utilities.IsNotMonoRuntime))] + public static void Verify_Type_CallPrivateLibTypeAndMethodGenericParams() + { + Console.WriteLine($"Running {nameof(Verify_Type_CallPrivateLibTypeAndMethodGenericParams)}"); + + { + object genericClass = Accessors.CreateGenericClass(); + TypeName genericClassName = TypeName.Parse(genericClass.GetType().FullName); + Assert.True(TypeNameEquals(genericClassName, TypeName.Parse("PrivateLib.GenericClass`1[[System.Int32]]"))); + + object genericListInt = Accessors.CallGenericClassM2(genericClass); + TypeName genericListIntName = TypeName.Parse(genericListInt.GetType().FullName); + Assert.True(TypeNameEquals(genericListIntName, TypeName.Parse("System.Collections.Generic.List`1[[System.Int32]]"))); + + object genericListString = Accessors.CallGenericClassM2(genericClass); + TypeName genericListStringName = TypeName.Parse(genericListString.GetType().FullName); + Assert.True(TypeNameEquals(genericListStringName, TypeName.Parse("System.Collections.Generic.List`1[[System.String]]"))); + + Assert.True(Accessors.CallGenericClassM5(genericClass, null, null, null, null)); + Assert.Equal(typeof(int), Accessors.CallGenericClassM6(genericClass, null)); + Assert.Equal(typeof(int), Accessors.CallGenericClassM7(genericClass, null)); + Assert.True(Accessors.CallGenericClassM9(genericClass, null, null, null)); + } + + { + object genericClass = Accessors.CreateGenericClass(); + TypeName genericClassName = TypeName.Parse(genericClass.GetType().FullName); + Assert.True(TypeNameEquals(genericClassName, TypeName.Parse("PrivateLib.GenericClass`1[[System.String]]"))); + + object genericListInt = Accessors.CallGenericClassM2(genericClass); + TypeName genericListIntName = TypeName.Parse(genericListInt.GetType().FullName); + Assert.True(TypeNameEquals(genericListIntName, TypeName.Parse("System.Collections.Generic.List`1[[System.Int32]]"))); + + object genericListString = Accessors.CallGenericClassM2(genericClass); + TypeName genericListStringName = TypeName.Parse(genericListString.GetType().FullName); + Assert.True(TypeNameEquals(genericListStringName, TypeName.Parse("System.Collections.Generic.List`1[[System.String]]"))); + + Assert.True(Accessors.CallGenericClassM5(genericClass, null, null, null, null)); + Assert.Equal(typeof(string), Accessors.CallGenericClassM6(genericClass, null)); + Assert.Equal(typeof(string), Accessors.CallGenericClassM7(genericClass, null)); + Assert.Equal(typeof(C1), Accessors.CallGenericClassM8(genericClass).GetType()); + Assert.True(Accessors.CallGenericClassM9(genericClass, null, null, null)); + } + } + + // Skip private types and Generic support on Mono runtime + [ConditionalFact(typeof(TestLibrary.Utilities), nameof(TestLibrary.Utilities.IsNotMonoRuntime))] + public static void Verify_Type_CallPrivateLibTypeAndMethodGenericParamsWithConstraints() + { + Console.WriteLine($"Running {nameof(Verify_Type_CallPrivateLibTypeAndMethodGenericParamsWithConstraints)}"); + + { + object genericClass = Accessors.CreateGenericClass(); + Assert.True(Accessors.CallGenericClassM5(genericClass, null, null, null, null)); + Assert.Throws(() => Accessors.CallGenericClassM5_NoConstraint(genericClass, null, null, null, null)); + } + + { + object genericClass = Accessors.CreateGenericClass(); + Assert.True(Accessors.CallGenericClassM5(genericClass, null, null, null, null)); + Assert.Throws(() => Accessors.CallGenericClassM5_NoConstraint(genericClass, null, null, null, null)); + } + } +} diff --git a/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.csproj b/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.csproj index f551f9b48c2495..71db192cc50085 100644 --- a/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.csproj +++ b/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.csproj @@ -7,8 +7,10 @@ + + diff --git a/src/tools/illink/src/linker/Linker.Steps/UnsafeAccessorMarker.cs b/src/tools/illink/src/linker/Linker.Steps/UnsafeAccessorMarker.cs index d6155b8a50b5e1..38a8f101d1bf17 100644 --- a/src/tools/illink/src/linker/Linker.Steps/UnsafeAccessorMarker.cs +++ b/src/tools/illink/src/linker/Linker.Steps/UnsafeAccessorMarker.cs @@ -1,6 +1,7 @@ // Copyright (c) .NET Foundation and contributors. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. +using System.Diagnostics.CodeAnalysis; using System.Runtime.CompilerServices; using Mono.Cecil; @@ -67,6 +68,29 @@ public void ProcessUnsafeAccessorMethod (MethodDefinition method) } } + bool TryResolveTargetType(TypeReference targetTypeReference, ICustomAttributeProvider unsafeAccessorTypeProvider, AssemblyDefinition assembly, [NotNullWhen(true)] out TypeDefinition? targetType) + { + targetType = null; + if (_context.TryResolve (targetTypeReference) is not TypeDefinition initialTargetType) + return false; + + targetType = initialTargetType; + + foreach (CustomAttribute attr in unsafeAccessorTypeProvider.CustomAttributes) { + if (attr.Constructor.DeclaringType.FullName == "System.Runtime.CompilerServices.UnsafeAccessorTypeAttribute") { + if (attr.HasConstructorArguments && attr.ConstructorArguments[0].Value is string typeName) { + TypeDefinition? newTargetType = _context.TryResolve (assembly, typeName); + if (newTargetType is null) + return false; // We can't find the target type, so there's nothing to mark. + + targetType = newTargetType; + } + } + } + + return true; + } + void ProcessConstructorAccessor (MethodDefinition method, string? name) { // A return type is required for a constructor, otherwise @@ -76,7 +100,7 @@ void ProcessConstructorAccessor (MethodDefinition method, string? name) if (method.ReturnsVoid () || method.ReturnType.IsByRefOrPointer () || !string.IsNullOrEmpty (name)) return; - if (_context.TryResolve (method.ReturnType) is not TypeDefinition targetType) + if (!TryResolveTargetType(method.ReturnType, method.MethodReturnType, method.Module.Assembly, out TypeDefinition? targetType)) return; foreach (MethodDefinition targetMethod in targetType.Methods) { @@ -97,7 +121,7 @@ void ProcessMethodAccessor (MethodDefinition method, string? name, bool isStatic name = method.Name; TypeReference targetTypeReference = method.Parameters[0].ParameterType; - if (_context.TryResolve (targetTypeReference) is not TypeDefinition targetType) + if (!TryResolveTargetType (targetTypeReference, method.Parameters[0], method.Module.Assembly, out TypeDefinition? targetType)) return; if (!isStatic && targetType.IsValueType && !targetTypeReference.IsByReference) @@ -124,7 +148,7 @@ void ProcessFieldAccessor (MethodDefinition method, string? name, bool isStatic) return; TypeReference targetTypeReference = method.Parameters[0].ParameterType; - if (_context.TryResolve (targetTypeReference) is not TypeDefinition targetType) + if (!TryResolveTargetType (targetTypeReference, method.Parameters[0], method.Module.Assembly, out TypeDefinition? targetType)) return; if (!isStatic && targetType.IsValueType && !targetTypeReference.IsByReference) diff --git a/src/tools/illink/test/Mono.Linker.Tests.Cases/Reflection/UnsafeAccessor.cs b/src/tools/illink/test/Mono.Linker.Tests.Cases/Reflection/UnsafeAccessor.cs index 2d9aadff272fc4..58e56f120edbf2 100644 --- a/src/tools/illink/test/Mono.Linker.Tests.Cases/Reflection/UnsafeAccessor.cs +++ b/src/tools/illink/test/Mono.Linker.Tests.Cases/Reflection/UnsafeAccessor.cs @@ -21,6 +21,7 @@ public static void Main () StaticFieldAccess.Test (); InstanceFieldAccess.Test (); InheritanceTest.Test (); + PrivateTypeTest.Test (); } // Trimmer doesn't use method overload resolution for UnsafeAccessor and instead marks entire method groups (by name) @@ -892,6 +893,41 @@ public static void Test () } } + [Kept] + class PrivateTypeTest + { + class ExternalType + { + [Kept] + [KeptMember (".ctor()")] + class PrivateType + { + [Kept] + private void TargetMethod () + { + } + } + } + + [Kept] + [KeptAttributeAttribute (typeof (UnsafeAccessorAttribute))] + [UnsafeAccessor (UnsafeAccessorKind.Constructor)] + [return: KeptAttributeAttribute (typeof (UnsafeAccessorTypeAttribute))] + [return: UnsafeAccessorType ("Mono.Linker.Test.Cases.Reflection.UnsafeAccessor+PrivateTypeTest+ExternalType+PrivateType")] + extern static object TargetConstructor (); + + [Kept] + [KeptAttributeAttribute (typeof (UnsafeAccessorAttribute))] + [UnsafeAccessor (UnsafeAccessorKind.Method)] + extern static void TargetMethod ([KeptAttributeAttribute(typeof(UnsafeAccessorTypeAttribute)), UnsafeAccessorType ("Mono.Linker.Test.Cases.Reflection.UnsafeAccessor+PrivateTypeTest+ExternalType+PrivateType")] object target); + + [Kept] + public static void Test () + { + TargetMethod (TargetConstructor()); + } + } + [Kept (By = Tool.Trimmer)] // NativeAOT doesn't preserve base type if it's not used anywhere class SuperBase { } @@ -904,3 +940,19 @@ class Base : SuperBase { } class Derived : Base { } } } + +// Polyfill for UnsafeAccessorTypeAttribute until we use an LKG runtime that has it. +namespace System.Runtime.CompilerServices +{ + [Kept(By = Tool.Trimmer)] + [KeptBaseType(typeof(Attribute), By = Tool.Trimmer)] + [KeptAttributeAttribute(typeof(AttributeUsageAttribute), By = Tool.Trimmer)] + [AttributeUsage (AttributeTargets.Parameter | AttributeTargets.ReturnValue, AllowMultiple = false, Inherited = false)] + public sealed class UnsafeAccessorTypeAttribute : Attribute + { + [Kept(By = Tool.Trimmer)] + public UnsafeAccessorTypeAttribute (string typeName) + { + } + } +}