Skip to content

Commit befa399

Browse files
jumerckxglwagnergithub-actions[bot]avik-pal
authored
feat: add specialization of copyto! from julia array to concreterarray (#942)
* add specialization of copyto! from julia array to concreterarray * support subarray et. al. as well * Update src/ConcreteRArray.jl Co-authored-by: Gregory L. Wagner <[email protected]> * Update src/ConcreteRArray.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix: proper copyto * fix: restrict types --------- Co-authored-by: Gregory L. Wagner <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Avik Pal <[email protected]>
1 parent 358ff04 commit befa399

File tree

4 files changed

+41
-12
lines changed

4 files changed

+41
-12
lines changed

lib/ReactantCore/src/ReactantCore.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,11 @@ function trace_for(mod, expr; track_numbers)
231231
end
232232

233233
$(ReactantCore).traced_while(
234-
cond_fn, body_fn, args; track_numbers=$(track_numbers), verify_arg_names=$(QuoteNode(args_init))
234+
cond_fn,
235+
body_fn,
236+
args;
237+
track_numbers=$(track_numbers),
238+
verify_arg_names=$(QuoteNode(args_init)),
235239
)
236240
end
237241
end

src/ConcreteRArray.jl

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -398,18 +398,25 @@ function Base.copy(bc::Base.Broadcast.Broadcasted{Broadcast.ArrayStyle{ConcreteI
398398
return fn(bc.args...)
399399
end
400400

401-
# XXX: This is not necessarily correct. We need to check for sharding and also device
402-
# compatibility.
403-
function Base.copyto!(dest::AbstractConcreteArray, src::AbstractConcreteArray)
404-
dest.data = src.data
405-
return dest
406-
end
407-
408401
function mycopyto!(dest, src)
409402
dest .= src # use broadcasting instead of copyto!
410403
return nothing
411404
end
412405

406+
for aType in (:ConcretePJRTArray, :ConcreteIFRTArray)
407+
@eval function Base.copyto!(dest::$(aType), src::$(aType))
408+
if dest.sharding == src.sharding &&
409+
XLA.device(dest) == XLA.device(src) &&
410+
XLA.client(dest) == XLA.client(src)
411+
dest.data = src.data
412+
else
413+
fn = compile(mycopyto!, (dest, src))
414+
fn(dest, src)
415+
end
416+
return dest
417+
end
418+
end
419+
413420
function Base.copyto!(
414421
dest::Union{AnyConcreteIFRTArray,AnyConcretePJRTArray}, src::AbstractConcreteArray
415422
)
@@ -418,6 +425,22 @@ function Base.copyto!(
418425
return dest
419426
end
420427

428+
for aType in (:ConcretePJRTArray, :ConcreteIFRTArray)
429+
anyaType = Symbol(:Any, aType)
430+
@eval function Base.copyto!(dest::$(anyaType), src::Array{<:ReactantPrimitive})
431+
ancestor_dest = ancestor(dest)
432+
return copyto!(
433+
dest,
434+
$(aType)(
435+
src;
436+
sharding=ancestor_dest.sharding,
437+
client=XLA.client(ancestor_dest),
438+
device=XLA.device(ancestor_dest),
439+
),
440+
)
441+
end
442+
end
443+
421444
for aType in (:ConcretePJRTArray, :ConcreteIFRTArray)
422445
@eval begin
423446
function Base.copyto!(

src/Ops.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1791,7 +1791,7 @@ end
17911791
return_dialect=:stablehlo,
17921792
args_in_result=:none,
17931793
do_transpose=false,
1794-
verify_arg_names
1794+
verify_arg_names,
17951795
).f
17961796

17971797
cond_reg = Reactant.TracedUtils.__take_region(cond_fn_compiled)

src/TracedUtils.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -356,9 +356,11 @@ function make_mlir_fn(
356356
conflicts = setdiff(resis, argis)
357357
@assert !isempty(conflicts) "Expected to have some conflicts, but none were found."
358358

359-
error("""Types do not match between function arguments and results.
360-
The following arguments should be traced: $(join(verify_arg_names.args[collect(conflicts)], ", "))
361-
""")
359+
error(
360+
"""Types do not match between function arguments and results.
361+
The following arguments should be traced: $(join(verify_arg_names.args[collect(conflicts)], ", "))
362+
""",
363+
)
362364
end
363365

364366
out_tys = if do_transpose

0 commit comments

Comments
 (0)