@@ -1884,13 +1884,37 @@ function codegen_flatten!(
1884
1884
return flatten_names, flatten_code, resharded_inputs
1885
1885
end
1886
1886
1887
- function donate_argument! (donated_args_mask, carg, i:: Int , donated_buffers, path)
1887
+ function donate_argument! (
1888
+ donated_args_mask,
1889
+ carg:: Union{ConcretePJRTNumber,ConcretePJRTArray} ,
1890
+ i:: Int ,
1891
+ donated_buffers,
1892
+ path,
1893
+ )
1894
+ if donated_args_mask[i]
1895
+ buffers = Tuple (d. buffer for d in carg. data)
1896
+ if buffers in donated_buffers
1897
+ error (" Donated buffer $(carg. data) is already marked as donated. Can't donate \
1898
+ the same buffer multiple times. The argument is present at $(path) " )
1899
+ end
1900
+ push! (donated_buffers, buffers)
1901
+ Reactant. mark_donated! (carg)
1902
+ end
1903
+ end
1904
+
1905
+ function donate_argument! (
1906
+ donated_args_mask,
1907
+ carg:: Union{ConcreteIFRTNumber,ConcreteIFRTArray} ,
1908
+ i:: Int ,
1909
+ donated_buffers,
1910
+ path,
1911
+ )
1888
1912
if donated_args_mask[i]
1889
- if carg. data in donated_buffers
1913
+ if carg. data. buffer in donated_buffers
1890
1914
error (" Donated buffer $(carg. data) is already marked as donated. Can't donate \
1891
1915
the same buffer multiple times. The argument is present at $(path) " )
1892
1916
end
1893
- push! (donated_buffers, carg. data)
1917
+ push! (donated_buffers, carg. data. buffer )
1894
1918
Reactant. mark_donated! (carg)
1895
1919
end
1896
1920
end
@@ -2502,9 +2526,9 @@ function compile(f, args; sync=false, kwargs...)
2502
2526
fname = gensym (Symbol (Symbol (f), :_reactant ))
2503
2527
2504
2528
donated_buffers_set = if XLA. runtime (client) isa Val{:PJRT }
2505
- :(Base. IdSet {NTuple{<:Any,XLA.PJRT.AsyncBuffer }} ())
2529
+ :(Base. IdSet {NTuple{<:Any,XLA.PJRT.Buffer }} ())
2506
2530
else
2507
- :(Base. IdSet {XLA.IFRT.AsyncArray } ())
2531
+ :(Base. IdSet {XLA.IFRT.Array } ())
2508
2532
end
2509
2533
2510
2534
body = quote
0 commit comments