@@ -468,100 +468,29 @@ function Base.mapreduce(
468
468
dims= :,
469
469
init= nothing ,
470
470
) where {T,N}
471
- A = materialize_traced_array (A)
471
+ inp = broadcast (f, materialize_traced_array (A) )
472
472
473
- if dims isa Int
474
- dims = [dims]
475
- end
476
-
477
- op_in_T = Core. Compiler. return_type (f, Tuple{T})
473
+ dims isa Number && (dims = (dims,))
478
474
479
- if init === nothing
480
- if op === min
481
- init = typemax (op_in_T)
482
- elseif op === max
483
- init = typemin (op_in_T)
484
- else
485
- init = Base. reduce_empty (Base. BottomRF (op), op_in_T)
486
- end
487
-
488
- if typeof (init) != op_in_T
489
- op_in_T = typeof (init)
490
- A = typeof (init).(A)
491
- end
475
+ if init != = nothing && typeof (init) != unwrapped_eltype (inp)
476
+ inp = typeof (init).(inp)
492
477
end
493
478
494
- init = [TracedUtils. broadcast_to_size (init, ()). mlir_data]
495
-
496
- inp = [broadcast (f, A). mlir_data]
479
+ rdims = dims == (:) ? collect (Int64, 1 : N) : collect (Int64, dims)
497
480
498
- rdims = Int64[]
481
+ reduction_result = Ops . reduce (inp, nothing , rdims, op)
499
482
500
- if dims == (:)
501
- for i in 0 : (N - 1 )
502
- push! (rdims, i)
503
- end
483
+ reduction_result = if dims != (:)
484
+ Ops. reshape (reduction_result, Int64[i ∈ rdims ? 1 : size (A, i) for i in 1 : N])
504
485
else
505
- for i in dims
506
- push! (rdims, i - 1 )
507
- end
508
- end
509
-
510
- in_tys = [
511
- MLIR. IR. TensorType (Int64[], eltype (MLIR. IR. type (inp[1 ]))),
512
- MLIR. IR. TensorType (Int64[], eltype (MLIR. IR. type (init[1 ]))),
513
- ]
514
-
515
- fnbody = MLIR. IR. Block (in_tys, [MLIR. IR. Location (), MLIR. IR. Location ()])
516
-
517
- args = (
518
- TracedRNumber {Reactant.unwrapped_eltype(op_in_T)} ((), MLIR. IR. argument (fnbody, 1 )),
519
- TracedRNumber {Reactant.unwrapped_eltype(op_in_T)} ((), MLIR. IR. argument (fnbody, 2 )),
520
- )
521
-
522
- resty = MLIR. IR. block! (fnbody) do
523
- tmp = TracedUtils. broadcast_to_size (op (args... ), ())
524
- Ops. return_ (tmp)
525
- return eltype (MLIR. IR. type (tmp. mlir_data))
486
+ TracedRNumber {unwrapped_eltype(reduction_result)} ((), reduction_result. mlir_data)
526
487
end
527
488
528
- toonedims = Int[]
529
- outdims = Int[]
530
- for i in 1 : N
531
- tmp = if in (i - 1 , rdims)
532
- 1
533
- else
534
- sz = size (A, i)
535
- push! (outdims, sz)
536
- sz
537
- end
538
- push! (toonedims, tmp)
539
- end
540
-
541
- TT = MLIR. IR. Type[MLIR. IR. TensorType (outdims, resty)]
542
-
543
- body = MLIR. IR. Region ()
544
- push! (body, fnbody)
545
- red = MLIR. Dialects. stablehlo. reduce (
546
- inp, init; result_0= TT, dimensions= MLIR. IR. DenseArrayAttribute (rdims), body
547
- )
548
-
549
- red = MLIR. IR. result (red, 1 )
550
- redT = eltype (MLIR. IR. julia_type (MLIR. IR. type (red)))
551
-
552
- if dims != (:)
553
- red = Ops. reshape (TracedRArray (red), toonedims... )
554
- else
555
- if length (outdims) == 0
556
- red = TracedRNumber {redT} ((), red)
557
- else
558
- red = TracedRArray {redT,length(outdims)} ((), red, (outdims... ,))
559
- end
560
- end
561
- return red
489
+ init === nothing && return reduction_result
490
+ return broadcast (op, reduction_result, init)
562
491
end
563
492
564
- function Base. mapreducedim ! (
493
+ function Base. _mapreducedim ! (
565
494
@nospecialize (f),
566
495
@nospecialize (op),
567
496
@nospecialize (R:: AnyTracedRArray ),
@@ -573,9 +502,9 @@ function Base.mapreducedim!(
573
502
@assert sR == 1
574
503
return i
575
504
end
505
+ isempty (A) && return R
576
506
tmp = mapreduce (f, op, A; dims= filter (! isnothing, dims))
577
- # set_mlir_data!(R, get_mlir_data(tmp))
578
- R .= op .(R, tmp) # match native Julia's behavior
507
+ R .= op .(R, tmp)
579
508
return R
580
509
end
581
510
0 commit comments