|
| 1 | +module ReactantFillArraysExt |
| 2 | + |
| 3 | +using Reactant: Reactant, TracedUtils, TracedRNumber, Ops, Sharding, unwrapped_eltype |
| 4 | +using ReactantCore: ReactantCore |
| 5 | +using FillArrays: FillArrays, AbstractFill, Fill, Ones, Zeros, OneElement |
| 6 | +using GPUArraysCore: @allowscalar |
| 7 | + |
| 8 | +# Tracing |
| 9 | +Reactant._parent_type(T::Type{<:AbstractFill}) = T |
| 10 | +Reactant._parent_type(T::Type{<:OneElement}) = T |
| 11 | + |
| 12 | +for AT in (Fill, Ones, Zeros) |
| 13 | + @eval Base.@nospecializeinfer function Reactant.traced_type_inner( |
| 14 | + @nospecialize(FA::Type{$(AT){T,N,Axes}}), |
| 15 | + seen, |
| 16 | + mode::Reactant.TraceMode, |
| 17 | + @nospecialize(track_numbers::Type), |
| 18 | + @nospecialize(sharding), |
| 19 | + @nospecialize(runtime) |
| 20 | + ) where {T,N,Axes} |
| 21 | + # T will be a number so we need to trace it |
| 22 | + return $(AT){ |
| 23 | + Reactant.traced_type_inner(T, seen, mode, Number, sharding, runtime),N,Axes |
| 24 | + } |
| 25 | + end |
| 26 | +end |
| 27 | + |
| 28 | +Base.@nospecializeinfer function Reactant.make_tracer( |
| 29 | + seen, @nospecialize(prev::Fill{T,N,Axes}), @nospecialize(path), mode; kwargs... |
| 30 | +) where {T,N,Axes} |
| 31 | + return Fill( |
| 32 | + Reactant.make_tracer( |
| 33 | + seen, prev.value, (path..., 1), mode; kwargs..., track_numbers=Number |
| 34 | + ), |
| 35 | + prev.axes, |
| 36 | + ) |
| 37 | +end |
| 38 | + |
| 39 | +Base.@nospecializeinfer function Reactant.make_tracer( |
| 40 | + seen, |
| 41 | + @nospecialize(prev::Ones{T,N,Axes}), |
| 42 | + @nospecialize(path), |
| 43 | + mode; |
| 44 | + @nospecialize(sharding = Sharding.NoSharding()), |
| 45 | + @nospecialize(runtime = nothing), |
| 46 | + kwargs..., |
| 47 | +) where {T,N,Axes} |
| 48 | + return Ones( |
| 49 | + Reactant.traced_type_inner(T, seen, mode, Number, sharding, runtime), prev.axes |
| 50 | + ) |
| 51 | +end |
| 52 | + |
| 53 | +Base.@nospecializeinfer function Reactant.make_tracer( |
| 54 | + seen, |
| 55 | + @nospecialize(prev::Zeros{T,N,Axes}), |
| 56 | + @nospecialize(path), |
| 57 | + mode; |
| 58 | + @nospecialize(sharding = Sharding.NoSharding()), |
| 59 | + @nospecialize(runtime = nothing), |
| 60 | + kwargs..., |
| 61 | +) where {T,N,Axes} |
| 62 | + return Zeros( |
| 63 | + Reactant.traced_type_inner(T, seen, mode, Number, sharding, runtime), prev.axes |
| 64 | + ) |
| 65 | +end |
| 66 | + |
| 67 | +Base.@nospecializeinfer function Reactant.traced_type_inner( |
| 68 | + @nospecialize(FA::Type{OneElement{T,N,I,A}}), |
| 69 | + seen, |
| 70 | + mode::Reactant.TraceMode, |
| 71 | + @nospecialize(track_numbers::Type), |
| 72 | + @nospecialize(sharding), |
| 73 | + @nospecialize(runtime) |
| 74 | +) where {T,N,I,A} |
| 75 | + # T will be a number so we need to trace it |
| 76 | + return OneElement{ |
| 77 | + Reactant.traced_type_inner(T, seen, mode, Number, sharding, runtime),N,I,A |
| 78 | + } |
| 79 | +end |
| 80 | + |
| 81 | +Base.@nospecializeinfer function Reactant.make_tracer( |
| 82 | + seen, @nospecialize(prev::OneElement{T,N,I,A}), @nospecialize(path), mode; kwargs... |
| 83 | +) where {T,N,I,A} |
| 84 | + return OneElement( |
| 85 | + Reactant.make_tracer( |
| 86 | + seen, prev.val, (path..., 1), mode; kwargs..., track_numbers=Number |
| 87 | + ), |
| 88 | + prev.ind, |
| 89 | + prev.axes, |
| 90 | + ) |
| 91 | +end |
| 92 | + |
| 93 | +# Materialize into a dense array |
| 94 | +function ReactantCore.materialize_traced_array(x::Fill{T}) where {T} |
| 95 | + return TracedUtils.broadcast_to_size( |
| 96 | + TracedUtils.promote_to(TracedRNumber{unwrapped_eltype(T)}, x.value), size(x) |
| 97 | + ) |
| 98 | +end |
| 99 | + |
| 100 | +function ReactantCore.materialize_traced_array(x::Ones{T}) where {T} |
| 101 | + return TracedUtils.broadcast_to_size(unwrapped_eltype(T)(1), size(x)) |
| 102 | +end |
| 103 | + |
| 104 | +function ReactantCore.materialize_traced_array(x::Zeros{T}) where {T} |
| 105 | + return TracedUtils.broadcast_to_size(unwrapped_eltype(T)(0), size(x)) |
| 106 | +end |
| 107 | + |
| 108 | +function ReactantCore.materialize_traced_array(x::OneElement{T}) where {T} |
| 109 | + y = TracedUtils.broadcast_to_size(unwrapped_eltype(T)(0), size(x)) |
| 110 | + @allowscalar setindex!(y, x.val, x.ind...) |
| 111 | + return y |
| 112 | +end |
| 113 | + |
| 114 | +# some functions to avoid bad performance |
| 115 | +for AT in (Fill, Ones, Zeros, OneElement) |
| 116 | + @eval function Base.similar(x::$AT{<:TracedRNumber}, ::Type{T}, dims::Dims) where {T} |
| 117 | + return TracedUtils.broadcast_to_size(unwrapped_eltype(T)(0), dims) |
| 118 | + end |
| 119 | +end |
| 120 | + |
| 121 | +end |
0 commit comments