Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,191 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
.Select(static (vtable, ct) => GeneratePopulateVTableMethod(vtable));

context.RegisterConcatenatedSyntaxOutputs(populateVTable, "PopulateVTable.g.cs");

IncrementalValuesProvider<MemberDeclarationSyntax> iIUnknownInterfaceTypeImplementation =
nativeToManagedStubContexts
.Collect()
.SelectMany(static (data, ct) => data.GroupBy(stub => stub.ContainingSyntaxContext))
.Select(static (context, ct) => GenerateIIUnknownInterfaceTypeImplementation(context.Key, context.Count()));

context.RegisterConcatenatedSyntaxOutputs(iIUnknownInterfaceTypeImplementation, "IIUnknownInterfaceTypeImplementation.g.cs");
}

private static MemberDeclarationSyntax GenerateIIUnknownInterfaceTypeImplementation(ContainingSyntaxContext context, int methodsCount)
{
// static Guid IIUnknownInterfaceType.Iid => new Guid("00000000-0000-0000-0000-000000000000");
var iid = PropertyDeclaration(List<AttributeListSyntax>(),
TokenList(Token(SyntaxKind.StaticKeyword)),
ParseTypeName("System.Guid"),
ExplicitInterfaceSpecifier(ParseName(TypeNames.IIUnknownInterfaceType)),
Identifier("Iid"),
AccessorList(
SingletonList(
AccessorDeclaration(SyntaxKind.GetAccessorDeclaration, Block(
SingletonList(
ReturnStatement(
ObjectCreationExpression(ParseTypeName("System.Guid"))
.WithArgumentList(
ArgumentList(
SingletonSeparatedList(
Argument(
LiteralExpression(SyntaxKind.StringLiteralExpression, Literal("00000000-0000-0000-0000-000000000000")))))))))))));
// private static readonly void** m_vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(<InterfaceName>), sizeof(void*) * <3 + numberOfInterfaceMethods>);
var m_vtable = FieldDeclaration(
List<AttributeListSyntax>(),
TokenList(
Token(SyntaxKind.StaticKeyword),
Token(SyntaxKind.PrivateKeyword),
Token(SyntaxKind.ReadOnlyKeyword)),
VariableDeclaration(
PointerType(PointerType(ParseTypeName("void"))),
SingletonSeparatedList(
VariableDeclarator("m_vtable")
.WithInitializer(
EqualsValueClause(
CastExpression(
PointerType(PointerType(ParseTypeName("void"))),
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
ParseTypeName(TypeNames.RuntimeHelpers),
IdentifierName("AllocateTypeAssociatedMemory")))
.WithArgumentList(
ArgumentList(
SeparatedList<ArgumentSyntax>(NodeOrTokenList(
Argument(
TypeOfExpression(
IdentifierName(context.ContainingSyntax[0].Identifier))),
Token(SyntaxKind.CommaToken),
Argument(
BinaryExpression(SyntaxKind.MultiplyExpression,
SizeOfExpression(
PointerType(ParseTypeName("void"))),
LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(3 + methodsCount))))))))))))));
// static void** IIUnknownInterfaceType.ManagedVirtualMethodTable
// {
// get
// {
// if (m_vtable[0] == null)
// {
// nint v0, v1, v2;
// ComWrappers.GetIUnknownImpl(out v0, out v1, out v2);
// m_vtable[0] = (void*)v0;
// m_vtable[1] = (void*)v1;
// m_vtable[2] = (void*)v2;
// Native.PopulateManagedVirtualMethodTable(m_vtable);
// }
// return m_vtable;
// }
// }
var managedVirtualMethodTableProperty = PropertyDeclaration(
List<AttributeListSyntax>(),
TokenList(Token(SyntaxKind.StaticKeyword)),
PointerType(PointerType(ParseTypeName("void"))),
ExplicitInterfaceSpecifier(ParseName(TypeNames.IIUnknownInterfaceType)),
Identifier("ManagedVirtualMethodTable"),
AccessorList(
SingletonList(
AccessorDeclaration(SyntaxKind.GetAccessorDeclaration, Block(
List(new StatementSyntax[]
{
// if (m_vtable[0] == null)
IfStatement(
BinaryExpression(
SyntaxKind.EqualsExpression,
ElementAccessExpression(
IdentifierName("m_vtable"))
.WithArgumentList(
BracketedArgumentList(
SingletonSeparatedList(
Argument(
LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0)))))),
LiteralExpression(SyntaxKind.NullLiteralExpression)),
Block(
List(new StatementSyntax[]
{
// nint v0, v1, v2;
LocalDeclarationStatement(VariableDeclaration(ParseTypeName("nint"),
SeparatedList(new VariableDeclaratorSyntax[] {
VariableDeclarator(Identifier("v0"), null, null),
VariableDeclarator(Identifier("v1"), null, null),
VariableDeclarator(Identifier("v2"), null, null),
}
))),
// ComWrappers.GetIUnknownImpl(out v0, out v1, out v2);
ExpressionStatement(
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
ParseTypeName(TypeNames.ComWrappers),
IdentifierName("GetIUnknownImpl")))
.WithArgumentList(
ArgumentList(
SeparatedList<ArgumentSyntax>(NodeOrTokenList(
Argument(IdentifierName("v0"))
.WithRefKindKeyword(Token(SyntaxKind.OutKeyword)),
Token(SyntaxKind.CommaToken),
Argument(IdentifierName("v1"))
.WithRefKindKeyword(Token(SyntaxKind.OutKeyword)),
Token(SyntaxKind.CommaToken),
Argument(IdentifierName("v2"))
.WithRefKindKeyword(Token(SyntaxKind.OutKeyword))))))),
// m_vtable[0] = (void*)v0;
ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
ElementAccessExpression(
IdentifierName("m_vtable"),
BracketedArgumentList(
SingletonSeparatedList(
Argument(
LiteralExpression(
SyntaxKind.NumericLiteralExpression,
Literal(0)))))),
CastExpression(
PointerType(
ParseTypeName("void")),
IdentifierName("v0")))),
// m_vtable[1] = (void*)v1;
ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
ElementAccessExpression(
IdentifierName("m_vtable"),
BracketedArgumentList(
SingletonSeparatedList(
Argument(
LiteralExpression(
SyntaxKind.NumericLiteralExpression,
Literal(1)))))),
CastExpression(
PointerType(
ParseTypeName("void")),
IdentifierName("v1")))),
// m_vtable[2] = (void*)v2;
ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
ElementAccessExpression(
IdentifierName("m_vtable"),
BracketedArgumentList(
SingletonSeparatedList(
Argument(
LiteralExpression(
SyntaxKind.NumericLiteralExpression,
Literal(2)))))),
CastExpression(
PointerType(
ParseTypeName("void")),
IdentifierName("v2")))),
// Native.PopulateManagedVirtualMethodTable(m_vtable);
ExpressionStatement(
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
IdentifierName("Native"),
IdentifierName("PopulateUnmanagedVirtualMethodTable")))
.WithArgumentList(
ArgumentList(
SingletonSeparatedList(
Argument(
IdentifierName("m_vtable"))))))
}))),
ReturnStatement(IdentifierName("m_vtable"))
}))))));

return context.WrapMembersInContainingSyntaxWithUnsafeModifierWithInterfaceOnInnerType(ParseTypeName(TypeNames.IIUnknownInterfaceType), iid, m_vtable, managedVirtualMethodTableProperty);
}

private static IncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax syntax, IMethodSymbol symbol, int index, StubEnvironment environment, CancellationToken ct)
Expand Down Expand Up @@ -327,7 +512,7 @@ private static MemberDeclarationSyntax GenerateNativeInterfaceMetadata(Containin
.WithModifiers(TokenList(Token(SyntaxKind.InternalKeyword), Token(SyntaxKind.StaticKeyword)))
.AddParameterListParameters(
Parameter(Identifier(VTableParameterName))
.WithType(GenericName(TypeNames.System_Span).AddTypeArgumentListArguments(IdentifierName("nint"))));
.WithType(PointerType(PointerType(PredefinedType(Token(SyntaxKind.VoidKeyword))))));

private static MemberDeclarationSyntax GeneratePopulateVTableMethod(IGrouping<ContainingSyntaxContext, IncrementalMethodStubGenerationContext> vtableMethods)
{
Expand All @@ -345,7 +530,7 @@ private static MemberDeclarationSyntax GeneratePopulateVTableMethod(IGrouping<Co
ElementAccessExpression(
IdentifierName(VTableParameterName))
.AddArgumentListArguments(Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(method.VtableIndexData.Index)))),
CastExpression(IdentifierName("nint"),
CastExpression(PointerType(PredefinedType(Token(SyntaxKind.VoidKeyword))),
CastExpression(functionPointerType,
PrefixUnaryExpression(SyntaxKind.AddressOfExpression,
IdentifierName($"ABI_{method.StubMethodSyntaxTemplate.Identifier}")))))));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,5 +135,38 @@ public MemberDeclarationSyntax WrapMembersInContainingSyntaxWithUnsafeModifier(p
}
return wrappedMember;
}
public MemberDeclarationSyntax WrapMembersInContainingSyntaxWithUnsafeModifierWithInterfaceOnInnerType(TypeSyntax interfaceName, params MemberDeclarationSyntax[] members)
{
bool addedUnsafe = false;
MemberDeclarationSyntax? wrappedMember = null;
int i = 0;
foreach (var containingType in ContainingSyntax)
{
TypeDeclarationSyntax type = TypeDeclaration(containingType.TypeKind, containingType.Identifier)
.WithModifiers(containingType.Modifiers)
.AddMembers(wrappedMember is not null ? new[] { wrappedMember } : members);
if (i == 0)
{
type = type.WithBaseList(BaseList(SeparatedList<BaseTypeSyntax>(new[] {
SimpleBaseType(interfaceName)
})));
}
if (!addedUnsafe)
{
type = type.WithModifiers(type.Modifiers.AddToModifiers(SyntaxKind.UnsafeKeyword));
}
if (containingType.TypeParameters is not null)
{
type = type.AddTypeParameterListParameters(containingType.TypeParameters.Parameters.ToArray());
}
wrappedMember = type;
i++;
}
if (ContainingNamespace is not null)
{
wrappedMember = NamespaceDeclaration(ParseName(ContainingNamespace)).AddMembers(wrappedMember);
}
return wrappedMember;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ private static ExpressionStatementSyntax GenerateStatementForManagedInvoke(Bound
{
return ExpressionStatement(invoke);
}
_ = 0;

return ExpressionStatement(
AssignmentExpression(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public static void RegisterDiagnostics(this IncrementalGeneratorInitializationCo
});
}

public static void RegisterConcatenatedSyntaxOutputs<TNode>(this IncrementalGeneratorInitializationContext context, IncrementalValuesProvider<TNode> nodes, string fileName)
public static void RegisterConcatenatedSyntaxOutputs<TNode>(this IncrementalGeneratorInitializationContext context, IncrementalValuesProvider<TNode> nodes, string fileName, string constantSource = "")
where TNode : SyntaxNode
{
IncrementalValueProvider<ImmutableArray<string>> generatedMethods = nodes
Expand All @@ -64,6 +64,7 @@ public static void RegisterConcatenatedSyntaxOutputs<TNode>(this IncrementalGene
StringBuilder source = new();
// Mark in source that the file is auto-generated.
source.AppendLine("// <auto-generated/>");
source.AppendLine(constantSource);
foreach (string generated in generatedSources)
{
source.AppendLine(generated);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,5 +110,8 @@ public static string MarshalEx(InteropGenerationOptions options)

public const string IUnmanagedObjectUnwrapper = "System.Runtime.InteropServices.Marshalling.IUnmanagedObjectUnwrapper";
public const string UnmanagedObjectUnwrapper = "System.Runtime.InteropServices.Marshalling.UnmanagedObjectUnwrapper";
public const string IIUnknownInterfaceType = "System.Runtime.InteropServices.Marshalling.IIUnknownInterfaceType";
public const string RuntimeHelpers = "System.Runtime.CompilerServices.RuntimeHelpers";
public const string ComWrappers = "System.Runtime.InteropServices.ComWrappers";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace System.Runtime.InteropServices.Marshalling
// This type implements the logic to get the managed object from the unmanaged "this" pointer.
// If we decide to not expose the VTable source generator, we don't need to expose this and we can just inline the logic
// into the generated code in the source generator.
internal sealed unsafe class ComWrappersUnwrapper : IUnmanagedObjectUnwrapper
public sealed unsafe class ComWrappersUnwrapper : IUnmanagedObjectUnwrapper
{
public static object GetObjectForUnmanagedWrapper(void* ptr)
{
Expand Down
Loading