Skip to content

Commit f30526b

Browse files
authored
fix: codegen to fix aliasing issues (#1115)
* fix: codegen to fix aliasing issues * fix: incorrect logic * test: more tests * fix: overload deepcopy_internal
1 parent 3a51bf5 commit f30526b

File tree

3 files changed

+63
-9
lines changed

3 files changed

+63
-9
lines changed

src/Compiler.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1262,8 +1262,10 @@ function compile_mlir!(
12621262
nresults = MLIR.IR.Value[]
12631263
linear_results2 = TracedType[]
12641264
results_mask = falses(length(results))
1265+
12651266
for (i, op) in enumerate(results)
1266-
if !MLIR.IR.is_block_arg(op)
1267+
if !MLIR.IR.is_block_arg(op) ||
1268+
!Reactant.TracedUtils.has_idx(linear_results[i], :args) # new buffer
12671269
push!(nresults, op)
12681270
push!(linear_results2, linear_results[i])
12691271
results_mask[i] = true

src/ConcreteRArray.jl

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,28 @@ for runtime in (:PJRT, :IFRT)
1717
end
1818
end
1919

20+
# copy
21+
function Base.copy(x::Union{AbstractConcreteArray,AbstractConcreteNumber})
22+
fn = Reactant.compile(copy, (x,))
23+
return fn(x)
24+
end
25+
26+
# deepcopy
27+
function Base.deepcopy(x::Union{AbstractConcreteArray,AbstractConcreteNumber})
28+
fn = Reactant.compile(copy, (x,))
29+
return fn(x)
30+
end
31+
32+
# One more reason why users shouldn't call `deepcopy`
33+
function Base.deepcopy_internal(
34+
x::Union{AbstractConcreteArray,AbstractConcreteNumber}, stackdict::IdDict
35+
)
36+
if haskey(stackdict, x)
37+
return stackdict[x]::typeof(x)
38+
end
39+
return deepcopy(x)
40+
end
41+
2042
Base.size(::AbstractConcreteNumber) = ()
2143
Base.real(x::AbstractConcreteNumber{<:Real}) = x
2244
function Base.rtoldefault(T::Type{<:AbstractConcreteNumber})
@@ -410,14 +432,10 @@ end
410432

411433
for aType in (:ConcretePJRTArray, :ConcreteIFRTArray)
412434
@eval function Base.copyto!(dest::$(aType), src::$(aType))
413-
if dest.sharding == src.sharding &&
414-
XLA.device(dest) == XLA.device(src) &&
415-
XLA.client(dest) == XLA.client(src)
416-
dest.data = src.data
417-
else
418-
fn = compile(mycopyto!, (dest, src))
419-
fn(dest, src)
420-
end
435+
# We can't directly set the data field. it will alias the inner buffers without
436+
# actually copying them.
437+
fn = compile(mycopyto!, (dest, src))
438+
fn(dest, src)
421439
return dest
422440
end
423441
end

test/basic.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,3 +1117,37 @@ end
11171117
@test_throws Reactant.XLA.ReactantInternalError Reactant.XLA.allocatorstats()
11181118
end
11191119
end
1120+
1121+
@testset "copy/deepcopy" begin
1122+
for op in (copy, deepcopy)
1123+
x = Reactant.to_rarray(ones(4, 4))
1124+
if x isa Reactant.ConcretePJRTArray
1125+
orig_ptr = only(x.data).buffer.buffer
1126+
y = op(x)
1127+
@test y isa Reactant.ConcretePJRTArray
1128+
@test only(y.data).buffer.buffer != orig_ptr
1129+
@test only(x.data).buffer.buffer == orig_ptr
1130+
else
1131+
orig_ptr = x.data.buffer.buffer
1132+
y = op(x)
1133+
@test y isa Reactant.ConcreteIFRTArray
1134+
@test y.data.buffer.buffer != orig_ptr
1135+
@test x.data.buffer.buffer == orig_ptr
1136+
end
1137+
1138+
x = Reactant.to_rarray(4.0; track_numbers=Number)
1139+
if x isa Reactant.ConcretePJRTNumber
1140+
orig_ptr = only(x.data).buffer.buffer
1141+
y = op(x)
1142+
@test y isa Reactant.ConcretePJRTNumber
1143+
@test only(y.data).buffer.buffer != orig_ptr
1144+
@test only(x.data).buffer.buffer == orig_ptr
1145+
else
1146+
orig_ptr = x.data.buffer.buffer
1147+
y = op(x)
1148+
@test y isa Reactant.ConcreteIFRTNumber
1149+
@test y.data.buffer.buffer != orig_ptr
1150+
@test x.data.buffer.buffer == orig_ptr
1151+
end
1152+
end
1153+
end

0 commit comments

Comments
 (0)