Skip to content

Commit 8040a3e

Browse files
authored
feat: support generators + force calling into interp (#1665)
* feat: support generators + force calling into interp * fix: overloaded_map
1 parent 6b9f46b commit 8040a3e

File tree

9 files changed

+139
-38
lines changed

9 files changed

+139
-38
lines changed

src/ConcreteRArray.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -650,7 +650,7 @@ function Base.mapreduce(
650650
@nospecialize(op),
651651
@nospecialize(A::AbstractConcreteArray{T,N});
652652
dims=:,
653-
init=nothing,
653+
init=Base._InitialValue(),
654654
) where {T,N}
655655
fn = compile(CallMapReduce(f, op, dims, init), (A,))
656656
return fn(A)

src/Ops.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,11 @@ macro opcall(expr)
4646

4747
# Generate location info at the callsite
4848
location_expr = :($(mlir_stacktrace)(
49-
joinpath(string(var"#self#"), $(string(func))),
49+
if @isdefined(var"#self#")
50+
joinpath(string(var"#self#"), $(string(func)))
51+
else
52+
$(string(func))
53+
end,
5054
$(string(__source__.file)),
5155
$(__source__.line),
5256
))
@@ -2575,7 +2579,7 @@ end
25752579
seen_cache = Reactant.OrderedIdDict()
25762580
Reactant.make_tracer(
25772581
seen_cache,
2578-
args,
2582+
fnwrapped ? (f, args) : args,
25792583
(), # we have to insert something here, but we remove it immediately below.
25802584
Reactant.TracedTrack;
25812585
toscalar=false,

src/Overlay.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,10 @@ end
156156
end
157157

158158
@reactant_overlay @noinline function Base.mapreduce(
159-
f, op, A::Union{AbstractArray,Base.Iterators.Zip,Base.Iterators.Enumerate}; kwargs...
159+
f,
160+
op,
161+
A::Union{AbstractArray,Base.Iterators.Zip,Base.Iterators.Enumerate,Base.Generator};
162+
kwargs...,
160163
)
161164
if use_overlayed_version(A)
162165
return TracedRArrayOverrides.overloaded_mapreduce(f, op, A; kwargs...)

src/Reactant.jl

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ function _parent end
6060
_parent_type(::Type{Array}) = Array
6161
_parent_type(::Type{Array{T}}) where {T} = Array{T}
6262
_parent_type(::Type{Array{T,N}}) where {T,N} = Array{T,N}
63+
_parent_type(::Type{<:Slices{P}}) where {P} = P
6364

6465
include("accelerators/Accelerators.jl")
6566

@@ -179,10 +180,15 @@ include("TracedRArray.jl")
179180
include("ConcreteRArray.jl")
180181

181182
use_overlayed_version(x) = false
182-
use_overlayed_version(x::Base.Iterators.Zip) = any(use_overlayed_version, x.is)
183+
function use_overlayed_version(x::F) where {F<:Function}
184+
return use_overlayed_version(getfield.(Ref(x), fieldnames(F)))
185+
end
186+
use_overlayed_version(x::Base.Generator) = use_overlayed_version((x.f, x.iter))
187+
use_overlayed_version(x::Base.Iterators.Zip) = use_overlayed_version(x.is)
183188
use_overlayed_version(x::Base.Iterators.Enumerate) = use_overlayed_version(x.itr)
184-
use_overlayed_version(iter::Tuple) = any(use_overlayed_version, iter)
185-
use_overlayed_version(iter::NamedTuple) = any(use_overlayed_version, values(iter))
189+
use_overlayed_version(x::Vector) = looped_any(use_overlayed_version, x)
190+
use_overlayed_version(iter::Tuple) = looped_any(use_overlayed_version, iter)
191+
use_overlayed_version(iter::NamedTuple) = looped_any(use_overlayed_version, values(iter))
186192
use_overlayed_version(::TracedRArray) = true
187193
use_overlayed_version(::TracedRNumber) = true
188194
use_overlayed_version(::Number) = false
@@ -195,6 +201,14 @@ function use_overlayed_version(x::AbstractArray)
195201
return use_overlayed_version(a)
196202
end
197203

204+
## We avoid calling into `any` to avoid triggering the `any` overlay
205+
function looped_any(f::F, itr) where {F}
206+
@inbounds for x in itr
207+
f(x) && return true
208+
end
209+
return false
210+
end
211+
198212
# StdLib Overloads
199213
include("stdlibs/LinearAlgebra.jl")
200214
include("stdlibs/Random.jl")

src/TracedRArray.jl

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,7 @@ function __default_init(T::Type{<:Reactant.ReactantFloat8}, op::F) where {F}
551551
end
552552

553553
function overloaded_mapreduce(
554-
@nospecialize(f), @nospecialize(op), @nospecialize(A); dims=:, init=nothing
554+
@nospecialize(f), @nospecialize(op), @nospecialize(A); dims=:, init=Base._InitialValue()
555555
)
556556
res = unwrapped_broadcast(f, A)
557557
# This means we are unable to use the optimized dispatches. For now we will
@@ -568,7 +568,7 @@ function overloaded_mapreduce(
568568
@nospecialize(op),
569569
@nospecialize(A::AnyTracedRArray{T,N});
570570
dims=:,
571-
init=nothing,
571+
init=Base._InitialValue(),
572572
) where {T,N}
573573
A = materialize_traced_array(A)
574574

@@ -589,7 +589,7 @@ function overloaded_mapreduce(
589589

590590
res = @opcall reduce(reduce_input, reduce_init, dims, op)
591591

592-
init !== nothing && (res = op.(res, init))
592+
(init isa Base._InitialValue || init === nothing) || (res = op.(res, init))
593593

594594
if original_dims isa Colon
595595
@assert size(res) == () "expected size of result to be (), got $(size(res))"
@@ -677,6 +677,8 @@ function Broadcast.copy(bc::Broadcasted{<:AbstractReactantArrayStyle})
677677
# Special case a union{} return so we can see the better error message
678678
if ElType === Union{}
679679
fn(map(first_scalar, bc.args)...)
680+
elseif ElType == Any
681+
ElType = eltype(fn(map(first_scalar, bc.args)...))
680682
end
681683
@assert ElType != Any && ElType != Union{}
682684
sim = similar(bc, ElType)
@@ -1231,16 +1233,25 @@ function overloaded_map(f, x::AbstractArray, xs::AbstractArray...)
12311233
@assert allequal((axes(x), axes.(xs)...)) "Expected axes of all inputs to map to be \
12321234
equal"
12331235

1236+
needs_unrolling = falses(length(xs) + 1)
12341237
inputs = ()
1235-
for input in (x, xs...)
1238+
for (i, input) in enumerate((x, xs...))
12361239
if input isa AnyTracedRArray
12371240
input = Reactant.materialize_traced_array(input)
1238-
else
1241+
elseif eltype(input) <: Reactant.ReactantPrimitive
12391242
input = Reactant.promote_to(TracedRArray{eltype(input),ndims(input)}, input)
1243+
else
1244+
needs_unrolling[i] = true
12401245
end
12411246
inputs = (inputs..., input)
12421247
end
12431248

1249+
@assert allequal(needs_unrolling) "All inputs to `overloaded_map` must be \
1250+
unrolled or none of them. Open an issue."
1251+
if needs_unrolling[1]
1252+
length(inputs) == 1 && return unrolled_map(f, only(inputs))
1253+
return unrolled_map(splat(f), zip(inputs...))
1254+
end
12441255
return TracedUtils.elem_apply(f, inputs...)
12451256
end
12461257

@@ -1321,14 +1332,14 @@ function scan_impl!(
13211332
output::AnyTracedRArray{T,N},
13221333
input::AnyTracedRArray{T,N};
13231334
dims::Integer,
1324-
init=nothing,
1335+
init=Base._InitialValue(),
13251336
) where {T,N}
13261337
@assert dims > 0 "dims must be a positive integer"
13271338
@assert axes(output) == axes(input) "output and input must have the same shape"
13281339

13291340
dims > ndims(input) && return copyto!(output, input)
13301341

1331-
if init === nothing
1342+
if init isa Base._InitialValue
13321343
op_in_T = Core.Compiler.return_type(op, Tuple{T,T})
13331344
op_in_T === Union{} && (op_in_T = T)
13341345
init = __default_init(T, op)
@@ -1494,27 +1505,44 @@ struct BroadcastIterator{F}
14941505
f::F
14951506
end
14961507

1497-
(fn::BroadcastIterator)(args...) = Reactant.call_with_reactant(fn.f, (args...,))
1508+
(fn::BroadcastIterator)(args...) = fn.f((args...,))
14981509

14991510
function unwrapped_broadcast(f::F, x::Base.Iterators.Zip) where {F}
15001511
min_length = Base.inferencebarrier(minimum)(length, x.is)
15011512
itrs = [length(itr) > min_length ? itr[1:min_length] : itr for itr in x.is]
1502-
if any(Base.Fix2(isa, AnyTracedRArray), itrs)
1503-
return (BroadcastIterator(f)).(itrs...)
1504-
else
1505-
fn = BroadcastIterator(f)
1506-
return [fn(Base.Fix2(getindex, i).(itrs)...) for i in 1:min_length]
1507-
end
1513+
any(Base.Fix2(isa, AnyTracedRArray), itrs) || return unrolled_map(f, x)
1514+
return broadcast(BroadcastIterator(f), itrs...)
15081515
end
15091516

15101517
function unwrapped_broadcast(f::F, x::Base.Iterators.Enumerate) where {F}
1511-
if x.itr isa AnyTracedRArray
1512-
return (BroadcastIterator(f)).(1:length(x.itr), x.itr)
1513-
else
1514-
return [f((i, x.itr[i])) for i in 1:length(x.itr)]
1515-
end
1518+
x.itr isa AnyTracedRArray || return unrolled_map(f, x)
1519+
return broadcast(
1520+
BroadcastIterator(f), Reactant.promote_to(TracedRArray, 1:length(x.itr)), x.itr
1521+
)
15161522
end
15171523

1518-
unwrapped_broadcast(f::F, xs::Vector) where {F} = [f(x) for x in xs]
1524+
unwrapped_broadcast(f::F, xs) where {F} = unrolled_map(f, xs)
1525+
1526+
# TODO: once traced_call supports internal mutations, we can use traced_call here
1527+
# TODO: we should overload this for Slices and use mapslices instead
1528+
function unrolled_map(f::F, itr) where {F}
1529+
y = Reactant.call_with_reactant(iterate, itr)
1530+
y === nothing && return []
1531+
1532+
first, state = y
1533+
res_first = Reactant.call_with_reactant(f, first)
1534+
result = [res_first]
1535+
1536+
while true
1537+
y = Reactant.call_with_reactant(iterate, itr, state)
1538+
y === nothing && break
1539+
1540+
val, state = y
1541+
res = Reactant.call_with_reactant(f, val)
1542+
push!(result, res)
1543+
end
1544+
1545+
return result
1546+
end
15191547

15201548
end

src/TracedUtils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -660,7 +660,7 @@ function finalize_mlir_fn(
660660
skipped_results = Reactant.TracedType[]
661661
for (k, v) in seen_results
662662
v isa Reactant.TracedType || continue
663-
if any(Base.Fix1(===, k), skipped_args)
663+
if Reactant.looped_any(Base.Fix1(===, k), skipped_args)
664664
push!(skipped_results, v)
665665

666666
_, argpath = get_argidx(v, argprefix)

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
struct CallWithReactant{F}
1+
struct CallWithReactant{F} <: Function
22
f::F
33
end
44

test/autodiff.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ end
132132

133133
@testset "Forward Gradient" begin
134134
x = Reactant.to_rarray(3.1 * ones(2, 2))
135-
res = @test_warn r"`Adapt.parent_type` is not implemented for" @jit gw(x)
135+
res = @jit gw(x)
136136
# TODO we should probably override https://github.com/EnzymeAD/Enzyme.jl/blob/5e6a82dd08e74666822b9d7b2b46c36b075668ca/src/Enzyme.jl#L2132
137137
# to make sure this gets merged as a tracedrarray
138138
@test res isa Tuple{<:Enzyme.TupleArray{<:ConcreteRNumber{Float64},(2, 2),4,2}}

test/basic.jl

Lines changed: 60 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -926,7 +926,7 @@ end
926926

927927
ra = Reactant.to_rarray(x)
928928
@jit dip!(ra)
929-
ra[:a] (2.7 * 2) * ones(4)
929+
@test ra[:a] (2.7 * 3.1) * ones(4)
930930
end
931931

932932
@testset "@code_xla" begin
@@ -1429,7 +1429,10 @@ end
14291429
end
14301430

14311431
zip_iterator(a, b) = mapreduce(splat(*), +, zip(a, b))
1432+
zip_iterator2(a, b) = mapreduce(splat(.-), +, zip(a, b))
14321433
enumerate_iterator(a) = mapreduce(splat(*), +, enumerate(a))
1434+
enumerate_iterator2(a) = mapreduce(splat(.-), +, enumerate(a))
1435+
mapreduce_vector(a) = mapreduce(-, +, a)
14331436

14341437
function nested_mapreduce_zip(x, y)
14351438
return mapreduce(+, zip(eachcol(x), eachcol(y)); init=0.0f0) do (x, y)
@@ -1445,44 +1448,65 @@ function nested_mapreduce_hcat(x, y)
14451448
end
14461449
end
14471450

1451+
function f_generator(points, params)
1452+
return sum(params * point for point in points)
1453+
end
1454+
14481455
@testset "Base.Iterators" begin
14491456
@testset "zip" begin
14501457
N = 10
1451-
a = range(1.0, 5.0; length=N)
1452-
x = range(10.0, 15.0; length=N + 2)
1458+
a = collect(range(1.0, 5.0; length=N))
1459+
x = collect(range(10.0, 15.0; length=N + 2))
14531460
x_ra = Reactant.to_rarray(x)
14541461

14551462
@test @jit(zip_iterator(a, x_ra)) zip_iterator(a, x)
1463+
1464+
a = [rand(Float32, 2, 3) for _ in 1:10]
1465+
x = [rand(Float32, 2, 3) for _ in 1:10]
1466+
a_ra = Reactant.to_rarray(a)
1467+
x_ra = Reactant.to_rarray(x)
1468+
1469+
@test @jit(zip_iterator2(a_ra, x_ra)) zip_iterator2(a, x)
14561470
end
14571471

14581472
@testset "enumerate" begin
1459-
x = range(1.0, 5.0; length=10)
1473+
x = collect(range(1.0, 5.0; length=10))
14601474
x_ra = Reactant.to_rarray(x)
14611475

14621476
@test @jit(enumerate_iterator(x_ra)) enumerate_iterator(x)
1477+
1478+
x = [rand(Float32, 2, 3) for _ in 1:10]
1479+
x_ra = Reactant.to_rarray(x)
1480+
1481+
@test @jit(enumerate_iterator2(x_ra)) enumerate_iterator2(x)
14631482
end
14641483

14651484
@testset "nested mapreduce" begin
14661485
x = rand(Float32, 4, 3)
14671486
y = rand(Float32, 4, 3)
1468-
14691487
x_ra = Reactant.to_rarray(x)
14701488
y_ra = Reactant.to_rarray(y)
1471-
14721489
@test @jit(nested_mapreduce_zip(x_ra, y_ra)) nested_mapreduce_zip(x, y)
14731490
end
1474-
14751491
@testset "nested mapreduce hcat" begin
14761492
x = rand(Float32, 4, 3)
14771493
y = rand(Float32, 4, 3)
1478-
14791494
x_ra = Reactant.to_rarray(x)
14801495
y_ra = Reactant.to_rarray(y)
14811496

14821497
@test @jit(nested_mapreduce_hcat(x_ra, y_ra)) nested_mapreduce_hcat(x, y)
14831498
end
14841499
end
14851500

1501+
@testset "Base.Generator" begin
1502+
points = eachcol(rand(Float32, 2, 6))
1503+
params = rand(Float32, 4, 2)
1504+
points_ra = Reactant.to_rarray(points)
1505+
params_ra = Reactant.to_rarray(params)
1506+
1507+
@test @jit(f_generator(points_ra, params_ra)) f_generator(points, params)
1508+
end
1509+
14861510
@testset "compilation cache" begin
14871511
if Reactant.PersistentCompileCache.autotune_cache_enabled() &&
14881512
contains(string(Reactant.devices()[1]), "CUDA")
@@ -1574,3 +1598,31 @@ end
15741598
x_ra = Reactant.to_rarray(x)
15751599
@test @jit(clamp!(x_ra, 0.5, Inf32)) clamp!(x, 0.5, Inf32)
15761600
end
1601+
1602+
mapped_sub(xs...) = stack(map(-, xs...))
1603+
1604+
@testset "map of slices" begin
1605+
# We shouldn't be using `elem_apply` in this case and instead unroll the map
1606+
# our passes will fuse them backup if needed.
1607+
@testset "Vector of Slices" begin
1608+
x_full = rand(Float32, 10, 5, 3)
1609+
y_full = rand(Float32, 10, 5, 3)
1610+
x = [view(x_full, :, i, :) for i in 1:size(x_full, 2)]
1611+
y = [view(y_full, :, i, :) for i in 1:size(y_full, 2)]
1612+
x_ra = Reactant.to_rarray(x)
1613+
y_ra = Reactant.to_rarray(y)
1614+
1615+
@test @jit(mapped_sub(x_ra, y_ra)) mapped_sub(x, y) atol = 1e-5 rtol = 1e-5
1616+
end
1617+
1618+
@testset "Slices" begin
1619+
x_full = rand(Float32, 10, 5)
1620+
1621+
@testset "ColumnSlices" begin
1622+
x_sliced = eachcol(x_full)
1623+
x_ra = Reactant.to_rarray(x_sliced)
1624+
1625+
@test @jit(mapped_sub(x_ra)) mapped_sub(x_sliced) atol = 1e-5 rtol = 1e-5
1626+
end
1627+
end
1628+
end

0 commit comments

Comments
 (0)