Skip to content

Commit cbeddbf

Browse files
authored
Generalize broadcast (#35)
* Generalize broadcast * functioning * slowly cleaning up * stash * try with getmap lookup * fix * fix mul * fix mul * rm unused fn
1 parent dbab658 commit cbeddbf

File tree

5 files changed

+254
-165
lines changed

5 files changed

+254
-165
lines changed

src/Reactant.jl

Lines changed: 51 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ abstract type RArray{ElType,Shape,N} <: AbstractArray{ElType,N} end
1717
@inline mlir_type(::RArray{ElType,Shape,N}) where {ElType,Shape,N} =
1818
MLIR.IR.TensorType(Shape, MLIR.IR.Type(ElType))
1919

20+
@inline mlir_type(::Type{<:RArray{ElType,Shape,N}}) where {ElType,Shape,N} =
21+
MLIR.IR.TensorType(Shape, MLIR.IR.Type(ElType))
22+
2023
struct XLAArray{ElType,Shape,N} <: RArray{ElType,Shape,N} end
2124

2225
mutable struct ConcreteRArray{ElType,Shape,N} <: RArray{ElType,Shape,N}
@@ -208,6 +211,10 @@ using Enzyme
208211
TracedSetPath = 5
209212
end
210213

214+
@inline getmap(::Val{T}) where T = nothing
215+
@inline getmap(::Val{T}, a, b, args...) where {T} = getmap(Val(T), args...)
216+
@inline getmap(::Val{T}, ::Val{T}, ::Val{T2}, args...) where {T, T2} = T2
217+
211218
@inline is_concrete_tuple(x::T2) where {T2} =
212219
(x <: Tuple) && !(x === Tuple) && !(x isa UnionAll)
213220
@inline function traced_type(val::Type{T}, seen::ST, ::Val{mode}) where {ST,T,mode}
@@ -391,17 +398,18 @@ end
391398
return IdDict{iddict_name(T),traced_type(iddict_val(T), seen, Val(mode))}
392399
end
393400

394-
if Val(T) seen
395-
return T
401+
nextTy = getmap(Val(T), seen...)
402+
if nextTy != nothing
403+
return nextTy
396404
end
397405

398-
seen = (Val(T), seen...)
406+
seen2 = (Val(T), Val(T), seen...)
399407

400408
changed = false
401409
subTys = Type[]
402410
for f in 1:fieldcount(T)
403411
subT = fieldtype(T, f)
404-
subTT = traced_type(subT, seen, Val(mode))
412+
subTT = traced_type(subT, seen2, Val(mode))
405413
changed |= subT != subTT
406414
push!(subTys, subTT)
407415
end
@@ -421,29 +429,32 @@ end
421429
end
422430

423431
TT2 = Core.apply_type(T.name.wrapper, subParms...)
432+
seen3 = (Val(T), Val(TT2), seen...)
424433
if fieldcount(T) == fieldcount(TT2)
425434
legal = true
426435
for f in 1:fieldcount(T)
427436
subT = fieldtype(T, f)
428437
subT2 = fieldtype(TT2, f)
429-
subTT = traced_type(subT, seen, Val(mode))
430-
legal &= subT2 == subTT
438+
subTT = traced_type(subT, seen3, Val(mode))
439+
if subT2 != subTT
440+
legal = false
441+
break
442+
end
431443
end
432444
if legal
433445
return TT2
434446
end
435447
end
436448

437449
name = Symbol[]
438-
439-
return NamedTuple{fieldnames(T),Tuple{subTys...}}
450+
throw(error("Cannot convert type $T, best attempt $TT2 failed"))
440451
end
441452

442453
function append_path(path, i)
443454
return (path..., i)
444455
end
445456

446-
@inline function make_tracer(seen::IdDict, prev::RT, path, mode) where {RT}
457+
@inline function make_tracer(seen::IdDict, prev::RT, path::Tuple, mode::TraceMode; toscalar=false, tobatch=nothing) where {RT}
447458
if haskey(seen, prev)
448459
return seen[prev]
449460
end
@@ -457,7 +468,7 @@ end
457468
subs = []
458469
for i in 1:nf
459470
xi = Base.getfield(prev, i)
460-
xi2 = make_tracer(seen, xi, append_path(path, i), mode)
471+
xi2 = make_tracer(seen, xi, append_path(path, i), mode; toscalar, tobatch)
461472
if xi !== xi2
462473
changed = true
463474
end
@@ -468,7 +479,6 @@ end
468479
return prev
469480
end
470481
tup = (subs...,)
471-
@show TT, subs, tup
472482
return NamedTuple{TT.parameters[1],typeof(tup)}(tup)
473483
end
474484

@@ -479,7 +489,7 @@ end
479489
for i in 1:nf
480490
if isdefined(prev, i)
481491
xi = Base.getfield(prev, i)
482-
xi2 = make_tracer(seen, xi, append_path(path, i), mode)
492+
xi2 = make_tracer(seen, xi, append_path(path, i), mode; toscalar, tobatch)
483493
if xi !== xi2
484494
changed = true
485495
end
@@ -502,7 +512,7 @@ end
502512
for i in 1:nf
503513
if isdefined(prev, i)
504514
xi = Base.getfield(prev, i)
505-
xi2 = make_tracer(seen, xi, append_path(path, i), mode)
515+
xi2 = make_tracer(seen, xi, append_path(path, i), mode; toscalar, tobatch)
506516
if xi !== xi2
507517
changed = true
508518
end
@@ -522,7 +532,7 @@ end
522532
end
523533

524534
@inline function make_tracer(
525-
seen::IdDict, prev::ConcreteRArray{ElType,Shape,N}, path, mode
535+
seen::IdDict, prev::ConcreteRArray{ElType,Shape,N}, path::Tuple, mode::TraceMode; toscalar=false, tobatch=nothing
526536
) where {ElType,Shape,N}
527537
if mode == ArrayToConcrete
528538
return prev
@@ -540,7 +550,7 @@ end
540550
end
541551

542552
@inline function make_tracer(
543-
seen::IdDict, prev::TracedRArray{ElType,Shape,N}, path, mode
553+
seen::IdDict, prev::TracedRArray{ElType,Shape,N}, path::Tuple, mode::TraceMode; toscalar=false, tobatch=nothing
544554
) where {ElType,Shape,N}
545555
if mode == ConcreteToTraced
546556
throw("Cannot trace existing trace type")
@@ -556,7 +566,13 @@ end
556566
if haskey(seen, prev)
557567
return seen[prev]
558568
end
559-
res = TracedRArray{ElType,Shape,N}((path,), prev.mlir_data)
569+
res = if toscalar
570+
TracedRArray{ElType,(),0}((path,), nothing)
571+
elseif tobatch !== nothing
572+
TracedRArray{ElType,tobatch,length(tobatch)}((path,), prev.mlir_data)
573+
else
574+
TracedRArray{ElType,Shape,N}((path,), prev.mlir_data)
575+
end
560576
seen[prev] = res
561577
return res
562578
end
@@ -573,18 +589,18 @@ end
573589
throw("Cannot Unknown trace mode $mode")
574590
end
575591

576-
@inline function make_tracer(seen::IdDict, prev::RT, path, mode) where {RT<:AbstractFloat}
592+
@inline function make_tracer(seen::IdDict, prev::RT, path::Tuple, mode::TraceMode; toscalar=false, tobatch=nothing) where {RT<:AbstractFloat}
577593
return prev
578594
end
579595

580-
@inline function make_tracer(seen::IdDict, prev::Complex{RT}, path, mode) where {RT}
596+
@inline function make_tracer(seen::IdDict, prev::Complex{RT}, path::Tuple, mode::TraceMode; toscalar=false, tobatch=nothing) where {RT}
581597
return Complex(
582-
make_tracer(seen, prev.re, append_path(path, :re), mode),
583-
make_tracer(seen, prev.im, append_path(path, :im), mode),
598+
make_tracer(seen, prev.re, append_path(path, :re), mode; toscalar, tobatch),
599+
make_tracer(seen, prev.im, append_path(path, :im), mode; toscalar, tobatch),
584600
)
585601
end
586602

587-
@inline function make_tracer(seen::IdDict, prev::RT, path, mode) where {RT<:Array}
603+
@inline function make_tracer(seen::IdDict, prev::RT, path::Tuple, mode::TraceMode; toscalar=false, tobatch=nothing) where {RT<:Array}
588604
if haskey(seen, prev)
589605
return seen[prev]
590606
end
@@ -598,7 +614,7 @@ end
598614
for I in eachindex(prev)
599615
if isassigned(prev, I)
600616
pv = prev[I]
601-
nv = make_tracer(seen, pv, append_path(path, I), mode)
617+
nv = make_tracer(seen, pv, append_path(path, I), mode; toscalar, tobatch)
602618
if pv !== nv
603619
same = false
604620
end
@@ -612,27 +628,27 @@ end
612628
return newa
613629
end
614630

615-
@inline function make_tracer(seen::IdDict, prev::RT, path, mode) where {RT<:Tuple}
631+
@inline function make_tracer(seen::IdDict, prev::RT, path::Tuple, mode::TraceMode; toscalar=false, tobatch=nothing) where {RT<:Tuple}
616632
return (
617-
(make_tracer(seen, v, append_path(path, i), mode) for (i, v) in enumerate(prev))...,
633+
(make_tracer(seen, v, append_path(path, i), mode; toscalar, tobatch) for (i, v) in enumerate(prev))...,
618634
)
619635
end
620636

621-
@inline function make_tracer(seen::IdDict, prev::NamedTuple{A,RT}, path, mode) where {A,RT}
637+
@inline function make_tracer(seen::IdDict, prev::NamedTuple{A,RT}, path::Tuple, mode::TraceMode; toscalar=false, tobatch=nothing) where {A,RT}
622638
return NamedTuple{A,traced_type(RT, (), Val(mode))}((
623639
(
624-
make_tracer(seen, Base.getfield(prev, i), append_path(path, i), mode) for
640+
make_tracer(seen, Base.getfield(prev, i), append_path(path, i), mode; toscalar, tobatch) for
625641
i in 1:length(A)
626642
)...,
627643
))
628644
end
629645

630-
@inline function make_tracer(seen::IdDict, prev::Core.Box, path, mode)
646+
@inline function make_tracer(seen::IdDict, prev::Core.Box, path::Tuple, mode::TraceMode; toscalar=false, tobatch=nothing)
631647
if haskey(seen, prev)
632648
return seen[prev]
633649
end
634650
prev2 = prev.contents
635-
tr = make_tracer(seen, prev2, append_path(path, :contents), mode)
651+
tr = make_tracer(seen, prev2, append_path(path, :contents), mode; toscalar, tobatch)
636652
if tr == prev2
637653
seen[prev] = prev
638654
return prev
@@ -1100,9 +1116,12 @@ pad_dot_general<1>(1);
11001116
"""
11011117

11021118
function compile_to_module(mod, f, args; optimize=true)
1103-
fnwrapped, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = make_mlir_fn(
1104-
mod, f, args, (), "main", true
1105-
)
1119+
fnwrapped, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results =
1120+
MLIR.IR.block!(MLIR.IR.body(mod)) do
1121+
return make_mlir_fn(
1122+
f, args, (), "main", true
1123+
)
1124+
end
11061125

11071126
concrete_seen = IdDict()
11081127

@@ -1112,6 +1131,7 @@ function compile_to_module(mod, f, args; optimize=true)
11121131

11131132
if optimize
11141133
XLA.RunPassPipeline(
1134+
opt_passes * ",enzyme-batch,"*
11151135
opt_passes *
11161136
",enzyme,arith-raise{stablehlo=true},canonicalize, remove-unnecessary-enzyme-ops, enzyme-simplify-math," *
11171137
opt_passes,

0 commit comments

Comments
 (0)