Skip to content

Commit ac4082a

Browse files
authored
CUDA: try more float alwaysinline (#1265)
* CUDA: try more float alwaysinline * f64
1 parent 6da3e14 commit ac4082a

File tree

1 file changed

+22
-10
lines changed

1 file changed

+22
-10
lines changed

ext/ReactantCUDAExt.jl

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -106,41 +106,45 @@ end
106106

107107
@inline function Base.convert(CT::Type{CuTracedRNumber{Float64,1}}, x::Number)
108108
return CT(
109+
Base.reinterpret(Core.LLVMPtr{Float64,1},
109110
Base.llvmcall(
110111
(
111-
"""define double addrspace(1)* @entry(double %d) alwaysinline {
112+
"""define i8 addrspace(1)* @entry(double %d) alwaysinline {
112113
%a = alloca double
113114
store atomic double %d, double* %a release, align 8
114-
%ac = addrspacecast double* %a to double addrspace(1)*
115-
ret double addrspace(1)* %ac
115+
%bc = bitcast double* %a to i8*
116+
%ac = addrspacecast i8* %bc to i8 addrspace(1)*
117+
ret i8 addrspace(1)* %ac
116118
}
117119
""",
118120
"entry",
119121
),
120-
Core.LLVMPtr{Float64,1},
122+
Core.LLVMPtr{UInt8,1},
121123
Tuple{Float64},
122124
Base.convert(Float64, x),
123-
),
125+
))
124126
)
125127
end
126128

127129
@inline function Base.convert(CT::Type{CuTracedRNumber{Float32,1}}, x::Number)
128130
return CT(
131+
Base.reinterpret(Core.LLVMPtr{Float32,1},
129132
Base.llvmcall(
130133
(
131-
"""define float addrspace(1)* @entry(float %d) alwaysinline {
134+
"""define i8 addrspace(1)* @entry(float %d) alwaysinline {
132135
%a = alloca float
133136
store atomic float %d, float* %a release, align 4
134-
%ac = addrspacecast float* %a to float addrspace(1)*
135-
ret float addrspace(1)* %ac
137+
%bc = bitcast float* %a to i8*
138+
%ac = addrspacecast i8* %bc to i8 addrspace(1)*
139+
ret i8 addrspace(1)* %ac
136140
}
137141
""",
138142
"entry",
139143
),
140-
Core.LLVMPtr{Float32,1},
144+
Core.LLVMPtr{UInt8,1},
141145
Tuple{Float32},
142146
Base.convert(Float32, x),
143-
),
147+
))
144148
)
145149
end
146150

@@ -908,6 +912,14 @@ function compile(job)
908912
if Reactant.Compiler.DUMP_LLVMIR[]
909913
println("cuda.jl pre vendor IR\n", string(mod))
910914
end
915+
916+
LLVM.@dispose pb = LLVM.NewPMPassBuilder() begin
917+
LLVM.add!(pb, LLVM.NewPMModulePassManager()) do mpm
918+
LLVM.add!(mpm, LLVM.AlwaysInlinerPass())
919+
end
920+
LLVM.run!(pb, mod, tm)
921+
end
922+
911923
vendored_optimize_module!(job, mod)
912924
if Reactant.Compiler.DUMP_LLVMIR[]
913925
println("cuda.jl post vendor IR\n", string(mod))

0 commit comments

Comments
 (0)