Skip to content

Commit 6009cca

Browse files
authored
fix: more aggressive checks for buffer donation (#1110)
fix: generator usage
1 parent f30526b commit 6009cca

File tree

1 file changed

+29
-5
lines changed

1 file changed

+29
-5
lines changed

src/Compiler.jl

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1884,13 +1884,37 @@ function codegen_flatten!(
18841884
return flatten_names, flatten_code, resharded_inputs
18851885
end
18861886

1887-
function donate_argument!(donated_args_mask, carg, i::Int, donated_buffers, path)
1887+
function donate_argument!(
1888+
donated_args_mask,
1889+
carg::Union{ConcretePJRTNumber,ConcretePJRTArray},
1890+
i::Int,
1891+
donated_buffers,
1892+
path,
1893+
)
1894+
if donated_args_mask[i]
1895+
buffers = Tuple(d.buffer for d in carg.data)
1896+
if buffers in donated_buffers
1897+
error("Donated buffer $(carg.data) is already marked as donated. Can't donate \
1898+
the same buffer multiple times. The argument is present at $(path)")
1899+
end
1900+
push!(donated_buffers, buffers)
1901+
Reactant.mark_donated!(carg)
1902+
end
1903+
end
1904+
1905+
function donate_argument!(
1906+
donated_args_mask,
1907+
carg::Union{ConcreteIFRTNumber,ConcreteIFRTArray},
1908+
i::Int,
1909+
donated_buffers,
1910+
path,
1911+
)
18881912
if donated_args_mask[i]
1889-
if carg.data in donated_buffers
1913+
if carg.data.buffer in donated_buffers
18901914
error("Donated buffer $(carg.data) is already marked as donated. Can't donate \
18911915
the same buffer multiple times. The argument is present at $(path)")
18921916
end
1893-
push!(donated_buffers, carg.data)
1917+
push!(donated_buffers, carg.data.buffer)
18941918
Reactant.mark_donated!(carg)
18951919
end
18961920
end
@@ -2502,9 +2526,9 @@ function compile(f, args; sync=false, kwargs...)
25022526
fname = gensym(Symbol(Symbol(f), :_reactant))
25032527

25042528
donated_buffers_set = if XLA.runtime(client) isa Val{:PJRT}
2505-
:(Base.IdSet{NTuple{<:Any,XLA.PJRT.AsyncBuffer}}())
2529+
:(Base.IdSet{NTuple{<:Any,XLA.PJRT.Buffer}}())
25062530
else
2507-
:(Base.IdSet{XLA.IFRT.AsyncArray}())
2531+
:(Base.IdSet{XLA.IFRT.Array}())
25082532
end
25092533

25102534
body = quote

0 commit comments

Comments
 (0)