Skip to content

Commit f4af26a

Browse files
Promote rule of cutraced (#1257)
* Promote rule of cutraced * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent c4360af commit f4af26a

File tree

3 files changed

+71
-8
lines changed

3 files changed

+71
-8
lines changed

ext/ReactantCUDAExt.jl

Lines changed: 60 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -146,22 +146,74 @@ Base.one(::Type{<:CuTracedRNumber{T,A}}) where {T,A} = one(T)
146146
Base.zero(a::CuTracedRNumber) = zero(a[])
147147
Base.zero(::Type{<:CuTracedRNumber{T,A}}) where {T,A} = zero(T)
148148

149-
function Base.promote_rule(
150-
::Type{<:CuTracedRNumber{T}}, ::Type{<:CuTracedRNumber{T2}}
149+
Base.@nospecializeinfer function Base.promote_rule(
150+
@nospecialize(a::Type{<:CuTracedRNumber{T}}),
151+
@nospecialize(b::Type{<:CuTracedRNumber{T2}})
151152
) where {T,T2}
152153
return Base.promote_rule(T, T2)
153154
end
154-
function Base.promote_rule(::Type{Any}, ::Type{<:CuTracedRNumber})
155+
Base.@nospecializeinfer function Base.promote_rule(
156+
::Type{Any}, @nospecialize(b::Type{<:CuTracedRNumber})
157+
)
155158
return Any
156159
end
157-
function Base.promote_rule(::Type{<:CuTracedRNumber}, ::Type{Any})
160+
Base.@nospecializeinfer function Base.promote_rule(
161+
@nospecialize(a::Type{<:CuTracedRNumber}), ::Type{Any}
162+
)
158163
return Any
159164
end
160-
function Base.promote_rule(::Type{T2}, ::Type{<:CuTracedRNumber{T}}) where {T,T2}
161-
return Base.promote_rule(T, T2)
165+
Base.@nospecializeinfer function Base.promote_rule(
166+
@nospecialize(T2::Type), @nospecialize(b::Type{<:CuTracedRNumber{T}})
167+
) where {T}
168+
if T == T2
169+
return T
170+
else
171+
return Base.promote_rule(T, T2)
172+
end
162173
end
163-
function Base.promote_rule(::Type{<:CuTracedRNumber{T}}, ::Type{T2}) where {T,T2}
164-
return Base.promote_rule(T, T2)
174+
Base.@nospecializeinfer function Base.promote_rule(
175+
@nospecialize(a::Type{<:CuTracedRNumber{T}}), @nospecialize(T2::Type)
176+
) where {T}
177+
if T == T2
178+
return T
179+
else
180+
return Base.promote_rule(T, T2)
181+
end
182+
end
183+
184+
Base.@nospecializeinfer function Reactant.promote_traced_type(
185+
@nospecialize(a::Type{<:CuTracedRNumber{T,A}}),
186+
@nospecialize(b::Type{<:CuTracedRNumber{T2,A}})
187+
) where {T,T2,A}
188+
return CuTracedRNumber{Reactant.promote_traced_type(T, T2),A}
189+
end
190+
Base.@nospecializeinfer function Reactant.promote_traced_type(
191+
::Type{Any}, @nospecialize(b::Type{<:CuTracedRNumber})
192+
)
193+
return Any
194+
end
195+
Base.@nospecializeinfer function Reactant.promote_traced_type(
196+
@nospecialize(a::Type{<:CuTracedRNumber}), ::Type{Any}
197+
)
198+
return Any
199+
end
200+
Base.@nospecializeinfer function Reactant.promote_traced_type(
201+
@nospecialize(T2::Type), ::Type{<:CuTracedRNumber{T,A}}
202+
) where {T,A}
203+
if T == T2
204+
return CuTracedRNumber{T,A}
205+
else
206+
return CuTracedRNumber{Reactant.promote_trace_type(T, T2),A}
207+
end
208+
end
209+
Base.@nospecializeinfer function Reactant.promote_traced_type(
210+
::Type{<:CuTracedRNumber{T,A}}, @nospecialize(T2::Type)
211+
) where {T,A}
212+
if T == T2
213+
return CuTracedRNumber{T,A}
214+
else
215+
return CuTracedRNumber{Reactant.promote_trace_type(T, T2),A}
216+
end
165217
end
166218

167219
function Base.show(io::IO, a::AT) where {AT<:CuTracedArray}

src/Reactant.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ unwrapped_eltype(::TracedRNumber{T}) where {T} = T
9494
unwrapped_eltype(::Type{<:AbstractArray{T,N}}) where {T,N} = unwrapped_eltype(T)
9595
unwrapped_eltype(::AbstractArray{T,N}) where {T,N} = unwrapped_eltype(T)
9696

97+
promote_traced_type(a::Type, b::Type) = Base.promote_type(a, b)
98+
9799
aos_to_soa(x::AbstractArray) = x
98100

99101
aos_to_soa(x::TracedRArray) = x

test/integration/cuda.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,15 @@ using Reactant
33
using Test
44
using CUDA
55

6+
const ReactantCUDAExt = Base.get_extension(Reactant, :ReactantCUDAExt)
7+
8+
@testset "Promote CuTraced" begin
9+
TFT = ReactantCUDAExt.CuTracedRNumber{Float64,1}
10+
FT = Float64
11+
@test Reactant.promote_traced_type(TFT, FT) == TFT
12+
@test Base.promote_type(TFT, FT) == FT
13+
end
14+
615
function square_kernel!(x, y)
716
i = threadIdx().x
817
x[i] *= y[i]

0 commit comments

Comments
 (0)