Skip to content

Commit 0760f4b

Browse files
fix: correct dims handling in mapreducedim! (#728)
* feat: add sign dispatches * fix: correct dims handling in mapreducedim! * Update test/basic.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent b70d614 commit 0760f4b

File tree

2 files changed

+35
-4
lines changed

2 files changed

+35
-4
lines changed

src/TracedRArray.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -529,10 +529,14 @@ function Base.mapreducedim!(
529529
@nospecialize(R::TracedRArray),
530530
A::Base.AbstractArrayOrBroadcasted,
531531
)
532-
tmp = TracedUtils.broadcast_to_size(
533-
Base.mapreduce(f, op, A; dims=1), (1, size(R)[2:end]...)
534-
)
535-
R.mlir_data = broadcast(op, R, tmp).mlir_data
532+
@assert length(size(R)) == length(size(A))
533+
dims = map(enumerate(zip(size(R), size(A)))) do (i, (sR, sA))
534+
sR == sA && return nothing
535+
@assert sR == 1
536+
return i
537+
end
538+
tmp = mapreduce(f, op, A; dims=filter(!isnothing, dims))
539+
set_mlir_data!(R, get_mlir_data(tmp))
536540
return R
537541
end
538542

test/basic.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -850,3 +850,30 @@ end
850850

851851
@test @jit(fn(x)) fn(Array(x))
852852
end
853+
854+
function fntest1(x)
855+
y = similar(x, 1, 1, 8)
856+
sum!(y, x)
857+
return y
858+
end
859+
860+
function fntest2(x)
861+
y = similar(x, 2, 1, 8)
862+
sum!(y, x)
863+
return y
864+
end
865+
866+
function fntest3(x)
867+
y = similar(x, 2, 1, 1)
868+
sum!(abs2, y, x)
869+
return y
870+
end
871+
872+
@testset "mapreducedim!" begin
873+
x = reshape(collect(Float32, 1:64), 2, 4, 8) ./ 64
874+
x_ra = Reactant.to_rarray(x)
875+
876+
@test Array(@jit(fntest1(x_ra))) fntest1(x)
877+
@test Array(@jit(fntest2(x_ra))) fntest2(x)
878+
@test Array(@jit(fntest3(x_ra))) fntest3(x)
879+
end

0 commit comments

Comments
 (0)