Skip to content

Commit 1aa8be2

Browse files
authored
Revert "fix: apply init values after reduction" (#887)
* Revert "fix: apply init values after reduction (#881)" This reverts commit 6936dbe. * Bump version to v0.2.40
1 parent 55574ac commit 1aa8be2

File tree

5 files changed

+111
-85
lines changed

5 files changed

+111
-85
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>"]
4-
version = "0.2.39"
4+
version = "0.2.40"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/ConcreteRArray.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,14 @@ buffer_on_cpu(::Any) = true
393393
buffer_on_cpu(x::ConcretePJRTArray) = all(XLA.buffer_on_cpu, x.data)
394394
buffer_on_cpu(x::ConcreteIFRTArray) = XLA.buffer_on_cpu(x.data)
395395

396+
function Ops.constant(x::AbstractConcreteArray; kwargs...)
397+
return Ops.constant(Base.convert(Array, x); kwargs...)
398+
end
399+
400+
function Ops.constant(x::AbstractConcreteNumber{T}; kwargs...) where {T}
401+
return Ops.constant(Base.convert(T, x); kwargs...)
402+
end
403+
396404
function Base.zero(x::ConcretePJRTArray{T,N}) where {T,N}
397405
return ConcretePJRTArray(
398406
zeros(T, size(x)...); client=XLA.client(x), device=XLA.device(x), x.sharding
@@ -456,9 +464,3 @@ function Base.mapreducedim!(
456464
fn(f, op, R, A)
457465
return R
458466
end
459-
460-
function Base.map!(f, R::Union{AnyConcreteIFRTArray,AnyConcretePJRTArray}, A::AbstractArray)
461-
fn = compile(Base.map!, (f, R, A))
462-
fn(f, R, A)
463-
return R
464-
end

src/Ops.jl

Lines changed: 17 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -122,16 +122,6 @@ end
122122
end
123123
end
124124

125-
@noinline function constant(
126-
x::AbstractArray{T,N}; location=mlir_stacktrace("constant", @__FILE__, @__LINE__)
127-
) where {T,N}
128-
return constant(collect(x); location)
129-
end
130-
131-
@noinline function constant(x::Reactant.AbstractConcreteArray; kwargs...)
132-
return constant(Base.convert(Array, x); kwargs...)
133-
end
134-
135125
@noinline function constant(
136126
x::T; location=mlir_stacktrace("constant", @__FILE__, @__LINE__)
137127
) where {T<:Number}
@@ -140,10 +130,6 @@ end
140130
return TracedRNumber{T}((), res.mlir_data)
141131
end
142132

143-
@noinline function constant(x::Reactant.AbstractConcreteNumber{T}; kwargs...) where {T}
144-
return constant(Base.convert(T, x); kwargs...)
145-
end
146-
147133
function fill(
148134
v, dims::Base.DimOrInd...; location=mlir_stacktrace("fill", @__FILE__, @__LINE__)
149135
)
@@ -391,7 +377,7 @@ end
391377
end
392378

393379
# shape ops
394-
function reshape(x::TracedRArray, dims::Integer...; kwargs...)
380+
function reshape(x::TracedRArray, dims...; kwargs...)
395381
return reshape(x, collect(dims); kwargs...)
396382
end
397383

@@ -2394,7 +2380,7 @@ end
23942380
x::TracedRArray{T},
23952381
init_values::TracedRNumber{T},
23962382
dimensions::Vector{Int},
2397-
fn::Function;
2383+
fn::Function,
23982384
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
23992385
)
24002386
@@ -2426,43 +2412,25 @@ Applies a reduction function `fn` along the specified `dimensions` of input `x`,
24262412
- **CPU version & Julia's `reduce`**:
24272413
- Reduce along dimension 1 → `[(15) (21); (18) (24)]`
24282414
- Reduce along dimension 3 → `[(33 + 2) (45 + 2)]` → `[35 47]`
2429-
2415+
24302416
- **GPU version**:
24312417
- Reduce along dimension 1 → `[(15 + 2) (21 + 2); (18 + 2) (24 + 2)]`
24322418
- Reduce along dimension 3 → `[37 49]`
24332419
"""
24342420
@noinline function reduce(
24352421
x::TracedRArray{T},
2436-
init_values::Union{TracedRNumber{T},Nothing},
2422+
init_values::TracedRNumber{T},
24372423
dimensions::Vector{Int},
2438-
fn::Function;
2424+
fn::Function,
24392425
location=mlir_stacktrace("reduce", @__FILE__, @__LINE__),
24402426
) where {T}
2441-
elT = T
2442-
if init_values === nothing
2443-
if fn === min || fn === Base.FastMath.min_fast
2444-
init = typemax(elT)
2445-
elseif fn === max || fn === Base.FastMath.max_fast
2446-
init = typemin(elT)
2447-
else
2448-
init = Base.reduce_empty(Base.BottomRF(fn), elT)
2449-
end
2450-
2451-
initT = unwrapped_eltype(typeof(init))
2452-
if initT != elT # Bool, etc. reductions
2453-
elT = promote_type(initT, elT)
2454-
x = elT.(x)
2455-
end
2456-
init_values = Reactant.TracedUtils.promote_to(TracedRNumber{elT}, init)
2457-
end
2458-
24592427
reduced_shape = Tuple(deleteat!(collect(size(x)), dimensions))
24602428

2461-
result_type = mlir_type(TracedRArray{elT,length(reduced_shape)}, reduced_shape)
2429+
result_type = mlir_type(TracedRArray{T,length(reduced_shape)}, reduced_shape)
24622430

24632431
sample_inputs = [
2464-
Reactant.TracedUtils.promote_to(TracedRNumber{elT}, 0),
2465-
Reactant.TracedUtils.promote_to(TracedRNumber{elT}, 0),
2432+
Reactant.TracedUtils.promote_to(TracedRNumber{T}, 0),
2433+
Reactant.TracedUtils.promote_to(TracedRNumber{T}, 0),
24662434
]
24672435

24682436
func =
@@ -2476,8 +2444,14 @@ Applies a reduction function `fn` along the specified `dimensions` of input `x`,
24762444
return_dialect=:stablehlo,
24772445
).f
24782446
@assert MLIR.IR.nregions(func) == 1
2479-
ftype = MLIR.IR.Type(MLIR.IR.attr(func, "function_type"))
2480-
@assert MLIR.IR.result(ftype) == MLIR.IR.TensorType((), MLIR.IR.Type(elT)) "$fn return type is not tensor<i1>"
2447+
fn_name = String(
2448+
MLIR.IR.attr(func, String(MLIR.API.mlirSymbolTableGetSymbolAttributeName()))
2449+
)
2450+
ftype_attr = MLIR.IR.attr(func, "function_type")
2451+
ftype = MLIR.IR.Type(ftype_attr)
2452+
@assert MLIR.IR.result(ftype) == MLIR.IR.TensorType((), MLIR.IR.Type(T)) error (
2453+
"$fn return type is not tensor<i1>"
2454+
)
24812455
fn = MLIR.IR.Region()
24822456
MLIR.API.mlirRegionTakeBody(fn, MLIR.IR.region(func, 1))
24832457
MLIR.IR.rmfromparent!(func)
@@ -2495,7 +2469,7 @@ Applies a reduction function `fn` along the specified `dimensions` of input `x`,
24952469
),
24962470
)
24972471

2498-
return TracedRArray{elT,length(reduced_shape)}((), res, reduced_shape)
2472+
return TracedRArray{T,length(reduced_shape)}((), res, reduced_shape)
24992473
end
25002474

25012475
end # module Ops

src/TracedRArray.jl

Lines changed: 85 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -468,29 +468,100 @@ function Base.mapreduce(
468468
dims=:,
469469
init=nothing,
470470
) where {T,N}
471-
inp = broadcast(f, materialize_traced_array(A))
471+
A = materialize_traced_array(A)
472472

473-
dims isa Number && (dims = (dims,))
473+
if dims isa Int
474+
dims = [dims]
475+
end
476+
477+
op_in_T = Core.Compiler.return_type(f, Tuple{T})
478+
479+
if init === nothing
480+
if op === min
481+
init = typemax(op_in_T)
482+
elseif op === max
483+
init = typemin(op_in_T)
484+
else
485+
init = Base.reduce_empty(Base.BottomRF(op), op_in_T)
486+
end
474487

475-
if init !== nothing && typeof(init) != unwrapped_eltype(inp)
476-
inp = typeof(init).(inp)
488+
if typeof(init) != op_in_T
489+
op_in_T = typeof(init)
490+
A = typeof(init).(A)
491+
end
477492
end
478493

479-
rdims = dims == (:) ? collect(Int64, 1:N) : collect(Int64, dims)
494+
init = [TracedUtils.broadcast_to_size(init, ()).mlir_data]
495+
496+
inp = [broadcast(f, A).mlir_data]
480497

481-
reduction_result = Ops.reduce(inp, nothing, rdims, op)
498+
rdims = Int64[]
482499

483-
reduction_result = if dims != (:)
484-
Ops.reshape(reduction_result, Int64[i rdims ? 1 : size(A, i) for i in 1:N])
500+
if dims == (:)
501+
for i in 0:(N - 1)
502+
push!(rdims, i)
503+
end
485504
else
486-
TracedRNumber{unwrapped_eltype(reduction_result)}((), reduction_result.mlir_data)
505+
for i in dims
506+
push!(rdims, i - 1)
507+
end
487508
end
488509

489-
init === nothing && return reduction_result
490-
return broadcast(op, reduction_result, init)
510+
in_tys = [
511+
MLIR.IR.TensorType(Int64[], eltype(MLIR.IR.type(inp[1]))),
512+
MLIR.IR.TensorType(Int64[], eltype(MLIR.IR.type(init[1]))),
513+
]
514+
515+
fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location(), MLIR.IR.Location()])
516+
517+
args = (
518+
TracedRNumber{Reactant.unwrapped_eltype(op_in_T)}((), MLIR.IR.argument(fnbody, 1)),
519+
TracedRNumber{Reactant.unwrapped_eltype(op_in_T)}((), MLIR.IR.argument(fnbody, 2)),
520+
)
521+
522+
resty = MLIR.IR.block!(fnbody) do
523+
tmp = TracedUtils.broadcast_to_size(op(args...), ())
524+
Ops.return_(tmp)
525+
return eltype(MLIR.IR.type(tmp.mlir_data))
526+
end
527+
528+
toonedims = Int[]
529+
outdims = Int[]
530+
for i in 1:N
531+
tmp = if in(i - 1, rdims)
532+
1
533+
else
534+
sz = size(A, i)
535+
push!(outdims, sz)
536+
sz
537+
end
538+
push!(toonedims, tmp)
539+
end
540+
541+
TT = MLIR.IR.Type[MLIR.IR.TensorType(outdims, resty)]
542+
543+
body = MLIR.IR.Region()
544+
push!(body, fnbody)
545+
red = MLIR.Dialects.stablehlo.reduce(
546+
inp, init; result_0=TT, dimensions=MLIR.IR.DenseArrayAttribute(rdims), body
547+
)
548+
549+
red = MLIR.IR.result(red, 1)
550+
redT = eltype(MLIR.IR.julia_type(MLIR.IR.type(red)))
551+
552+
if dims != (:)
553+
red = Ops.reshape(TracedRArray(red), toonedims...)
554+
else
555+
if length(outdims) == 0
556+
red = TracedRNumber{redT}((), red)
557+
else
558+
red = TracedRArray{redT,length(outdims)}((), red, (outdims...,))
559+
end
560+
end
561+
return red
491562
end
492563

493-
function Base._mapreducedim!(
564+
function Base.mapreducedim!(
494565
@nospecialize(f),
495566
@nospecialize(op),
496567
@nospecialize(R::AnyTracedRArray),
@@ -502,11 +573,9 @@ function Base._mapreducedim!(
502573
@assert sR == 1
503574
return i
504575
end
505-
506-
isempty(A) && return R
507-
508576
tmp = mapreduce(f, op, A; dims=filter(!isnothing, dims))
509-
R .= op.(R, tmp)
577+
# set_mlir_data!(R, get_mlir_data(tmp))
578+
R .= op.(R, tmp) # match native Julia's behavior
510579
return R
511580
end
512581

@@ -1015,11 +1084,4 @@ function Base.findmax(f, x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothin
10151084
return (values, linear_indices)
10161085
end
10171086

1018-
Base.map(f, x::AnyTracedRArray) = f.(x)
1019-
1020-
function Base.map!(f, y::AnyTracedRArray, x::AbstractArray)
1021-
y .= f.(x)
1022-
return y
1023-
end
1024-
10251087
end

test/basic.jl

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -938,15 +938,3 @@ end
938938
rv
939939
)
940940
end
941-
942-
@testset "mapreduce with init" begin
943-
x = reshape(collect(Float32, 1:12), 3, 4)
944-
x_ra = Reactant.to_rarray(x)
945-
946-
init = 3.0
947-
init_ra = Reactant.to_rarray(init; track_numbers=Number)
948-
949-
fn(x, init; kwargs...) = sum(x; init, kwargs...)
950-
951-
@test @jit(fn(x_ra, init_ra; dims=2)) fn(x, init; dims=2)
952-
end

0 commit comments

Comments
 (0)