Skip to content

Commit 3cb288e

Browse files
authored
feat: add a reactant.donated attr to donated args (#947)
1 parent c913bff commit 3cb288e

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

src/Compiler.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,6 +1055,17 @@ function compile_mlir!(
10551055
MLIR.API.mlirOperationDestroy(compiled_f.operation)
10561056
compiled_f.operation = MLIR.API.MlirOperation(C_NULL)
10571057

1058+
# Add a `donated` attr to the function arguments. This doesn't affect XLA, but lets us
1059+
# check which arguments were donated.
1060+
preserved_args_idx = last.(preserved_args)
1061+
for (i, arg) in enumerate(linear_args)
1062+
if i preserved_args_idx
1063+
MLIR.API.mlirFuncSetArgAttr(
1064+
func3, i - 1, "reactant.donated", MLIR.IR.UnitAttribute()
1065+
)
1066+
end
1067+
end
1068+
10581069
return Reactant.TracedUtils.CompiledMlirFnResult(
10591070
fnwrapped,
10601071
func3,

test/buffer_donation.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ end
1717
b = Reactant.to_rarray(3 * ones(2, 2))
1818
@jit(donate_fill_x_with_2(a, b))
1919
@test convert(Array, a) == 2 * ones(2, 2)
20+
hlo = @code_hlo(donate_fill_x_with_2(a, b))
21+
@test length(findall("reactant.donated", repr(hlo))) == 1
2022

2123
(; preserved_args) = Reactant.Compiler.compile_xla(donate_fill_x_with_2, (a, b))[3]
2224
preserved_args_idx = last.(preserved_args)
@@ -26,6 +28,8 @@ end
2628
b = Reactant.to_rarray(3 * ones(2, 2))
2729
@jit(donate_inplace_mul(a, b))
2830
@test convert(Array, a) == 6 * ones(2, 2)
31+
hlo = @code_hlo(donate_inplace_mul(a, b))
32+
@test length(findall("reactant.donated", repr(hlo))) == 1
2933

3034
(; preserved_args) = Reactant.Compiler.compile_xla(donate_inplace_mul, (a, b))[3]
3135
preserved_args_idx = last.(preserved_args)

0 commit comments

Comments
 (0)