@@ -551,7 +551,7 @@ function __default_init(T::Type{<:Reactant.ReactantFloat8}, op::F) where {F}
551
551
end
552
552
553
553
function overloaded_mapreduce (
554
- @nospecialize (f), @nospecialize (op), @nospecialize (A); dims= :, init= nothing
554
+ @nospecialize (f), @nospecialize (op), @nospecialize (A); dims= :, init= Base . _InitialValue ()
555
555
)
556
556
res = unwrapped_broadcast (f, A)
557
557
# This means we are unable to use the optimized dispatches. For now we will
@@ -568,7 +568,7 @@ function overloaded_mapreduce(
568
568
@nospecialize (op),
569
569
@nospecialize (A:: AnyTracedRArray{T,N} );
570
570
dims= :,
571
- init= nothing ,
571
+ init= Base . _InitialValue () ,
572
572
) where {T,N}
573
573
A = materialize_traced_array (A)
574
574
@@ -589,7 +589,7 @@ function overloaded_mapreduce(
589
589
590
590
res = @opcall reduce (reduce_input, reduce_init, dims, op)
591
591
592
- init != = nothing && (res = op .(res, init))
592
+ ( init isa Base . _InitialValue || init === nothing ) || (res = op .(res, init))
593
593
594
594
if original_dims isa Colon
595
595
@assert size (res) == () " expected size of result to be (), got $(size (res)) "
@@ -677,6 +677,8 @@ function Broadcast.copy(bc::Broadcasted{<:AbstractReactantArrayStyle})
677
677
# Special case a union{} return so we can see the better error message
678
678
if ElType === Union{}
679
679
fn (map (first_scalar, bc. args)... )
680
+ elseif ElType == Any
681
+ ElType = eltype (fn (map (first_scalar, bc. args)... ))
680
682
end
681
683
@assert ElType != Any && ElType != Union{}
682
684
sim = similar (bc, ElType)
@@ -1231,16 +1233,25 @@ function overloaded_map(f, x::AbstractArray, xs::AbstractArray...)
1231
1233
@assert allequal ((axes (x), axes .(xs)... )) " Expected axes of all inputs to map to be \
1232
1234
equal"
1233
1235
1236
+ needs_unrolling = falses (length (xs) + 1 )
1234
1237
inputs = ()
1235
- for input in ( x, xs... )
1238
+ for (i, input) in enumerate (( x, xs... ) )
1236
1239
if input isa AnyTracedRArray
1237
1240
input = Reactant. materialize_traced_array (input)
1238
- else
1241
+ elseif eltype (input) <: Reactant.ReactantPrimitive
1239
1242
input = Reactant. promote_to (TracedRArray{eltype (input),ndims (input)}, input)
1243
+ else
1244
+ needs_unrolling[i] = true
1240
1245
end
1241
1246
inputs = (inputs... , input)
1242
1247
end
1243
1248
1249
+ @assert allequal (needs_unrolling) " All inputs to `overloaded_map` must be \
1250
+ unrolled or none of them. Open an issue."
1251
+ if needs_unrolling[1 ]
1252
+ length (inputs) == 1 && return unrolled_map (f, only (inputs))
1253
+ return unrolled_map (splat (f), zip (inputs... ))
1254
+ end
1244
1255
return TracedUtils. elem_apply (f, inputs... )
1245
1256
end
1246
1257
@@ -1321,14 +1332,14 @@ function scan_impl!(
1321
1332
output:: AnyTracedRArray{T,N} ,
1322
1333
input:: AnyTracedRArray{T,N} ;
1323
1334
dims:: Integer ,
1324
- init= nothing ,
1335
+ init= Base . _InitialValue () ,
1325
1336
) where {T,N}
1326
1337
@assert dims > 0 " dims must be a positive integer"
1327
1338
@assert axes (output) == axes (input) " output and input must have the same shape"
1328
1339
1329
1340
dims > ndims (input) && return copyto! (output, input)
1330
1341
1331
- if init === nothing
1342
+ if init isa Base . _InitialValue
1332
1343
op_in_T = Core. Compiler. return_type (op, Tuple{T,T})
1333
1344
op_in_T === Union{} && (op_in_T = T)
1334
1345
init = __default_init (T, op)
@@ -1494,27 +1505,44 @@ struct BroadcastIterator{F}
1494
1505
f:: F
1495
1506
end
1496
1507
1497
- (fn:: BroadcastIterator )(args... ) = Reactant . call_with_reactant ( fn. f, (args... ,))
1508
+ (fn:: BroadcastIterator )(args... ) = fn. f ( (args... ,))
1498
1509
1499
1510
function unwrapped_broadcast (f:: F , x:: Base.Iterators.Zip ) where {F}
1500
1511
min_length = Base. inferencebarrier (minimum)(length, x. is)
1501
1512
itrs = [length (itr) > min_length ? itr[1 : min_length] : itr for itr in x. is]
1502
- if any (Base. Fix2 (isa, AnyTracedRArray), itrs)
1503
- return (BroadcastIterator (f)). (itrs... )
1504
- else
1505
- fn = BroadcastIterator (f)
1506
- return [fn (Base. Fix2 (getindex, i).(itrs). .. ) for i in 1 : min_length]
1507
- end
1513
+ any (Base. Fix2 (isa, AnyTracedRArray), itrs) || return unrolled_map (f, x)
1514
+ return broadcast (BroadcastIterator (f), itrs... )
1508
1515
end
1509
1516
1510
1517
function unwrapped_broadcast (f:: F , x:: Base.Iterators.Enumerate ) where {F}
1511
- if x. itr isa AnyTracedRArray
1512
- return (BroadcastIterator (f)). (1 : length (x. itr), x. itr)
1513
- else
1514
- return [f ((i, x. itr[i])) for i in 1 : length (x. itr)]
1515
- end
1518
+ x. itr isa AnyTracedRArray || return unrolled_map (f, x)
1519
+ return broadcast (
1520
+ BroadcastIterator (f), Reactant. promote_to (TracedRArray, 1 : length (x. itr)), x. itr
1521
+ )
1516
1522
end
1517
1523
1518
- unwrapped_broadcast (f:: F , xs:: Vector ) where {F} = [f (x) for x in xs]
1524
+ unwrapped_broadcast (f:: F , xs) where {F} = unrolled_map (f, xs)
1525
+
1526
+ # TODO : once traced_call supports internal mutations, we can use traced_call here
1527
+ # TODO : we should overload this for Slices and use mapslices instead
1528
+ function unrolled_map (f:: F , itr) where {F}
1529
+ y = Reactant. call_with_reactant (iterate, itr)
1530
+ y === nothing && return []
1531
+
1532
+ first, state = y
1533
+ res_first = Reactant. call_with_reactant (f, first)
1534
+ result = [res_first]
1535
+
1536
+ while true
1537
+ y = Reactant. call_with_reactant (iterate, itr, state)
1538
+ y === nothing && break
1539
+
1540
+ val, state = y
1541
+ res = Reactant. call_with_reactant (f, val)
1542
+ push! (result, res)
1543
+ end
1544
+
1545
+ return result
1546
+ end
1519
1547
1520
1548
end
0 commit comments