1
1
module ReactantCUDAExt
2
2
3
- using CUDA
4
3
using Reactant: Reactant, TracedRArray, AnyConcretePJRTArray, MLIR, TracedRNumber
5
4
using Reactant. Compiler: raising
6
- using ReactantCore: @trace
5
+ using Reactant. Ops: @opcall
6
+
7
+ using Adapt: Adapt, adapt
8
+ using CUDA: CUDA, CuDim, DenseCuArray, unsafe_cached_load
9
+
7
10
using GPUCompiler: GPUCompiler
8
11
using KernelAbstractions: KernelAbstractions
9
- import KernelAbstractions as KA
10
12
using LLVM: LLVM
11
- using Libdl
12
-
13
- using Reactant. Ops: @opcall
14
13
15
- const ReactantKernelAbstractionsExt = Base. get_extension (
16
- Reactant, :ReactantKernelAbstractionsExt
17
- )
18
- const ReactantBackend = ReactantKernelAbstractionsExt. ReactantBackend
14
+ using PrecompileTools: @setup_workload , @compile_workload
19
15
20
- using Adapt
16
+ const KA = KernelAbstractions
21
17
22
18
Reactant. is_extension_loaded (:: Val{:CUDA} ) = true
23
19
@@ -64,9 +60,7 @@ function Base.getindex(RN::CuTracedRNumber{T,A}) where {T,A}
64
60
return @inbounds unsafe_load (RN. ptr, 1 , Val (align))
65
61
end
66
62
67
- function Base. convert (:: Type{T} , RN:: CuTracedRNumber ) where {T<: Number }
68
- return Base. convert (T, Base. getindex (RN))
69
- end
63
+ Base. convert (:: Type{T} , RN:: CuTracedRNumber ) where {T<: Number } = convert (T, getindex (RN))
70
64
71
65
for jlop in (
72
66
:(Base. min),
@@ -89,17 +83,15 @@ for jlop in (
89
83
end
90
84
end
91
85
92
- @inline Base. ifelse (cond:: Bool , a, b:: CuTracedRNumber ) = Base . ifelse (cond, a, b[])
93
- @inline Base. ifelse (cond:: Bool , a:: CuTracedRNumber , b) = Base . ifelse (cond, a[], b)
86
+ @inline Base. ifelse (cond:: Bool , a, b:: CuTracedRNumber ) = ifelse (cond, a, b[])
87
+ @inline Base. ifelse (cond:: Bool , a:: CuTracedRNumber , b) = ifelse (cond, a[], b)
94
88
@inline Base. ifelse (cond:: Bool , a:: CuTracedRNumber , b:: CuTracedRNumber ) =
95
- Base. ifelse (cond, a[], b[])
96
- @inline Base. ifelse (cond:: CuTracedRNumber , a, b) = Base. ifelse (cond[], a, b)
97
- @inline Base. ifelse (cond:: CuTracedRNumber , a:: CuTracedRNumber , b) =
98
- Base. ifelse (cond[], a[], b)
99
- @inline Base. ifelse (cond:: CuTracedRNumber , a, b:: CuTracedRNumber ) =
100
- Base. ifelse (cond[], a, b[])
89
+ ifelse (cond, a[], b[])
90
+ @inline Base. ifelse (cond:: CuTracedRNumber , a, b) = ifelse (cond[], a, b)
91
+ @inline Base. ifelse (cond:: CuTracedRNumber , a:: CuTracedRNumber , b) = ifelse (cond[], a[], b)
92
+ @inline Base. ifelse (cond:: CuTracedRNumber , a, b:: CuTracedRNumber ) = ifelse (cond[], a, b[])
101
93
@inline Base. ifelse (cond:: CuTracedRNumber , a:: CuTracedRNumber , b:: CuTracedRNumber ) =
102
- Base . ifelse (cond[], a[], b[])
94
+ ifelse (cond[], a[], b[])
103
95
104
96
Base. @constprop :aggressive @inline Base.:^ (
105
97
a:: CuTracedRNumber{T,A} , b:: Integer
140
132
),
141
133
Core. LLVMPtr{UInt8,1 },
142
134
Tuple{Float64},
143
- Base . convert (Float64, x),
135
+ convert (Float64, x),
144
136
),
145
137
),
146
138
)
164
156
),
165
157
Core. LLVMPtr{UInt8,1 },
166
158
Tuple{Float32},
167
- Base . convert (Float32, x),
159
+ convert (Float32, x),
168
160
),
169
161
),
170
162
)
@@ -181,7 +173,7 @@ Base.@nospecializeinfer function Base.promote_rule(
181
173
@nospecialize (a:: Type{<:CuTracedRNumber{T}} ),
182
174
@nospecialize (b:: Type{<:CuTracedRNumber{T2}} )
183
175
) where {T,T2}
184
- return Base . promote_rule (T, T2)
176
+ return promote_rule (T, T2)
185
177
end
186
178
Base. @nospecializeinfer function Base. promote_rule (
187
179
:: Type{Any} , @nospecialize (b:: Type{<:CuTracedRNumber} )
@@ -199,7 +191,7 @@ Base.@nospecializeinfer function Base.promote_rule(
199
191
if T == T2
200
192
return T
201
193
else
202
- return Base . promote_rule (T, T2)
194
+ return promote_rule (T, T2)
203
195
end
204
196
end
205
197
Base. @nospecializeinfer function Base. promote_rule (
@@ -208,7 +200,7 @@ Base.@nospecializeinfer function Base.promote_rule(
208
200
if T == T2
209
201
return T
210
202
else
211
- return Base . promote_rule (T, T2)
203
+ return promote_rule (T, T2)
212
204
end
213
205
end
214
206
@@ -506,9 +498,7 @@ function threads_to_workgroupsize(threads, ndrange)
506
498
end
507
499
end
508
500
509
- function ReactantKernelAbstractionsExt. ka_with_reactant (
510
- ndrange, workgroupsize, obj, args...
511
- )
501
+ function Reactant. ka_with_reactant (ndrange, workgroupsize, obj, args... )
512
502
backend = KA. backend (obj)
513
503
514
504
ndrange, workgroupsize, iterspace, dynamic = KA. launch_config (
@@ -588,7 +578,7 @@ function Adapt.adapt_storage(::ReactantKernelAdaptor, xs::TracedRNumber{T}) wher
588
578
return res
589
579
end
590
580
591
- import Reactant. TracedRNumberOverrides . TracedStepRangeLen
581
+ import Reactant. TracedStepRangeLen
592
582
593
583
function Adapt. adapt_storage (:: ReactantKernelAdaptor , r:: TracedStepRangeLen )
594
584
return TracedStepRangeLen (
@@ -1481,7 +1471,7 @@ end
1481
1471
# In Julia v1.11.3 precompiling this module caches bad code:
1482
1472
# <https://github.com/EnzymeAD/Reactant.jl/issues/614>.
1483
1473
@static if ! Sys. isapple ()
1484
- Reactant . PrecompileTools . @setup_workload begin
1474
+ @setup_workload begin
1485
1475
Reactant. initialize_dialect ()
1486
1476
1487
1477
if Reactant. XLA. REACTANT_XLA_RUNTIME == " PJRT"
@@ -1492,7 +1482,7 @@ end
1492
1482
error (" Unsupported runtime: $(Reactant. XLA. REACTANT_XLA_RUNTIME) " )
1493
1483
end
1494
1484
1495
- Reactant . PrecompileTools . @compile_workload begin
1485
+ @compile_workload begin
1496
1486
@static if Reactant. precompilation_supported () && VERSION != v " 1.11.3"
1497
1487
function square_kernel! (x)
1498
1488
i = CUDA. threadIdx (). x
0 commit comments