diff --git a/src/mono/mono/component/marshal-ilgen-stub.c b/src/mono/mono/component/marshal-ilgen-stub.c index 39a9b4e854a523..277261c6b6dde5 100644 --- a/src/mono/mono/component/marshal-ilgen-stub.c +++ b/src/mono/mono/component/marshal-ilgen-stub.c @@ -9,11 +9,46 @@ marshal_ilgen_available (void) return false; } +static void emit_throw_exception (MonoMarshalLightweightCallbacks* lightweight_cb, + MonoMethodBuilder* mb, const char* exc_nspace, const char* exc_name, const char* msg) +{ + lightweight_cb->mb_emit_exception (mb, exc_nspace, exc_name, msg); +} + static int -stub_emit_marshal_ilgen (EmitMarshalContext *m, int argnum, MonoType *t, - MonoMarshalSpec *spec, int conv_arg, - MonoType **conv_arg_type, MarshalAction action, MonoMarshalLightweightCallbacks* lightweight_cb) +stub_emit_marshal_ilgen (EmitMarshalContext* m, int argnum, MonoType* t, + MonoMarshalSpec* spec, int conv_arg, + MonoType** conv_arg_type, MarshalAction action, MonoMarshalLightweightCallbacks* lightweight_cb) { + if (spec) { + g_assert (spec->native != MONO_NATIVE_ASANY); + g_assert (spec->native != MONO_NATIVE_CUSTOM); + } + + g_assert (!m_type_is_byref(t)); + + switch (t->type) { + case MONO_TYPE_PTR: + case MONO_TYPE_I1: + case MONO_TYPE_U1: + case MONO_TYPE_I2: + case MONO_TYPE_U2: + case MONO_TYPE_I4: + case MONO_TYPE_U4: + case MONO_TYPE_I: + case MONO_TYPE_U: + case MONO_TYPE_R4: + case MONO_TYPE_R8: + case MONO_TYPE_I8: + case MONO_TYPE_U8: + case MONO_TYPE_FNPTR: + return lightweight_cb->emit_marshal_scalar (m, argnum, t, spec, conv_arg, conv_arg_type, action); + default: + emit_throw_exception (lightweight_cb, m->mb, "System", "ApplicationException", + g_strdup("Cannot marshal nonblittlable types without marshal-ilgen.")); + break; + } + return 0; } diff --git a/src/mono/mono/metadata/marshal-lightweight.c b/src/mono/mono/metadata/marshal-lightweight.c index a4fb731434b670..8871b2f4eee4ea 100644 --- a/src/mono/mono/metadata/marshal-lightweight.c +++ b/src/mono/mono/metadata/marshal-lightweight.c @@ -523,7 +523,7 @@ emit_runtime_invoke_body_ilgen (MonoMethodBuilder *mb, const char **param_names, emit_thread_force_interrupt_checkpoint (mb); emit_invoke_call (mb, method, sig, callsig, loc_res, virtual_, need_direct_wrapper); - mono_mb_emit_ldloc (mb, 0); + mono_mb_emit_ldloc (mb, loc_res); mono_mb_emit_byte (mb, CEE_RET); } diff --git a/src/mono/wasm/build/WasmApp.Native.targets b/src/mono/wasm/build/WasmApp.Native.targets index 85b5e1858cb419..8093fb92158716 100644 --- a/src/mono/wasm/build/WasmApp.Native.targets +++ b/src/mono/wasm/build/WasmApp.Native.targets @@ -3,9 +3,11 @@ + <_WasmBuildNativeCoreDependsOn> + _ScanAssembliesDecideLightweightMarshaler; _WasmAotCompileApp; _WasmStripAOTAssemblies; _PrepareForWasmBuildNative; @@ -33,7 +35,7 @@ <_MonoComponent Include="hot_reload;debugger" /> - + <_MonoComponent Include="marshal-ilgen" /> @@ -680,6 +682,16 @@ + + + + + + + + + + diff --git a/src/tasks/MonoTargetsTasks/MarshalingPInvokeScanner/MarshalingPInvokeScanner.cs b/src/tasks/MonoTargetsTasks/MarshalingPInvokeScanner/MarshalingPInvokeScanner.cs new file mode 100644 index 00000000000000..de5e6e3762fdf7 --- /dev/null +++ b/src/tasks/MonoTargetsTasks/MarshalingPInvokeScanner/MarshalingPInvokeScanner.cs @@ -0,0 +1,157 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Diagnostics.CodeAnalysis; +using System.Reflection.Metadata; +using System.Reflection.Metadata.Ecma335; +using System.Collections.Immutable; +using System.IO; +using System.Linq; +using System.Text; +using System.Reflection; +using System.Reflection.PortableExecutable; +using Microsoft.Build.Framework; +using Microsoft.Build.Utilities; + +namespace MonoTargetsTasks +{ + public class MarshalingPInvokeScanner : Task + { + [Required] + public string[] Assemblies { get; set; } = Array.Empty(); + + [Output] + public string[]? IncompatibleAssemblies { get; private set; } + + public override bool Execute() + { + if (Assemblies is null || Assemblies!.Length == 0) + { + Log.LogError($"{nameof(MarshalingPInvokeScanner)}.{nameof(Assemblies)} cannot be empty"); + return false; + } + + try + { + ExecuteInternal(); + return !Log.HasLoggedErrors; + } + catch (LogAsErrorException e) + { + Log.LogError(e.Message); + return false; + } + } + + private void ExecuteInternal() + { + IncompatibleAssemblies = ScanAssemblies(Assemblies); + } + + private string[] ScanAssemblies(string[] assemblies) + { + HashSet incompatible = new HashSet(); + MinimalMarshalingTypeCompatibilityProvider mmtcp = new(Log); + foreach (string aname in assemblies) + { + if (IsAssemblyIncompatible(aname, mmtcp)) + incompatible.Add(aname); + } + + if (mmtcp.IsSecondPassNeeded) + { + foreach (string aname in assemblies) + ResolveInconclusiveTypes(incompatible, aname, mmtcp); + } + + return incompatible.ToArray(); + } + + private static string GetMethodName(MetadataReader mr, MethodDefinition md) => mr.GetString(md.Name); + + private void ResolveInconclusiveTypes(HashSet incompatible, string assyPath, MinimalMarshalingTypeCompatibilityProvider mmtcp) + { + string assyName = MetadataReader.GetAssemblyName(assyPath).Name!; + HashSet inconclusiveTypes = mmtcp.GetInconclusiveTypesForAssembly(assyName); + if(inconclusiveTypes.Count == 0) + return; + + using FileStream file = new FileStream(assyPath, FileMode.Open, FileAccess.Read, FileShare.ReadWrite); + using PEReader peReader = new PEReader(file); + MetadataReader mdtReader = peReader.GetMetadataReader(); + + SignatureDecoder decoder = new(mmtcp, mdtReader, null!); + + foreach (TypeDefinitionHandle typeDefHandle in mdtReader.TypeDefinitions) + { + TypeDefinition typeDef = mdtReader.GetTypeDefinition(typeDefHandle); + string fullTypeName = string.Join(":", mdtReader.GetString(typeDef.Namespace), mdtReader.GetString(typeDef.Name)); + + // This is not perfect, but should work right for enums defined in other assemblies, + // which is the only case where we use Compatibility.Inconclusive. + if (inconclusiveTypes.Contains(fullTypeName) && + mmtcp.GetTypeFromDefinition(mdtReader, typeDefHandle, 0) != Compatibility.Compatible) + { + Log.LogMessage(MessageImportance.Low, string.Format("Type {0} is marshaled and requires marshal-ilgen.", fullTypeName)); + + incompatible.Add("(unknown assembly)"); + } + } + } + + private bool IsAssemblyIncompatible(string assyPath, MinimalMarshalingTypeCompatibilityProvider mmtcp) + { + using FileStream file = new FileStream(assyPath, FileMode.Open, FileAccess.Read, FileShare.ReadWrite); + using PEReader peReader = new PEReader(file); + MetadataReader mdtReader = peReader.GetMetadataReader(); + + foreach(CustomAttributeHandle attrHandle in mdtReader.CustomAttributes) + { + CustomAttribute attr = mdtReader.GetCustomAttribute(attrHandle); + + if(attr.Constructor.Kind == HandleKind.MethodDefinition) + { + MethodDefinitionHandle mdh = (MethodDefinitionHandle)attr.Constructor; + MethodDefinition md = mdtReader.GetMethodDefinition(mdh); + TypeDefinitionHandle tdh = md.GetDeclaringType(); + TypeDefinition td = mdtReader.GetTypeDefinition(tdh); + + if(mdtReader.GetString(td.Namespace) == "System.Runtime.CompilerServices" && + mdtReader.GetString(td.Name) == "DisableRuntimeMarshallingAttribute") + return false; + } + } + + foreach (TypeDefinitionHandle typeDefHandle in mdtReader.TypeDefinitions) + { + TypeDefinition typeDef = mdtReader.GetTypeDefinition(typeDefHandle); + string ns = mdtReader.GetString(typeDef.Namespace); + string name = mdtReader.GetString(typeDef.Name); + + foreach(MethodDefinitionHandle mthDefHandle in typeDef.GetMethods()) + { + MethodDefinition mthDef = mdtReader.GetMethodDefinition(mthDefHandle); + if(!mthDef.Attributes.HasFlag(MethodAttributes.PinvokeImpl)) + continue; + + BlobReader sgnBlobReader = mdtReader.GetBlobReader(mthDef.Signature); + SignatureDecoder decoder = new(mmtcp, mdtReader, null!); + + MethodSignature sgn = decoder.DecodeMethodSignature(ref sgnBlobReader); + if(sgn.ReturnType == Compatibility.Incompatible || sgn.ParameterTypes.Any(p => p == Compatibility.Incompatible)) + { + Log.LogMessage(MessageImportance.Low, string.Format("Assembly {0} requires marhsal-ilgen for method {1}.{2}:{3} (first pass).", + assyPath, ns, name, mdtReader.GetString(mthDef.Name))); + + return true; + } + } + } + + return false; + } + } +} diff --git a/src/tasks/MonoTargetsTasks/MarshalingPInvokeScanner/MinimalMarshalingTypeCompatibilityProvider.cs b/src/tasks/MonoTargetsTasks/MarshalingPInvokeScanner/MinimalMarshalingTypeCompatibilityProvider.cs new file mode 100644 index 00000000000000..8916c1d674c138 --- /dev/null +++ b/src/tasks/MonoTargetsTasks/MarshalingPInvokeScanner/MinimalMarshalingTypeCompatibilityProvider.cs @@ -0,0 +1,167 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Diagnostics.CodeAnalysis; +using System.Reflection.Metadata; +using System.Reflection.Metadata.Ecma335; +using System.Collections.Immutable; +using System.IO; +using System.Linq; +using System.Text; +using System.Reflection; +using System.Reflection.PortableExecutable; +using Microsoft.Build.Framework; +using Microsoft.Build.Utilities; + +namespace MonoTargetsTasks +{ + // For some valuetypes we cannot determine if they are compatible with disabled + // runtime marshaling without first resolving their base types. In this case we + // first mark the assembly as Inconclusive and do a second pass over the collected + // base type references in order to decide. If the base types are System.Enum, + // then the valuetypes are enumerations, and are compatible. + internal enum Compatibility + { + Compatible, + Incompatible, + Inconclusive + } + + internal sealed class InconclusiveCompatibilityCollection + { + private readonly Dictionary> _data = new(); + + public bool IsEmpty => _data.Count == 0; + + public void Add(string assyName, string namespaceName, string typeName) + { + HashSet? incAssyTypes; + + if(!_data.TryGetValue(assyName, out incAssyTypes)) + { + incAssyTypes = new(); + _data.Add(assyName, incAssyTypes); + } + + incAssyTypes.Add($"{namespaceName}:{typeName}"); + } + + public HashSet EnumerateForAssembly(string assyName) + { + if(_data.TryGetValue(assyName, out HashSet? incAssyTypes)) + return incAssyTypes!; + + return new HashSet(); + } + } + + internal sealed class MinimalMarshalingTypeCompatibilityProvider : ISignatureTypeProvider + { + internal MinimalMarshalingTypeCompatibilityProvider(TaskLoggingHelper log) + { + _log = log; + } + + private readonly TaskLoggingHelper _log; + + // assembly name -> set of types needed for second pass + private readonly InconclusiveCompatibilityCollection _inconclusive = new(); + + public bool IsSecondPassNeeded => !_inconclusive.IsEmpty; + public HashSet GetInconclusiveTypesForAssembly(string assyName) => _inconclusive.EnumerateForAssembly(assyName); + + public Compatibility GetArrayType(Compatibility elementType, ArrayShape shape) => Compatibility.Incompatible; + public Compatibility GetByReferenceType(Compatibility elementType) => Compatibility.Incompatible; + public Compatibility GetFunctionPointerType(MethodSignature signature) => Compatibility.Compatible; + public Compatibility GetGenericInstantiation(Compatibility genericType, ImmutableArray typeArguments) => genericType; + public Compatibility GetGenericMethodParameter(object genericContext, int index) => Compatibility.Incompatible; + public Compatibility GetGenericTypeParameter(object genericContext, int index) => Compatibility.Incompatible; + public Compatibility GetModifiedType(Compatibility modifier, Compatibility unmodifiedType, bool isRequired) => Compatibility.Incompatible; + public Compatibility GetPinnedType(Compatibility elementType) => Compatibility.Compatible; + public Compatibility GetPointerType(Compatibility elementType) => Compatibility.Compatible; + public Compatibility GetPrimitiveType(PrimitiveTypeCode typeCode) + { + return typeCode switch + { + PrimitiveTypeCode.Object => Compatibility.Incompatible, + PrimitiveTypeCode.String => Compatibility.Incompatible, + PrimitiveTypeCode.TypedReference => Compatibility.Incompatible, + _ => Compatibility.Compatible + }; + } + + public Compatibility GetSZArrayType(Compatibility elementType) => Compatibility.Incompatible; + + public Compatibility GetTypeFromDefinition(MetadataReader reader, TypeDefinitionHandle handle, byte rawTypeKind) + { + TypeDefinition typeDef = reader.GetTypeDefinition(handle); + if (reader.GetString(typeDef.Namespace) == "System" && + reader.GetString(typeDef.Name) == "Enum") + return Compatibility.Compatible; + + try + { + EntityHandle baseTypeHandle = typeDef.BaseType; + if (baseTypeHandle.Kind == HandleKind.TypeReference) + { + TypeReference baseType = reader.GetTypeReference((TypeReferenceHandle)baseTypeHandle); + if (reader.GetString(typeDef.Namespace) == "System" && + reader.GetString(baseType.Name) == "Enum") + return Compatibility.Compatible; + } + else if (baseTypeHandle.Kind == HandleKind.TypeSpecification) + { + TypeSpecification specInner = reader.GetTypeSpecification((TypeSpecificationHandle)baseTypeHandle); + return specInner.DecodeSignature(this, new object()); + } + else if (baseTypeHandle.Kind == HandleKind.TypeDefinition) + { + TypeDefinitionHandle handleInner = (TypeDefinitionHandle)baseTypeHandle; + if (handle != handleInner) + return GetTypeFromDefinition(reader, handleInner, rawTypeKind); + } + } + catch(BadImageFormatException ex) + { + _log.LogMessage(MessageImportance.Low, ex.Message); + } + + return Compatibility.Incompatible; + } + + public Compatibility GetTypeFromReference(MetadataReader reader, TypeReferenceHandle handle, byte rawTypeKind) + { + if (rawTypeKind == 0x11 /*ELEMENT_TYPE_VALUETYPE*/) + { + TypeReference typeRef = reader.GetTypeReference(handle); + EntityHandle scope = typeRef.ResolutionScope; + + if (scope.Kind == HandleKind.AssemblyReference) + { + AssemblyReferenceHandle assyRefHandle = (AssemblyReferenceHandle)typeRef.ResolutionScope; + AssemblyReference assyRef = reader.GetAssemblyReference(assyRefHandle); + + _inconclusive.Add(assyName: reader.GetString(assyRef.Name), + namespaceName: reader.GetString(typeRef.Namespace), typeName: reader.GetString(typeRef.Name)); + return Compatibility.Inconclusive; + } + else + { + throw new NotImplementedException(string.Format("Unsupported ResolutionScope kind '{0}' used in type {1}:{2}.", + scope.Kind.ToString(), reader.GetString(typeRef.Namespace), reader.GetString(typeRef.Name))); + } + } + + return Compatibility.Incompatible; + } + + public Compatibility GetTypeFromSpecification(MetadataReader reader, object genericContext, TypeSpecificationHandle handle, byte rawTypeKind) + { + TypeSpecification spec = reader.GetTypeSpecification((TypeSpecificationHandle)handle); + return spec.DecodeSignature(this, genericContext); + } + } +} diff --git a/src/tasks/MonoTargetsTasks/MonoTargetsTasks.csproj b/src/tasks/MonoTargetsTasks/MonoTargetsTasks.csproj index 181538f565e422..340cc4e5463b27 100644 --- a/src/tasks/MonoTargetsTasks/MonoTargetsTasks.csproj +++ b/src/tasks/MonoTargetsTasks/MonoTargetsTasks.csproj @@ -17,6 +17,7 @@ + @@ -30,6 +31,8 @@ + + diff --git a/src/tasks/WasmAppBuilder/PInvokeCollector.cs b/src/tasks/WasmAppBuilder/PInvokeCollector.cs new file mode 100644 index 00000000000000..0eeecabc904d66 --- /dev/null +++ b/src/tasks/WasmAppBuilder/PInvokeCollector.cs @@ -0,0 +1,250 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System; +using System.Linq; +using System.Diagnostics.CodeAnalysis; +using System.Reflection; +using Microsoft.Build.Framework; +using Microsoft.Build.Utilities; +using Microsoft.Build.Tasks; + +#pragma warning disable CA1067 +#pragma warning disable CS0649 +internal sealed class PInvoke : IEquatable +#pragma warning restore CA1067 +{ + public PInvoke(string entryPoint, string module, MethodInfo method) + { + EntryPoint = entryPoint; + Module = module; + Method = method; + } + + public string EntryPoint; + public string Module; + public MethodInfo Method; + public bool Skip; + + public bool Equals(PInvoke? other) + => other != null && + string.Equals(EntryPoint, other.EntryPoint, StringComparison.Ordinal) && + string.Equals(Module, other.Module, StringComparison.Ordinal) && + string.Equals(Method.ToString(), other.Method.ToString(), StringComparison.Ordinal); + + public override string ToString() => $"{{ EntryPoint: {EntryPoint}, Module: {Module}, Method: {Method}, Skip: {Skip} }}"; +} +#pragma warning restore CS0649 + +internal sealed class PInvokeComparer : IEqualityComparer +{ + public bool Equals(PInvoke? x, PInvoke? y) + { + if (x == null && y == null) + return true; + if (x == null || y == null) + return false; + + return x.Equals(y); + } + + public int GetHashCode(PInvoke pinvoke) + => $"{pinvoke.EntryPoint}{pinvoke.Module}{pinvoke.Method}".GetHashCode(); +} + + +internal sealed class PInvokeCollector { + private readonly Dictionary _assemblyDisableRuntimeMarshallingAttributeCache = new(); + private TaskLoggingHelper Log { get; init; } + + public PInvokeCollector(TaskLoggingHelper log) + { + Log = log; + } + + public void CollectPInvokes(List pinvokes, List callbacks, List signatures, Type type) + { + foreach (var method in type.GetMethods(BindingFlags.DeclaredOnly | BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance)) + { + try + { + CollectPInvokesForMethod(method); + if (DoesMethodHaveCallbacks(method)) + callbacks.Add(new PInvokeCallback(method)); + } + catch (Exception ex) when (ex is not LogAsErrorException) + { + Log.LogWarning(null, "WASM0001", "", "", 0, 0, 0, 0, + $"Could not get pinvoke, or callbacks for method '{type.FullName}::{method.Name}' because '{ex.Message}'"); + } + } + + if (HasAttribute(type, "System.Runtime.InteropServices.UnmanagedFunctionPointerAttribute")) + { + var method = type.GetMethod("Invoke"); + + if (method != null) + { + string? signature = SignatureMapper.MethodToSignature(method!); + if (signature == null) + throw new NotSupportedException($"Unsupported parameter type in method '{type.FullName}.{method.Name}'"); + + + Log.LogMessage(MessageImportance.Low, $"Adding pinvoke signature {signature} for method '{type.FullName}.{method.Name}'"); + signatures.Add(signature); + } + } + + void CollectPInvokesForMethod(MethodInfo method) + { + if ((method.Attributes & MethodAttributes.PinvokeImpl) != 0) + { + var dllimport = method.CustomAttributes.First(attr => attr.AttributeType.Name == "DllImportAttribute"); + var module = (string)dllimport.ConstructorArguments[0].Value!; + var entrypoint = (string)dllimport.NamedArguments.First(arg => arg.MemberName == "EntryPoint").TypedValue.Value!; + pinvokes.Add(new PInvoke(entrypoint, module, method)); + + string? signature = SignatureMapper.MethodToSignature(method); + if (signature == null) + { + throw new NotSupportedException($"Unsupported parameter type in method '{type.FullName}.{method.Name}'"); + } + + Log.LogMessage(MessageImportance.Low, $"Adding pinvoke signature {signature} for method '{type.FullName}.{method.Name}'"); + signatures.Add(signature); + } + } + + bool DoesMethodHaveCallbacks(MethodInfo method) + { + if (!MethodHasCallbackAttributes(method)) + return false; + + if (TryIsMethodGetParametersUnsupported(method, out string? reason)) + { + Log.LogWarning(null, "WASM0001", "", "", 0, 0, 0, 0, + $"Skipping callback '{method.DeclaringType!.FullName}::{method.Name}' because '{reason}'."); + return false; + } + + if (method.DeclaringType != null && HasAssemblyDisableRuntimeMarshallingAttribute(method.DeclaringType.Assembly)) + return true; + + // No DisableRuntimeMarshalling attribute, so check if the params/ret-type are + // blittable + bool isVoid = method.ReturnType.FullName == "System.Void"; + if (!isVoid && !IsBlittable(method.ReturnType)) + Error($"The return type '{method.ReturnType.FullName}' of pinvoke callback method '{method}' needs to be blittable."); + + foreach (var p in method.GetParameters()) + { + if (!IsBlittable(p.ParameterType)) + Error("Parameter types of pinvoke callback method '" + method + "' needs to be blittable."); + } + + return true; + } + + static bool MethodHasCallbackAttributes(MethodInfo method) + { + foreach (CustomAttributeData cattr in CustomAttributeData.GetCustomAttributes(method)) + { + try + { + if (cattr.AttributeType.FullName == "System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute" || + cattr.AttributeType.Name == "MonoPInvokeCallbackAttribute") + { + return true; + } + } + catch + { + // Assembly not found, ignore + } + } + + return false; + } + } + + public static bool IsBlittable(Type type) + { + if (type.IsPrimitive || type.IsByRef || type.IsPointer || type.IsEnum) + return true; + else + return false; + } + + private static void Error(string msg) => throw new LogAsErrorException(msg); + + private static bool HasAttribute(MemberInfo element, params string[] attributeNames) + { + foreach (CustomAttributeData cattr in CustomAttributeData.GetCustomAttributes(element)) + { + try + { + for (int i = 0; i < attributeNames.Length; ++i) + { + if (cattr.AttributeType.FullName == attributeNames [i] || + cattr.AttributeType.Name == attributeNames[i]) + { + return true; + } + } + } + catch + { + // Assembly not found, ignore + } + } + return false; + } + + private static bool TryIsMethodGetParametersUnsupported(MethodInfo method, [NotNullWhen(true)] out string? reason) + { + try + { + method.GetParameters(); + } + catch (NotSupportedException nse) + { + reason = nse.Message; + return true; + } + catch + { + // not concerned with other exceptions + } + + reason = null; + return false; + } + + private bool HasAssemblyDisableRuntimeMarshallingAttribute(Assembly assembly) + { + if (!_assemblyDisableRuntimeMarshallingAttributeCache.TryGetValue(assembly, out var value)) + { + _assemblyDisableRuntimeMarshallingAttributeCache[assembly] = value = assembly + .GetCustomAttributesData() + .Any(d => d.AttributeType.Name == "DisableRuntimeMarshallingAttribute"); + } + + value = assembly.GetCustomAttributesData().Any(d => d.AttributeType.Name == "DisableRuntimeMarshallingAttribute"); + + return value; + } +} + +#pragma warning disable CS0649 +internal sealed class PInvokeCallback +{ + public PInvokeCallback(MethodInfo method) + { + Method = method; + } + + public MethodInfo Method; + public string? EntryName; +} +#pragma warning restore CS0649 diff --git a/src/tasks/WasmAppBuilder/PInvokeTableGenerator.cs b/src/tasks/WasmAppBuilder/PInvokeTableGenerator.cs index b8ecfa7f8438eb..6bec3a98617805 100644 --- a/src/tasks/WasmAppBuilder/PInvokeTableGenerator.cs +++ b/src/tasks/WasmAppBuilder/PInvokeTableGenerator.cs @@ -35,6 +35,8 @@ public IEnumerable Generate(string[] pinvokeModules, IEnumerable var pinvokes = new List(); var callbacks = new List(); + PInvokeCollector pinvokeCollector = new(Log); + var resolver = new PathAssemblyResolver(assemblies); using var mlc = new MetadataLoadContext(resolver, "System.Private.CoreLib"); @@ -46,7 +48,7 @@ public IEnumerable Generate(string[] pinvokeModules, IEnumerable Log.LogMessage(MessageImportance.Low, $"Loading {asmPath} to scan for pinvokes"); var a = mlc.LoadFromAssemblyPath(asmPath); foreach (var type in a.GetTypes()) - CollectPInvokes(pinvokes, callbacks, signatures, type); + pinvokeCollector.CollectPInvokes(pinvokes, callbacks, signatures, type); } string tmpFileName = Path.GetTempFileName(); @@ -71,111 +73,6 @@ public IEnumerable Generate(string[] pinvokeModules, IEnumerable return signatures; } - private void CollectPInvokes(List pinvokes, List callbacks, List signatures, Type type) - { - foreach (var method in type.GetMethods(BindingFlags.DeclaredOnly | BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance)) - { - try - { - CollectPInvokesForMethod(method); - if (DoesMethodHaveCallbacks(method)) - callbacks.Add(new PInvokeCallback(method)); - } - catch (Exception ex) when (ex is not LogAsErrorException) - { - Log.LogWarning(null, "WASM0001", "", "", 0, 0, 0, 0, - $"Could not get pinvoke, or callbacks for method '{type.FullName}::{method.Name}' because '{ex.Message}'"); - } - } - - if (HasAttribute(type, "System.Runtime.InteropServices.UnmanagedFunctionPointerAttribute")) - { - var method = type.GetMethod("Invoke"); - - if (method != null) - { - string? signature = SignatureMapper.MethodToSignature(method!); - if (signature == null) - throw new NotSupportedException($"Unsupported parameter type in method '{type.FullName}.{method.Name}'"); - - - Log.LogMessage(MessageImportance.Low, $"Adding pinvoke signature {signature} for method '{type.FullName}.{method.Name}'"); - signatures.Add(signature); - } - } - - void CollectPInvokesForMethod(MethodInfo method) - { - if ((method.Attributes & MethodAttributes.PinvokeImpl) != 0) - { - var dllimport = method.CustomAttributes.First(attr => attr.AttributeType.Name == "DllImportAttribute"); - var module = (string)dllimport.ConstructorArguments[0].Value!; - var entrypoint = (string)dllimport.NamedArguments.First(arg => arg.MemberName == "EntryPoint").TypedValue.Value!; - pinvokes.Add(new PInvoke(entrypoint, module, method)); - - string? signature = SignatureMapper.MethodToSignature(method); - if (signature == null) - { - throw new NotSupportedException($"Unsupported parameter type in method '{type.FullName}.{method.Name}'"); - } - - Log.LogMessage(MessageImportance.Low, $"Adding pinvoke signature {signature} for method '{type.FullName}.{method.Name}'"); - signatures.Add(signature); - } - } - - bool DoesMethodHaveCallbacks(MethodInfo method) - { - if (!MethodHasCallbackAttributes(method)) - return false; - - if (TryIsMethodGetParametersUnsupported(method, out string? reason)) - { - Log.LogWarning(null, "WASM0001", "", "", 0, 0, 0, 0, - $"Skipping callback '{method.DeclaringType!.FullName}::{method.Name}' because '{reason}'."); - return false; - } - - if (method.DeclaringType != null && HasAssemblyDisableRuntimeMarshallingAttribute(method.DeclaringType.Assembly)) - return true; - - // No DisableRuntimeMarshalling attribute, so check if the params/ret-type are - // blittable - bool isVoid = method.ReturnType.FullName == "System.Void"; - if (!isVoid && !IsBlittable(method.ReturnType)) - Error($"The return type '{method.ReturnType.FullName}' of pinvoke callback method '{method}' needs to be blittable."); - - foreach (var p in method.GetParameters()) - { - if (!IsBlittable(p.ParameterType)) - Error("Parameter types of pinvoke callback method '" + method + "' needs to be blittable."); - } - - return true; - } - - static bool MethodHasCallbackAttributes(MethodInfo method) - { - foreach (CustomAttributeData cattr in CustomAttributeData.GetCustomAttributes(method)) - { - try - { - if (cattr.AttributeType.FullName == "System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute" || - cattr.AttributeType.Name == "MonoPInvokeCallbackAttribute") - { - return true; - } - } - catch - { - // Assembly not found, ignore - } - } - - return false; - } - } - private static bool HasAttribute(MemberInfo element, params string[] attributeNames) { foreach (CustomAttributeData cattr in CustomAttributeData.GetCustomAttributes(element)) @@ -516,55 +413,3 @@ private static bool IsBlittable(Type type) private static void Error(string msg) => throw new LogAsErrorException(msg); } - -#pragma warning disable CA1067 -internal sealed class PInvoke : IEquatable -#pragma warning restore CA1067 -{ - public PInvoke(string entryPoint, string module, MethodInfo method) - { - EntryPoint = entryPoint; - Module = module; - Method = method; - } - - public string EntryPoint; - public string Module; - public MethodInfo Method; - public bool Skip; - - public bool Equals(PInvoke? other) - => other != null && - string.Equals(EntryPoint, other.EntryPoint, StringComparison.Ordinal) && - string.Equals(Module, other.Module, StringComparison.Ordinal) && - string.Equals(Method.ToString(), other.Method.ToString(), StringComparison.Ordinal); - - public override string ToString() => $"{{ EntryPoint: {EntryPoint}, Module: {Module}, Method: {Method}, Skip: {Skip} }}"; -} - -internal sealed class PInvokeComparer : IEqualityComparer -{ - public bool Equals(PInvoke? x, PInvoke? y) - { - if (x == null && y == null) - return true; - if (x == null || y == null) - return false; - - return x.Equals(y); - } - - public int GetHashCode(PInvoke pinvoke) - => $"{pinvoke.EntryPoint}{pinvoke.Module}{pinvoke.Method}".GetHashCode(); -} - -internal sealed class PInvokeCallback -{ - public PInvokeCallback(MethodInfo method) - { - Method = method; - } - - public MethodInfo Method; - public string? EntryName; -}