@@ -46,7 +46,12 @@ struct CuTracedRNumber{T,A} <: Number
46
46
end
47
47
end
48
48
49
- CuTracedRNumber {T,A} (val:: Number ) where {T,A} = convert (CuTracedRNumber{T,A}, val)
49
+ Base. @nospecializeinfer Reactant. is_traced_number (
50
+ @nospecialize (T:: Type{<:CuTracedRNumber} )
51
+ ) = true
52
+ Reactant. unwrapped_eltype (:: Type{<:CuTracedRNumber{T}} ) where {T} = T
53
+
54
+ @inline CuTracedRNumber {T,A} (val:: Number ) where {T,A} = convert (CuTracedRNumber{T,A}, val)
50
55
51
56
function Base. getindex (RN:: CuTracedRNumber{T,A} ) where {T,A}
52
57
align = alignment (RN)
@@ -99,13 +104,13 @@ Base.OneTo(x::CuTracedRNumber{<:Integer}) = Base.OneTo(x[])
99
104
end
100
105
end
101
106
102
- function Base. convert (CT:: Type{CuTracedRNumber{Float64,1}} , x:: Number )
107
+ @inline function Base. convert (CT:: Type{CuTracedRNumber{Float64,1}} , x:: Number )
103
108
return CT (
104
109
Base. llvmcall (
105
110
(
106
111
""" define double addrspace(1)* @entry(double %d) alwaysinline {
107
112
%a = alloca double
108
- store double %d, double* %a
113
+ store atomic double %d, double* %a release, align 8
109
114
%ac = addrspacecast double* %a to double addrspace(1)*
110
115
ret double addrspace(1)* %ac
111
116
}
@@ -119,13 +124,13 @@ function Base.convert(CT::Type{CuTracedRNumber{Float64,1}}, x::Number)
119
124
)
120
125
end
121
126
122
- function Base. convert (CT:: Type{CuTracedRNumber{Float32,1}} , x:: Number )
127
+ @inline function Base. convert (CT:: Type{CuTracedRNumber{Float32,1}} , x:: Number )
123
128
return CT (
124
129
Base. llvmcall (
125
130
(
126
131
""" define float addrspace(1)* @entry(float %d) alwaysinline {
127
132
%a = alloca float
128
- store float %d, float* %a
133
+ store atomic float %d, float* %a release, align 4
129
134
%ac = addrspacecast float* %a to float addrspace(1)*
130
135
ret float addrspace(1)* %ac
131
136
}
@@ -1070,6 +1075,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
1070
1075
# linearize kernel arguments
1071
1076
seen = Reactant. OrderedIdDict ()
1072
1077
kernelargsym = gensym (" kernelarg" )
1078
+
1073
1079
for (i, prev) in enumerate (Any[func. f, args... ])
1074
1080
Reactant. make_tracer (seen, prev, (kernelargsym, i), Reactant. NoStopTracedTrack)
1075
1081
end
0 commit comments