@@ -106,41 +106,45 @@ end
106
106
107
107
@inline function Base. convert (CT:: Type{CuTracedRNumber{Float64,1}} , x:: Number )
108
108
return CT (
109
+ Base. reinterpret (Core. LLVMPtr{Float64,1 },
109
110
Base. llvmcall (
110
111
(
111
- """ define double addrspace(1)* @entry(double %d) alwaysinline {
112
+ """ define i8 addrspace(1)* @entry(double %d) alwaysinline {
112
113
%a = alloca double
113
114
store atomic double %d, double* %a release, align 8
114
- %ac = addrspacecast double* %a to double addrspace(1)*
115
- ret double addrspace(1)* %ac
115
+ %bc = bitcast double* %a to i8*
116
+ %ac = addrspacecast i8* %bc to i8 addrspace(1)*
117
+ ret i8 addrspace(1)* %ac
116
118
}
117
119
""" ,
118
120
" entry" ,
119
121
),
120
- Core. LLVMPtr{Float64 ,1 },
122
+ Core. LLVMPtr{UInt8 ,1 },
121
123
Tuple{Float64},
122
124
Base. convert (Float64, x),
123
- ),
125
+ ))
124
126
)
125
127
end
126
128
127
129
@inline function Base. convert (CT:: Type{CuTracedRNumber{Float32,1}} , x:: Number )
128
130
return CT (
131
+ Base. reinterpret (Core. LLVMPtr{Float32,1 },
129
132
Base. llvmcall (
130
133
(
131
- """ define float addrspace(1)* @entry(float %d) alwaysinline {
134
+ """ define i8 addrspace(1)* @entry(float %d) alwaysinline {
132
135
%a = alloca float
133
136
store atomic float %d, float* %a release, align 4
134
- %ac = addrspacecast float* %a to float addrspace(1)*
135
- ret float addrspace(1)* %ac
137
+ %bc = bitcast float* %a to i8*
138
+ %ac = addrspacecast i8* %bc to i8 addrspace(1)*
139
+ ret i8 addrspace(1)* %ac
136
140
}
137
141
""" ,
138
142
" entry" ,
139
143
),
140
- Core. LLVMPtr{Float32 ,1 },
144
+ Core. LLVMPtr{UInt8 ,1 },
141
145
Tuple{Float32},
142
146
Base. convert (Float32, x),
143
- ),
147
+ ))
144
148
)
145
149
end
146
150
@@ -908,6 +912,14 @@ function compile(job)
908
912
if Reactant. Compiler. DUMP_LLVMIR[]
909
913
println (" cuda.jl pre vendor IR\n " , string (mod))
910
914
end
915
+
916
+ LLVM. @dispose pb = LLVM. NewPMPassBuilder () begin
917
+ LLVM. add! (pb, LLVM. NewPMModulePassManager ()) do mpm
918
+ LLVM. add! (mpm, LLVM. AlwaysInlinerPass ())
919
+ end
920
+ LLVM. run! (pb, mod, tm)
921
+ end
922
+
911
923
vendored_optimize_module! (job, mod)
912
924
if Reactant. Compiler. DUMP_LLVMIR[]
913
925
println (" cuda.jl post vendor IR\n " , string (mod))
0 commit comments