@@ -14,13 +14,13 @@ namespace Microsoft.Interop
1414 internal sealed class StatefulValueMarshalling : ICustomTypeMarshallingStrategy
1515 {
1616 internal const string MarshallerIdentifier = "marshaller" ;
17- private readonly TypeSyntax _marshallerTypeSyntax ;
17+ private readonly ManagedTypeInfo _marshallerType ;
1818 private readonly TypeSyntax _nativeTypeSyntax ;
1919 private readonly MarshallerShape _shape ;
2020
21- public StatefulValueMarshalling ( TypeSyntax marshallerTypeSyntax , TypeSyntax nativeTypeSyntax , MarshallerShape shape )
21+ public StatefulValueMarshalling ( ManagedTypeInfo marshallerType , TypeSyntax nativeTypeSyntax , MarshallerShape shape )
2222 {
23- _marshallerTypeSyntax = marshallerTypeSyntax ;
23+ _marshallerType = marshallerType ;
2424 _nativeTypeSyntax = nativeTypeSyntax ;
2525 _shape = shape ;
2626 }
@@ -140,10 +140,23 @@ public IEnumerable<StatementSyntax> GenerateUnmarshalCaptureStatements(TypePosit
140140 public IEnumerable < StatementSyntax > GenerateSetupStatements ( TypePositionInfo info , StubCodeContext context )
141141 {
142142 // <marshaller> = new();
143- yield return MarshallerHelpers . Declare (
144- _marshallerTypeSyntax ,
143+ LocalDeclarationStatementSyntax declaration = MarshallerHelpers . Declare (
144+ _marshallerType . Syntax ,
145145 context . GetAdditionalIdentifier ( info , MarshallerIdentifier ) ,
146146 ImplicitObjectCreationExpression ( ArgumentList ( ) , initializer : null ) ) ;
147+
148+ // For byref-like marshaller types, we'll mark them as scoped.
149+ // Byref-like types can capture references, so by default the compiler has to worry that
150+ // they could enable those references to escape the current stack frame.
151+ // In particular, this can interact poorly with the caller-allocated-buffer marshalling
152+ // support and make the simple `marshaller.FromManaged(managed, stackalloc X[i])` expression
153+ // illegal. Mark the marshaller type as scoped so the compiler knows that it won't escape.
154+ if ( _marshallerType is ValueTypeInfo { IsByRefLike : true } )
155+ {
156+ declaration = declaration . AddModifiers ( Token ( SyntaxKind . ScopedKeyword ) ) ;
157+ }
158+
159+ yield return declaration ;
147160 }
148161
149162 public IEnumerable < StatementSyntax > GeneratePinStatements ( TypePositionInfo info , StubCodeContext context )
@@ -218,28 +231,9 @@ public IEnumerable<StatementSyntax> GenerateMarshalStatements(TypePositionInfo i
218231
219232 IEnumerable < StatementSyntax > GenerateCallerAllocatedBufferMarshalStatements ( )
220233 {
221- // TODO: Update once we can consume the scoped keword. We should be able to simplify this once we get that API.
222- string stackPtrIdentifier = context . GetAdditionalIdentifier ( info , "stackptr" ) ;
223- // <bufferElementType>* <managedIdentifier>__stackptr = stackalloc <bufferElementType>[<_bufferSize>];
224- yield return LocalDeclarationStatement (
225- VariableDeclaration (
226- PointerType ( _bufferElementType ) ,
227- SingletonSeparatedList (
228- VariableDeclarator ( stackPtrIdentifier )
229- . WithInitializer ( EqualsValueClause (
230- StackAllocArrayCreationExpression (
231- ArrayType (
232- _bufferElementType ,
233- SingletonList ( ArrayRankSpecifier ( SingletonSeparatedList < ExpressionSyntax > (
234- MemberAccessExpression ( SyntaxKind . SimpleMemberAccessExpression ,
235- _marshallerType ,
236- IdentifierName ( ShapeMemberNames . BufferSize ) )
237- ) ) ) ) ) ) ) ) ) ) ;
238-
239-
240234 ( string managedIdentifier , _ ) = context . GetIdentifiers ( info ) ;
241235
242- // <marshaller>.FromManaged(<managedIdentifier>, new Span <bufferElementType>(<stackPtrIdentifier>, < marshallerType>.BufferSize) );
236+ // <marshaller>.FromManaged(<managedIdentifier>, stackalloc <bufferElementType>[< marshallerType>.BufferSize] );
243237 yield return ExpressionStatement (
244238 InvocationExpression (
245239 MemberAccessExpression ( SyntaxKind . SimpleMemberAccessExpression ,
@@ -249,19 +243,13 @@ IEnumerable<StatementSyntax> GenerateCallerAllocatedBufferMarshalStatements()
249243 new [ ]
250244 {
251245 Argument ( IdentifierName ( managedIdentifier ) ) ,
252- Argument (
253- ObjectCreationExpression (
254- GenericName ( Identifier ( TypeNames . System_Span ) ,
255- TypeArgumentList ( SingletonSeparatedList (
256- _bufferElementType ) ) ) )
257- . WithArgumentList (
258- ArgumentList ( SeparatedList ( new ArgumentSyntax [ ]
259- {
260- Argument ( IdentifierName ( stackPtrIdentifier ) ) ,
261- Argument ( MemberAccessExpression ( SyntaxKind . SimpleMemberAccessExpression ,
246+ Argument ( StackAllocArrayCreationExpression (
247+ ArrayType (
248+ _bufferElementType ,
249+ SingletonList ( ArrayRankSpecifier ( SingletonSeparatedList < ExpressionSyntax > (
250+ MemberAccessExpression ( SyntaxKind . SimpleMemberAccessExpression ,
262251 _marshallerType ,
263- IdentifierName ( ShapeMemberNames . BufferSize ) ) )
264- } ) ) ) )
252+ IdentifierName ( ShapeMemberNames . BufferSize ) ) ) ) ) ) ) )
265253 } ) ) ) ) ;
266254 }
267255 }
0 commit comments