Skip to content

Commit 9796318

Browse files
Fix atomic store in cuda (#1263)
* Fix atomic store in cuda * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update cuda.jl * Update cuda.jl --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent cec7f8a commit 9796318

File tree

3 files changed

+35
-7
lines changed

3 files changed

+35
-7
lines changed

ext/ReactantCUDAExt.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,12 @@ struct CuTracedRNumber{T,A} <: Number
4646
end
4747
end
4848

49-
CuTracedRNumber{T,A}(val::Number) where {T,A} = convert(CuTracedRNumber{T,A}, val)
49+
Base.@nospecializeinfer Reactant.is_traced_number(
50+
@nospecialize(T::Type{<:CuTracedRNumber})
51+
) = true
52+
Reactant.unwrapped_eltype(::Type{<:CuTracedRNumber{T}}) where {T} = T
53+
54+
@inline CuTracedRNumber{T,A}(val::Number) where {T,A} = convert(CuTracedRNumber{T,A}, val)
5055

5156
function Base.getindex(RN::CuTracedRNumber{T,A}) where {T,A}
5257
align = alignment(RN)
@@ -99,13 +104,13 @@ Base.OneTo(x::CuTracedRNumber{<:Integer}) = Base.OneTo(x[])
99104
end
100105
end
101106

102-
function Base.convert(CT::Type{CuTracedRNumber{Float64,1}}, x::Number)
107+
@inline function Base.convert(CT::Type{CuTracedRNumber{Float64,1}}, x::Number)
103108
return CT(
104109
Base.llvmcall(
105110
(
106111
"""define double addrspace(1)* @entry(double %d) alwaysinline {
107112
%a = alloca double
108-
store double %d, double* %a
113+
store atomic double %d, double* %a release, align 8
109114
%ac = addrspacecast double* %a to double addrspace(1)*
110115
ret double addrspace(1)* %ac
111116
}
@@ -119,13 +124,13 @@ function Base.convert(CT::Type{CuTracedRNumber{Float64,1}}, x::Number)
119124
)
120125
end
121126

122-
function Base.convert(CT::Type{CuTracedRNumber{Float32,1}}, x::Number)
127+
@inline function Base.convert(CT::Type{CuTracedRNumber{Float32,1}}, x::Number)
123128
return CT(
124129
Base.llvmcall(
125130
(
126131
"""define float addrspace(1)* @entry(float %d) alwaysinline {
127132
%a = alloca float
128-
store float %d, float* %a
133+
store atomic float %d, float* %a release, align 4
129134
%ac = addrspacecast float* %a to float addrspace(1)*
130135
ret float addrspace(1)* %ac
131136
}
@@ -1070,6 +1075,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
10701075
# linearize kernel arguments
10711076
seen = Reactant.OrderedIdDict()
10721077
kernelargsym = gensym("kernelarg")
1078+
10731079
for (i, prev) in enumerate(Any[func.f, args...])
10741080
Reactant.make_tracer(seen, prev, (kernelargsym, i), Reactant.NoStopTracedTrack)
10751081
end

src/Tracing.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ struct VisitedObject
1212
id::Int
1313
end
1414

15+
is_traced_number(x::Type) = false
16+
Base.@nospecializeinfer is_traced_number(@nospecialize(T::Type{<:TracedRNumber})) = true
17+
1518
function traced_type_inner end
1619

1720
Base.@nospecializeinfer function traced_type_inner(
@@ -1018,7 +1021,7 @@ Base.@nospecializeinfer function make_tracer_via_immutable_constructor(
10181021
end
10191022
FT = fieldtype(TT, i)
10201023
if mode != TracedToTypes && !(Core.Typeof(xi2) <: FT)
1021-
if FT <: TracedRNumber && xi2 isa unwrapped_eltype(FT)
1024+
if is_traced_number(FT) && xi2 isa unwrapped_eltype(FT)
10221025
xi2 = FT(xi2)
10231026
xi2 = Core.Typeof(xi2)((newpath,), xi2.mlir_data)
10241027
seen[xi2] = xi2
@@ -1139,7 +1142,7 @@ Base.@nospecializeinfer function make_tracer_unknown(
11391142
end
11401143
FT = fieldtype(TT, i)
11411144
if mode != TracedToTypes && !(Core.Typeof(xi2) <: FT)
1142-
if FT <: TracedRNumber && xi2 isa unwrapped_eltype(FT)
1145+
if is_traced_number(FT) && xi2 isa unwrapped_eltype(FT)
11431146
xi2 = FT(xi2)
11441147
xi2 = Core.Typeof(xi2)((newpath,), xi2.mlir_data)
11451148
seen[xi2] = xi2

test/integration/cuda.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,3 +203,22 @@ end
203203
@jit searchsorted!(A, B)
204204
@test all(Array(A) .≈ 311)
205205
end
206+
207+
function convert_mul_kernel!(Gu, w::FT) where {FT}
208+
r = FT(0.5) * w
209+
@inbounds Gu[1, 1, 1] = r
210+
return nothing
211+
end
212+
213+
function convert_mul!(Gu, w)
214+
@cuda blocks = 1 threads = 1 convert_mul_kernel!(Gu, w)
215+
return nothing
216+
end
217+
218+
@testset "Convert mul" begin
219+
w = Reactant.ConcreteRNumber(0.6)
220+
Gu = Reactant.to_rarray(ones(24, 24, 24))
221+
@jit convert_mul!(Gu, w)
222+
Gui = Array((Gu))
223+
@test Gui[1] 0.3
224+
end

0 commit comments

Comments
 (0)