diff --git a/src/mono/mono/component/marshal-ilgen-stub.c b/src/mono/mono/component/marshal-ilgen-stub.c
index 39a9b4e854a523..277261c6b6dde5 100644
--- a/src/mono/mono/component/marshal-ilgen-stub.c
+++ b/src/mono/mono/component/marshal-ilgen-stub.c
@@ -9,11 +9,46 @@ marshal_ilgen_available (void)
return false;
}
+static void emit_throw_exception (MonoMarshalLightweightCallbacks* lightweight_cb,
+ MonoMethodBuilder* mb, const char* exc_nspace, const char* exc_name, const char* msg)
+{
+ lightweight_cb->mb_emit_exception (mb, exc_nspace, exc_name, msg);
+}
+
static int
-stub_emit_marshal_ilgen (EmitMarshalContext *m, int argnum, MonoType *t,
- MonoMarshalSpec *spec, int conv_arg,
- MonoType **conv_arg_type, MarshalAction action, MonoMarshalLightweightCallbacks* lightweight_cb)
+stub_emit_marshal_ilgen (EmitMarshalContext* m, int argnum, MonoType* t,
+ MonoMarshalSpec* spec, int conv_arg,
+ MonoType** conv_arg_type, MarshalAction action, MonoMarshalLightweightCallbacks* lightweight_cb)
{
+ if (spec) {
+ g_assert (spec->native != MONO_NATIVE_ASANY);
+ g_assert (spec->native != MONO_NATIVE_CUSTOM);
+ }
+
+ g_assert (!m_type_is_byref(t));
+
+ switch (t->type) {
+ case MONO_TYPE_PTR:
+ case MONO_TYPE_I1:
+ case MONO_TYPE_U1:
+ case MONO_TYPE_I2:
+ case MONO_TYPE_U2:
+ case MONO_TYPE_I4:
+ case MONO_TYPE_U4:
+ case MONO_TYPE_I:
+ case MONO_TYPE_U:
+ case MONO_TYPE_R4:
+ case MONO_TYPE_R8:
+ case MONO_TYPE_I8:
+ case MONO_TYPE_U8:
+ case MONO_TYPE_FNPTR:
+ return lightweight_cb->emit_marshal_scalar (m, argnum, t, spec, conv_arg, conv_arg_type, action);
+ default:
+ emit_throw_exception (lightweight_cb, m->mb, "System", "ApplicationException",
+ g_strdup("Cannot marshal nonblittlable types without marshal-ilgen."));
+ break;
+ }
+
return 0;
}
diff --git a/src/mono/mono/metadata/marshal-lightweight.c b/src/mono/mono/metadata/marshal-lightweight.c
index a4fb731434b670..8871b2f4eee4ea 100644
--- a/src/mono/mono/metadata/marshal-lightweight.c
+++ b/src/mono/mono/metadata/marshal-lightweight.c
@@ -523,7 +523,7 @@ emit_runtime_invoke_body_ilgen (MonoMethodBuilder *mb, const char **param_names,
emit_thread_force_interrupt_checkpoint (mb);
emit_invoke_call (mb, method, sig, callsig, loc_res, virtual_, need_direct_wrapper);
- mono_mb_emit_ldloc (mb, 0);
+ mono_mb_emit_ldloc (mb, loc_res);
mono_mb_emit_byte (mb, CEE_RET);
}
diff --git a/src/mono/wasm/build/WasmApp.Native.targets b/src/mono/wasm/build/WasmApp.Native.targets
index 85b5e1858cb419..8093fb92158716 100644
--- a/src/mono/wasm/build/WasmApp.Native.targets
+++ b/src/mono/wasm/build/WasmApp.Native.targets
@@ -3,9 +3,11 @@
+
<_WasmBuildNativeCoreDependsOn>
+ _ScanAssembliesDecideLightweightMarshaler;
_WasmAotCompileApp;
_WasmStripAOTAssemblies;
_PrepareForWasmBuildNative;
@@ -33,7 +35,7 @@
<_MonoComponent Include="hot_reload;debugger" />
-
+
<_MonoComponent Include="marshal-ilgen" />
@@ -680,6 +682,16 @@
+
+
+
+
+
+
+
+
+
+
diff --git a/src/tasks/MonoTargetsTasks/MarshalingPInvokeScanner/MarshalingPInvokeScanner.cs b/src/tasks/MonoTargetsTasks/MarshalingPInvokeScanner/MarshalingPInvokeScanner.cs
new file mode 100644
index 00000000000000..de5e6e3762fdf7
--- /dev/null
+++ b/src/tasks/MonoTargetsTasks/MarshalingPInvokeScanner/MarshalingPInvokeScanner.cs
@@ -0,0 +1,157 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Collections.Generic;
+using System.Runtime.CompilerServices;
+using System.Diagnostics.CodeAnalysis;
+using System.Reflection.Metadata;
+using System.Reflection.Metadata.Ecma335;
+using System.Collections.Immutable;
+using System.IO;
+using System.Linq;
+using System.Text;
+using System.Reflection;
+using System.Reflection.PortableExecutable;
+using Microsoft.Build.Framework;
+using Microsoft.Build.Utilities;
+
+namespace MonoTargetsTasks
+{
+ public class MarshalingPInvokeScanner : Task
+ {
+ [Required]
+ public string[] Assemblies { get; set; } = Array.Empty();
+
+ [Output]
+ public string[]? IncompatibleAssemblies { get; private set; }
+
+ public override bool Execute()
+ {
+ if (Assemblies is null || Assemblies!.Length == 0)
+ {
+ Log.LogError($"{nameof(MarshalingPInvokeScanner)}.{nameof(Assemblies)} cannot be empty");
+ return false;
+ }
+
+ try
+ {
+ ExecuteInternal();
+ return !Log.HasLoggedErrors;
+ }
+ catch (LogAsErrorException e)
+ {
+ Log.LogError(e.Message);
+ return false;
+ }
+ }
+
+ private void ExecuteInternal()
+ {
+ IncompatibleAssemblies = ScanAssemblies(Assemblies);
+ }
+
+ private string[] ScanAssemblies(string[] assemblies)
+ {
+ HashSet incompatible = new HashSet();
+ MinimalMarshalingTypeCompatibilityProvider mmtcp = new(Log);
+ foreach (string aname in assemblies)
+ {
+ if (IsAssemblyIncompatible(aname, mmtcp))
+ incompatible.Add(aname);
+ }
+
+ if (mmtcp.IsSecondPassNeeded)
+ {
+ foreach (string aname in assemblies)
+ ResolveInconclusiveTypes(incompatible, aname, mmtcp);
+ }
+
+ return incompatible.ToArray();
+ }
+
+ private static string GetMethodName(MetadataReader mr, MethodDefinition md) => mr.GetString(md.Name);
+
+ private void ResolveInconclusiveTypes(HashSet incompatible, string assyPath, MinimalMarshalingTypeCompatibilityProvider mmtcp)
+ {
+ string assyName = MetadataReader.GetAssemblyName(assyPath).Name!;
+ HashSet inconclusiveTypes = mmtcp.GetInconclusiveTypesForAssembly(assyName);
+ if(inconclusiveTypes.Count == 0)
+ return;
+
+ using FileStream file = new FileStream(assyPath, FileMode.Open, FileAccess.Read, FileShare.ReadWrite);
+ using PEReader peReader = new PEReader(file);
+ MetadataReader mdtReader = peReader.GetMetadataReader();
+
+ SignatureDecoder decoder = new(mmtcp, mdtReader, null!);
+
+ foreach (TypeDefinitionHandle typeDefHandle in mdtReader.TypeDefinitions)
+ {
+ TypeDefinition typeDef = mdtReader.GetTypeDefinition(typeDefHandle);
+ string fullTypeName = string.Join(":", mdtReader.GetString(typeDef.Namespace), mdtReader.GetString(typeDef.Name));
+
+ // This is not perfect, but should work right for enums defined in other assemblies,
+ // which is the only case where we use Compatibility.Inconclusive.
+ if (inconclusiveTypes.Contains(fullTypeName) &&
+ mmtcp.GetTypeFromDefinition(mdtReader, typeDefHandle, 0) != Compatibility.Compatible)
+ {
+ Log.LogMessage(MessageImportance.Low, string.Format("Type {0} is marshaled and requires marshal-ilgen.", fullTypeName));
+
+ incompatible.Add("(unknown assembly)");
+ }
+ }
+ }
+
+ private bool IsAssemblyIncompatible(string assyPath, MinimalMarshalingTypeCompatibilityProvider mmtcp)
+ {
+ using FileStream file = new FileStream(assyPath, FileMode.Open, FileAccess.Read, FileShare.ReadWrite);
+ using PEReader peReader = new PEReader(file);
+ MetadataReader mdtReader = peReader.GetMetadataReader();
+
+ foreach(CustomAttributeHandle attrHandle in mdtReader.CustomAttributes)
+ {
+ CustomAttribute attr = mdtReader.GetCustomAttribute(attrHandle);
+
+ if(attr.Constructor.Kind == HandleKind.MethodDefinition)
+ {
+ MethodDefinitionHandle mdh = (MethodDefinitionHandle)attr.Constructor;
+ MethodDefinition md = mdtReader.GetMethodDefinition(mdh);
+ TypeDefinitionHandle tdh = md.GetDeclaringType();
+ TypeDefinition td = mdtReader.GetTypeDefinition(tdh);
+
+ if(mdtReader.GetString(td.Namespace) == "System.Runtime.CompilerServices" &&
+ mdtReader.GetString(td.Name) == "DisableRuntimeMarshallingAttribute")
+ return false;
+ }
+ }
+
+ foreach (TypeDefinitionHandle typeDefHandle in mdtReader.TypeDefinitions)
+ {
+ TypeDefinition typeDef = mdtReader.GetTypeDefinition(typeDefHandle);
+ string ns = mdtReader.GetString(typeDef.Namespace);
+ string name = mdtReader.GetString(typeDef.Name);
+
+ foreach(MethodDefinitionHandle mthDefHandle in typeDef.GetMethods())
+ {
+ MethodDefinition mthDef = mdtReader.GetMethodDefinition(mthDefHandle);
+ if(!mthDef.Attributes.HasFlag(MethodAttributes.PinvokeImpl))
+ continue;
+
+ BlobReader sgnBlobReader = mdtReader.GetBlobReader(mthDef.Signature);
+ SignatureDecoder decoder = new(mmtcp, mdtReader, null!);
+
+ MethodSignature sgn = decoder.DecodeMethodSignature(ref sgnBlobReader);
+ if(sgn.ReturnType == Compatibility.Incompatible || sgn.ParameterTypes.Any(p => p == Compatibility.Incompatible))
+ {
+ Log.LogMessage(MessageImportance.Low, string.Format("Assembly {0} requires marhsal-ilgen for method {1}.{2}:{3} (first pass).",
+ assyPath, ns, name, mdtReader.GetString(mthDef.Name)));
+
+ return true;
+ }
+ }
+ }
+
+ return false;
+ }
+ }
+}
diff --git a/src/tasks/MonoTargetsTasks/MarshalingPInvokeScanner/MinimalMarshalingTypeCompatibilityProvider.cs b/src/tasks/MonoTargetsTasks/MarshalingPInvokeScanner/MinimalMarshalingTypeCompatibilityProvider.cs
new file mode 100644
index 00000000000000..8916c1d674c138
--- /dev/null
+++ b/src/tasks/MonoTargetsTasks/MarshalingPInvokeScanner/MinimalMarshalingTypeCompatibilityProvider.cs
@@ -0,0 +1,167 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Collections.Generic;
+using System.Runtime.CompilerServices;
+using System.Diagnostics.CodeAnalysis;
+using System.Reflection.Metadata;
+using System.Reflection.Metadata.Ecma335;
+using System.Collections.Immutable;
+using System.IO;
+using System.Linq;
+using System.Text;
+using System.Reflection;
+using System.Reflection.PortableExecutable;
+using Microsoft.Build.Framework;
+using Microsoft.Build.Utilities;
+
+namespace MonoTargetsTasks
+{
+ // For some valuetypes we cannot determine if they are compatible with disabled
+ // runtime marshaling without first resolving their base types. In this case we
+ // first mark the assembly as Inconclusive and do a second pass over the collected
+ // base type references in order to decide. If the base types are System.Enum,
+ // then the valuetypes are enumerations, and are compatible.
+ internal enum Compatibility
+ {
+ Compatible,
+ Incompatible,
+ Inconclusive
+ }
+
+ internal sealed class InconclusiveCompatibilityCollection
+ {
+ private readonly Dictionary> _data = new();
+
+ public bool IsEmpty => _data.Count == 0;
+
+ public void Add(string assyName, string namespaceName, string typeName)
+ {
+ HashSet? incAssyTypes;
+
+ if(!_data.TryGetValue(assyName, out incAssyTypes))
+ {
+ incAssyTypes = new();
+ _data.Add(assyName, incAssyTypes);
+ }
+
+ incAssyTypes.Add($"{namespaceName}:{typeName}");
+ }
+
+ public HashSet EnumerateForAssembly(string assyName)
+ {
+ if(_data.TryGetValue(assyName, out HashSet? incAssyTypes))
+ return incAssyTypes!;
+
+ return new HashSet();
+ }
+ }
+
+ internal sealed class MinimalMarshalingTypeCompatibilityProvider : ISignatureTypeProvider
+ {
+ internal MinimalMarshalingTypeCompatibilityProvider(TaskLoggingHelper log)
+ {
+ _log = log;
+ }
+
+ private readonly TaskLoggingHelper _log;
+
+ // assembly name -> set of types needed for second pass
+ private readonly InconclusiveCompatibilityCollection _inconclusive = new();
+
+ public bool IsSecondPassNeeded => !_inconclusive.IsEmpty;
+ public HashSet GetInconclusiveTypesForAssembly(string assyName) => _inconclusive.EnumerateForAssembly(assyName);
+
+ public Compatibility GetArrayType(Compatibility elementType, ArrayShape shape) => Compatibility.Incompatible;
+ public Compatibility GetByReferenceType(Compatibility elementType) => Compatibility.Incompatible;
+ public Compatibility GetFunctionPointerType(MethodSignature signature) => Compatibility.Compatible;
+ public Compatibility GetGenericInstantiation(Compatibility genericType, ImmutableArray typeArguments) => genericType;
+ public Compatibility GetGenericMethodParameter(object genericContext, int index) => Compatibility.Incompatible;
+ public Compatibility GetGenericTypeParameter(object genericContext, int index) => Compatibility.Incompatible;
+ public Compatibility GetModifiedType(Compatibility modifier, Compatibility unmodifiedType, bool isRequired) => Compatibility.Incompatible;
+ public Compatibility GetPinnedType(Compatibility elementType) => Compatibility.Compatible;
+ public Compatibility GetPointerType(Compatibility elementType) => Compatibility.Compatible;
+ public Compatibility GetPrimitiveType(PrimitiveTypeCode typeCode)
+ {
+ return typeCode switch
+ {
+ PrimitiveTypeCode.Object => Compatibility.Incompatible,
+ PrimitiveTypeCode.String => Compatibility.Incompatible,
+ PrimitiveTypeCode.TypedReference => Compatibility.Incompatible,
+ _ => Compatibility.Compatible
+ };
+ }
+
+ public Compatibility GetSZArrayType(Compatibility elementType) => Compatibility.Incompatible;
+
+ public Compatibility GetTypeFromDefinition(MetadataReader reader, TypeDefinitionHandle handle, byte rawTypeKind)
+ {
+ TypeDefinition typeDef = reader.GetTypeDefinition(handle);
+ if (reader.GetString(typeDef.Namespace) == "System" &&
+ reader.GetString(typeDef.Name) == "Enum")
+ return Compatibility.Compatible;
+
+ try
+ {
+ EntityHandle baseTypeHandle = typeDef.BaseType;
+ if (baseTypeHandle.Kind == HandleKind.TypeReference)
+ {
+ TypeReference baseType = reader.GetTypeReference((TypeReferenceHandle)baseTypeHandle);
+ if (reader.GetString(typeDef.Namespace) == "System" &&
+ reader.GetString(baseType.Name) == "Enum")
+ return Compatibility.Compatible;
+ }
+ else if (baseTypeHandle.Kind == HandleKind.TypeSpecification)
+ {
+ TypeSpecification specInner = reader.GetTypeSpecification((TypeSpecificationHandle)baseTypeHandle);
+ return specInner.DecodeSignature(this, new object());
+ }
+ else if (baseTypeHandle.Kind == HandleKind.TypeDefinition)
+ {
+ TypeDefinitionHandle handleInner = (TypeDefinitionHandle)baseTypeHandle;
+ if (handle != handleInner)
+ return GetTypeFromDefinition(reader, handleInner, rawTypeKind);
+ }
+ }
+ catch(BadImageFormatException ex)
+ {
+ _log.LogMessage(MessageImportance.Low, ex.Message);
+ }
+
+ return Compatibility.Incompatible;
+ }
+
+ public Compatibility GetTypeFromReference(MetadataReader reader, TypeReferenceHandle handle, byte rawTypeKind)
+ {
+ if (rawTypeKind == 0x11 /*ELEMENT_TYPE_VALUETYPE*/)
+ {
+ TypeReference typeRef = reader.GetTypeReference(handle);
+ EntityHandle scope = typeRef.ResolutionScope;
+
+ if (scope.Kind == HandleKind.AssemblyReference)
+ {
+ AssemblyReferenceHandle assyRefHandle = (AssemblyReferenceHandle)typeRef.ResolutionScope;
+ AssemblyReference assyRef = reader.GetAssemblyReference(assyRefHandle);
+
+ _inconclusive.Add(assyName: reader.GetString(assyRef.Name),
+ namespaceName: reader.GetString(typeRef.Namespace), typeName: reader.GetString(typeRef.Name));
+ return Compatibility.Inconclusive;
+ }
+ else
+ {
+ throw new NotImplementedException(string.Format("Unsupported ResolutionScope kind '{0}' used in type {1}:{2}.",
+ scope.Kind.ToString(), reader.GetString(typeRef.Namespace), reader.GetString(typeRef.Name)));
+ }
+ }
+
+ return Compatibility.Incompatible;
+ }
+
+ public Compatibility GetTypeFromSpecification(MetadataReader reader, object genericContext, TypeSpecificationHandle handle, byte rawTypeKind)
+ {
+ TypeSpecification spec = reader.GetTypeSpecification((TypeSpecificationHandle)handle);
+ return spec.DecodeSignature(this, genericContext);
+ }
+ }
+}
diff --git a/src/tasks/MonoTargetsTasks/MonoTargetsTasks.csproj b/src/tasks/MonoTargetsTasks/MonoTargetsTasks.csproj
index 181538f565e422..340cc4e5463b27 100644
--- a/src/tasks/MonoTargetsTasks/MonoTargetsTasks.csproj
+++ b/src/tasks/MonoTargetsTasks/MonoTargetsTasks.csproj
@@ -17,6 +17,7 @@
+
@@ -30,6 +31,8 @@
+
+
diff --git a/src/tasks/WasmAppBuilder/PInvokeCollector.cs b/src/tasks/WasmAppBuilder/PInvokeCollector.cs
new file mode 100644
index 00000000000000..0eeecabc904d66
--- /dev/null
+++ b/src/tasks/WasmAppBuilder/PInvokeCollector.cs
@@ -0,0 +1,250 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Collections.Generic;
+using System;
+using System.Linq;
+using System.Diagnostics.CodeAnalysis;
+using System.Reflection;
+using Microsoft.Build.Framework;
+using Microsoft.Build.Utilities;
+using Microsoft.Build.Tasks;
+
+#pragma warning disable CA1067
+#pragma warning disable CS0649
+internal sealed class PInvoke : IEquatable
+#pragma warning restore CA1067
+{
+ public PInvoke(string entryPoint, string module, MethodInfo method)
+ {
+ EntryPoint = entryPoint;
+ Module = module;
+ Method = method;
+ }
+
+ public string EntryPoint;
+ public string Module;
+ public MethodInfo Method;
+ public bool Skip;
+
+ public bool Equals(PInvoke? other)
+ => other != null &&
+ string.Equals(EntryPoint, other.EntryPoint, StringComparison.Ordinal) &&
+ string.Equals(Module, other.Module, StringComparison.Ordinal) &&
+ string.Equals(Method.ToString(), other.Method.ToString(), StringComparison.Ordinal);
+
+ public override string ToString() => $"{{ EntryPoint: {EntryPoint}, Module: {Module}, Method: {Method}, Skip: {Skip} }}";
+}
+#pragma warning restore CS0649
+
+internal sealed class PInvokeComparer : IEqualityComparer
+{
+ public bool Equals(PInvoke? x, PInvoke? y)
+ {
+ if (x == null && y == null)
+ return true;
+ if (x == null || y == null)
+ return false;
+
+ return x.Equals(y);
+ }
+
+ public int GetHashCode(PInvoke pinvoke)
+ => $"{pinvoke.EntryPoint}{pinvoke.Module}{pinvoke.Method}".GetHashCode();
+}
+
+
+internal sealed class PInvokeCollector {
+ private readonly Dictionary _assemblyDisableRuntimeMarshallingAttributeCache = new();
+ private TaskLoggingHelper Log { get; init; }
+
+ public PInvokeCollector(TaskLoggingHelper log)
+ {
+ Log = log;
+ }
+
+ public void CollectPInvokes(List pinvokes, List callbacks, List signatures, Type type)
+ {
+ foreach (var method in type.GetMethods(BindingFlags.DeclaredOnly | BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance))
+ {
+ try
+ {
+ CollectPInvokesForMethod(method);
+ if (DoesMethodHaveCallbacks(method))
+ callbacks.Add(new PInvokeCallback(method));
+ }
+ catch (Exception ex) when (ex is not LogAsErrorException)
+ {
+ Log.LogWarning(null, "WASM0001", "", "", 0, 0, 0, 0,
+ $"Could not get pinvoke, or callbacks for method '{type.FullName}::{method.Name}' because '{ex.Message}'");
+ }
+ }
+
+ if (HasAttribute(type, "System.Runtime.InteropServices.UnmanagedFunctionPointerAttribute"))
+ {
+ var method = type.GetMethod("Invoke");
+
+ if (method != null)
+ {
+ string? signature = SignatureMapper.MethodToSignature(method!);
+ if (signature == null)
+ throw new NotSupportedException($"Unsupported parameter type in method '{type.FullName}.{method.Name}'");
+
+
+ Log.LogMessage(MessageImportance.Low, $"Adding pinvoke signature {signature} for method '{type.FullName}.{method.Name}'");
+ signatures.Add(signature);
+ }
+ }
+
+ void CollectPInvokesForMethod(MethodInfo method)
+ {
+ if ((method.Attributes & MethodAttributes.PinvokeImpl) != 0)
+ {
+ var dllimport = method.CustomAttributes.First(attr => attr.AttributeType.Name == "DllImportAttribute");
+ var module = (string)dllimport.ConstructorArguments[0].Value!;
+ var entrypoint = (string)dllimport.NamedArguments.First(arg => arg.MemberName == "EntryPoint").TypedValue.Value!;
+ pinvokes.Add(new PInvoke(entrypoint, module, method));
+
+ string? signature = SignatureMapper.MethodToSignature(method);
+ if (signature == null)
+ {
+ throw new NotSupportedException($"Unsupported parameter type in method '{type.FullName}.{method.Name}'");
+ }
+
+ Log.LogMessage(MessageImportance.Low, $"Adding pinvoke signature {signature} for method '{type.FullName}.{method.Name}'");
+ signatures.Add(signature);
+ }
+ }
+
+ bool DoesMethodHaveCallbacks(MethodInfo method)
+ {
+ if (!MethodHasCallbackAttributes(method))
+ return false;
+
+ if (TryIsMethodGetParametersUnsupported(method, out string? reason))
+ {
+ Log.LogWarning(null, "WASM0001", "", "", 0, 0, 0, 0,
+ $"Skipping callback '{method.DeclaringType!.FullName}::{method.Name}' because '{reason}'.");
+ return false;
+ }
+
+ if (method.DeclaringType != null && HasAssemblyDisableRuntimeMarshallingAttribute(method.DeclaringType.Assembly))
+ return true;
+
+ // No DisableRuntimeMarshalling attribute, so check if the params/ret-type are
+ // blittable
+ bool isVoid = method.ReturnType.FullName == "System.Void";
+ if (!isVoid && !IsBlittable(method.ReturnType))
+ Error($"The return type '{method.ReturnType.FullName}' of pinvoke callback method '{method}' needs to be blittable.");
+
+ foreach (var p in method.GetParameters())
+ {
+ if (!IsBlittable(p.ParameterType))
+ Error("Parameter types of pinvoke callback method '" + method + "' needs to be blittable.");
+ }
+
+ return true;
+ }
+
+ static bool MethodHasCallbackAttributes(MethodInfo method)
+ {
+ foreach (CustomAttributeData cattr in CustomAttributeData.GetCustomAttributes(method))
+ {
+ try
+ {
+ if (cattr.AttributeType.FullName == "System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute" ||
+ cattr.AttributeType.Name == "MonoPInvokeCallbackAttribute")
+ {
+ return true;
+ }
+ }
+ catch
+ {
+ // Assembly not found, ignore
+ }
+ }
+
+ return false;
+ }
+ }
+
+ public static bool IsBlittable(Type type)
+ {
+ if (type.IsPrimitive || type.IsByRef || type.IsPointer || type.IsEnum)
+ return true;
+ else
+ return false;
+ }
+
+ private static void Error(string msg) => throw new LogAsErrorException(msg);
+
+ private static bool HasAttribute(MemberInfo element, params string[] attributeNames)
+ {
+ foreach (CustomAttributeData cattr in CustomAttributeData.GetCustomAttributes(element))
+ {
+ try
+ {
+ for (int i = 0; i < attributeNames.Length; ++i)
+ {
+ if (cattr.AttributeType.FullName == attributeNames [i] ||
+ cattr.AttributeType.Name == attributeNames[i])
+ {
+ return true;
+ }
+ }
+ }
+ catch
+ {
+ // Assembly not found, ignore
+ }
+ }
+ return false;
+ }
+
+ private static bool TryIsMethodGetParametersUnsupported(MethodInfo method, [NotNullWhen(true)] out string? reason)
+ {
+ try
+ {
+ method.GetParameters();
+ }
+ catch (NotSupportedException nse)
+ {
+ reason = nse.Message;
+ return true;
+ }
+ catch
+ {
+ // not concerned with other exceptions
+ }
+
+ reason = null;
+ return false;
+ }
+
+ private bool HasAssemblyDisableRuntimeMarshallingAttribute(Assembly assembly)
+ {
+ if (!_assemblyDisableRuntimeMarshallingAttributeCache.TryGetValue(assembly, out var value))
+ {
+ _assemblyDisableRuntimeMarshallingAttributeCache[assembly] = value = assembly
+ .GetCustomAttributesData()
+ .Any(d => d.AttributeType.Name == "DisableRuntimeMarshallingAttribute");
+ }
+
+ value = assembly.GetCustomAttributesData().Any(d => d.AttributeType.Name == "DisableRuntimeMarshallingAttribute");
+
+ return value;
+ }
+}
+
+#pragma warning disable CS0649
+internal sealed class PInvokeCallback
+{
+ public PInvokeCallback(MethodInfo method)
+ {
+ Method = method;
+ }
+
+ public MethodInfo Method;
+ public string? EntryName;
+}
+#pragma warning restore CS0649
diff --git a/src/tasks/WasmAppBuilder/PInvokeTableGenerator.cs b/src/tasks/WasmAppBuilder/PInvokeTableGenerator.cs
index b8ecfa7f8438eb..6bec3a98617805 100644
--- a/src/tasks/WasmAppBuilder/PInvokeTableGenerator.cs
+++ b/src/tasks/WasmAppBuilder/PInvokeTableGenerator.cs
@@ -35,6 +35,8 @@ public IEnumerable Generate(string[] pinvokeModules, IEnumerable
var pinvokes = new List();
var callbacks = new List();
+ PInvokeCollector pinvokeCollector = new(Log);
+
var resolver = new PathAssemblyResolver(assemblies);
using var mlc = new MetadataLoadContext(resolver, "System.Private.CoreLib");
@@ -46,7 +48,7 @@ public IEnumerable Generate(string[] pinvokeModules, IEnumerable
Log.LogMessage(MessageImportance.Low, $"Loading {asmPath} to scan for pinvokes");
var a = mlc.LoadFromAssemblyPath(asmPath);
foreach (var type in a.GetTypes())
- CollectPInvokes(pinvokes, callbacks, signatures, type);
+ pinvokeCollector.CollectPInvokes(pinvokes, callbacks, signatures, type);
}
string tmpFileName = Path.GetTempFileName();
@@ -71,111 +73,6 @@ public IEnumerable Generate(string[] pinvokeModules, IEnumerable
return signatures;
}
- private void CollectPInvokes(List pinvokes, List callbacks, List signatures, Type type)
- {
- foreach (var method in type.GetMethods(BindingFlags.DeclaredOnly | BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance))
- {
- try
- {
- CollectPInvokesForMethod(method);
- if (DoesMethodHaveCallbacks(method))
- callbacks.Add(new PInvokeCallback(method));
- }
- catch (Exception ex) when (ex is not LogAsErrorException)
- {
- Log.LogWarning(null, "WASM0001", "", "", 0, 0, 0, 0,
- $"Could not get pinvoke, or callbacks for method '{type.FullName}::{method.Name}' because '{ex.Message}'");
- }
- }
-
- if (HasAttribute(type, "System.Runtime.InteropServices.UnmanagedFunctionPointerAttribute"))
- {
- var method = type.GetMethod("Invoke");
-
- if (method != null)
- {
- string? signature = SignatureMapper.MethodToSignature(method!);
- if (signature == null)
- throw new NotSupportedException($"Unsupported parameter type in method '{type.FullName}.{method.Name}'");
-
-
- Log.LogMessage(MessageImportance.Low, $"Adding pinvoke signature {signature} for method '{type.FullName}.{method.Name}'");
- signatures.Add(signature);
- }
- }
-
- void CollectPInvokesForMethod(MethodInfo method)
- {
- if ((method.Attributes & MethodAttributes.PinvokeImpl) != 0)
- {
- var dllimport = method.CustomAttributes.First(attr => attr.AttributeType.Name == "DllImportAttribute");
- var module = (string)dllimport.ConstructorArguments[0].Value!;
- var entrypoint = (string)dllimport.NamedArguments.First(arg => arg.MemberName == "EntryPoint").TypedValue.Value!;
- pinvokes.Add(new PInvoke(entrypoint, module, method));
-
- string? signature = SignatureMapper.MethodToSignature(method);
- if (signature == null)
- {
- throw new NotSupportedException($"Unsupported parameter type in method '{type.FullName}.{method.Name}'");
- }
-
- Log.LogMessage(MessageImportance.Low, $"Adding pinvoke signature {signature} for method '{type.FullName}.{method.Name}'");
- signatures.Add(signature);
- }
- }
-
- bool DoesMethodHaveCallbacks(MethodInfo method)
- {
- if (!MethodHasCallbackAttributes(method))
- return false;
-
- if (TryIsMethodGetParametersUnsupported(method, out string? reason))
- {
- Log.LogWarning(null, "WASM0001", "", "", 0, 0, 0, 0,
- $"Skipping callback '{method.DeclaringType!.FullName}::{method.Name}' because '{reason}'.");
- return false;
- }
-
- if (method.DeclaringType != null && HasAssemblyDisableRuntimeMarshallingAttribute(method.DeclaringType.Assembly))
- return true;
-
- // No DisableRuntimeMarshalling attribute, so check if the params/ret-type are
- // blittable
- bool isVoid = method.ReturnType.FullName == "System.Void";
- if (!isVoid && !IsBlittable(method.ReturnType))
- Error($"The return type '{method.ReturnType.FullName}' of pinvoke callback method '{method}' needs to be blittable.");
-
- foreach (var p in method.GetParameters())
- {
- if (!IsBlittable(p.ParameterType))
- Error("Parameter types of pinvoke callback method '" + method + "' needs to be blittable.");
- }
-
- return true;
- }
-
- static bool MethodHasCallbackAttributes(MethodInfo method)
- {
- foreach (CustomAttributeData cattr in CustomAttributeData.GetCustomAttributes(method))
- {
- try
- {
- if (cattr.AttributeType.FullName == "System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute" ||
- cattr.AttributeType.Name == "MonoPInvokeCallbackAttribute")
- {
- return true;
- }
- }
- catch
- {
- // Assembly not found, ignore
- }
- }
-
- return false;
- }
- }
-
private static bool HasAttribute(MemberInfo element, params string[] attributeNames)
{
foreach (CustomAttributeData cattr in CustomAttributeData.GetCustomAttributes(element))
@@ -516,55 +413,3 @@ private static bool IsBlittable(Type type)
private static void Error(string msg) => throw new LogAsErrorException(msg);
}
-
-#pragma warning disable CA1067
-internal sealed class PInvoke : IEquatable
-#pragma warning restore CA1067
-{
- public PInvoke(string entryPoint, string module, MethodInfo method)
- {
- EntryPoint = entryPoint;
- Module = module;
- Method = method;
- }
-
- public string EntryPoint;
- public string Module;
- public MethodInfo Method;
- public bool Skip;
-
- public bool Equals(PInvoke? other)
- => other != null &&
- string.Equals(EntryPoint, other.EntryPoint, StringComparison.Ordinal) &&
- string.Equals(Module, other.Module, StringComparison.Ordinal) &&
- string.Equals(Method.ToString(), other.Method.ToString(), StringComparison.Ordinal);
-
- public override string ToString() => $"{{ EntryPoint: {EntryPoint}, Module: {Module}, Method: {Method}, Skip: {Skip} }}";
-}
-
-internal sealed class PInvokeComparer : IEqualityComparer
-{
- public bool Equals(PInvoke? x, PInvoke? y)
- {
- if (x == null && y == null)
- return true;
- if (x == null || y == null)
- return false;
-
- return x.Equals(y);
- }
-
- public int GetHashCode(PInvoke pinvoke)
- => $"{pinvoke.EntryPoint}{pinvoke.Module}{pinvoke.Method}".GetHashCode();
-}
-
-internal sealed class PInvokeCallback
-{
- public PInvokeCallback(MethodInfo method)
- {
- Method = method;
- }
-
- public MethodInfo Method;
- public string? EntryName;
-}