Skip to content

Commit 9d852a2

Browse files
test: code quality testing (#1668)
* fix: improper qualified accesses * fix: use explicit imports * fix: actually use explicit imports * test: add qa testing via Aqua and ExplicitImports * fix: more ambiguities fixed * fix: more ambiguities fixed * fix: more ambiguities fixed * chore: overlay getindex * fix: indexing * fix: more indexing fixes * feat: all ambiguities have been fixed * fix: type conversion * fix: lu factorization * fix: min compats * fix: missing getindex for unitrange * fix: more bug fixes * fix: getindex ambiguities * chore: fmt Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * chore: fmt Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix: for 1.10 * fix: use explicit imports * docs: openssl version mismatch --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 05eb248 commit 9d852a2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+1704
-1101
lines changed

.github/workflows/downgrade.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ jobs:
7373
push!(dev_pks, Pkg.PackageSpec(; path))
7474
end
7575
Pkg.develop(dev_pks)
76-
Pkg.test(; coverage="user")
76+
Pkg.test(; coverage="user", allow_reresolve=false)
7777
shell: julia --color=yes --code-coverage=user --depwarn=yes --project=. {0}
7878
id: run_tests
7979
env:

Project.toml

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
1313
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1414
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1515
HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
16+
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
1617
LLVMOpenMP_jll = "1d63c593-3942-5779-bab2-d838dc0a180e"
1718
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
1819
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -36,13 +37,13 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
3637
Float8s = "81dfefd7-55b0-40c6-a251-db853704e186"
3738
GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
3839
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
39-
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
4040
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
4141
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
4242
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
4343
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
4444
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
4545
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
46+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
4647
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
4748
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
4849
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"
@@ -64,6 +65,7 @@ ReactantOffsetArraysExt = "OffsetArrays"
6465
ReactantOneHotArraysExt = "OneHotArrays"
6566
ReactantPythonCallExt = "PythonCall"
6667
ReactantRandom123Ext = "Random123"
68+
ReactantSparseArraysExt = "SparseArrays"
6769
ReactantSpecialFunctionsExt = "SpecialFunctions"
6870
ReactantStatisticsExt = "Statistics"
6971
ReactantYaoBlocksExt = "YaoBlocks"
@@ -77,8 +79,8 @@ CUDA = "5.6"
7779
DLFP8Types = "0.1"
7880
Downloads = "1.6"
7981
EnumX = "1"
80-
Enzyme = "0.13.72"
81-
EnzymeCore = "0.8.11"
82+
Enzyme = "0.13.74"
83+
EnzymeCore = "0.8.13"
8284
FillArrays = "1.13"
8385
Float8s = "0.1"
8486
Functors = "0.5"
@@ -88,14 +90,15 @@ HTTP = "1.10.15"
8890
KernelAbstractions = "0.9.30"
8991
LLVM = "9.1"
9092
LLVMOpenMP_jll = "18.1.7"
93+
Libdl = "1.10"
9194
LinearAlgebra = "1.10"
9295
MPI = "0.20"
9396
NNlib = "0.9.26"
9497
OffsetArrays = "1"
9598
OneHotArrays = "0.2.10"
9699
OrderedCollections = "1"
97100
PrecompileTools = "1.2"
98-
Preferences = "1.4"
101+
Preferences = "1.4.3"
99102
PythonCall = "0.9.25"
100103
Random = "1.10"
101104
Random123 = "1.7"
@@ -104,6 +107,7 @@ Reactant_jll = "0.0.240"
104107
ScopedValues = "1.3.0"
105108
Scratch = "1.2"
106109
Sockets = "1.10"
110+
SparseArrays = "1.10"
107111
SpecialFunctions = "2.4"
108112
Statistics = "1.10"
109113
YaoBlocks = "0.13, 0.14"

docs/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
33
DocumenterVitepress = "4710194d-e776-4893-9690-8d956a29c365"
44
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
5+
OpenSSL_jll = "458c3c95-2e84-50aa-8efc-19380b2a3a95"
56
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
67
ReactantCore = "a3311ec8-5e00-46d5-b541-4f83e724a433"
78
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
@@ -13,3 +14,4 @@ ReactantCore = {path = "../lib/ReactantCore"}
1314
[compat]
1415
Documenter = "1.4.1"
1516
DocumenterVitepress = "0.2"
17+
OpenSSL_jll = "=3.0.16"

ext/ReactantArrayInterfaceExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module ReactantArrayInterfaceExt
22

33
using ArrayInterface: ArrayInterface
4-
using Reactant: Reactant, RArray, AbstractConcreteNumber, AnyTracedRArray, Ops
4+
using Reactant: Reactant, RArray, AbstractConcreteNumber, AnyTracedRArray
55

66
ArrayInterface.can_setindex(::Type{<:RArray}) = false
77
ArrayInterface.fast_scalar_indexing(::Type{<:RArray}) = false

ext/ReactantCUDAExt.jl

Lines changed: 24 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,19 @@
11
module ReactantCUDAExt
22

3-
using CUDA
43
using Reactant: Reactant, TracedRArray, AnyConcretePJRTArray, MLIR, TracedRNumber
54
using Reactant.Compiler: raising
6-
using ReactantCore: @trace
5+
using Reactant.Ops: @opcall
6+
7+
using Adapt: Adapt, adapt
8+
using CUDA: CUDA, CuDim, DenseCuArray, unsafe_cached_load
9+
710
using GPUCompiler: GPUCompiler
811
using KernelAbstractions: KernelAbstractions
9-
import KernelAbstractions as KA
1012
using LLVM: LLVM
11-
using Libdl
12-
13-
using Reactant.Ops: @opcall
1413

15-
const ReactantKernelAbstractionsExt = Base.get_extension(
16-
Reactant, :ReactantKernelAbstractionsExt
17-
)
18-
const ReactantBackend = ReactantKernelAbstractionsExt.ReactantBackend
14+
using PrecompileTools: @setup_workload, @compile_workload
1915

20-
using Adapt
16+
const KA = KernelAbstractions
2117

2218
Reactant.is_extension_loaded(::Val{:CUDA}) = true
2319

@@ -64,9 +60,7 @@ function Base.getindex(RN::CuTracedRNumber{T,A}) where {T,A}
6460
return @inbounds unsafe_load(RN.ptr, 1, Val(align))
6561
end
6662

67-
function Base.convert(::Type{T}, RN::CuTracedRNumber) where {T<:Number}
68-
return Base.convert(T, Base.getindex(RN))
69-
end
63+
Base.convert(::Type{T}, RN::CuTracedRNumber) where {T<:Number} = convert(T, getindex(RN))
7064

7165
for jlop in (
7266
:(Base.min),
@@ -89,17 +83,15 @@ for jlop in (
8983
end
9084
end
9185

92-
@inline Base.ifelse(cond::Bool, a, b::CuTracedRNumber) = Base.ifelse(cond, a, b[])
93-
@inline Base.ifelse(cond::Bool, a::CuTracedRNumber, b) = Base.ifelse(cond, a[], b)
86+
@inline Base.ifelse(cond::Bool, a, b::CuTracedRNumber) = ifelse(cond, a, b[])
87+
@inline Base.ifelse(cond::Bool, a::CuTracedRNumber, b) = ifelse(cond, a[], b)
9488
@inline Base.ifelse(cond::Bool, a::CuTracedRNumber, b::CuTracedRNumber) =
95-
Base.ifelse(cond, a[], b[])
96-
@inline Base.ifelse(cond::CuTracedRNumber, a, b) = Base.ifelse(cond[], a, b)
97-
@inline Base.ifelse(cond::CuTracedRNumber, a::CuTracedRNumber, b) =
98-
Base.ifelse(cond[], a[], b)
99-
@inline Base.ifelse(cond::CuTracedRNumber, a, b::CuTracedRNumber) =
100-
Base.ifelse(cond[], a, b[])
89+
ifelse(cond, a[], b[])
90+
@inline Base.ifelse(cond::CuTracedRNumber, a, b) = ifelse(cond[], a, b)
91+
@inline Base.ifelse(cond::CuTracedRNumber, a::CuTracedRNumber, b) = ifelse(cond[], a[], b)
92+
@inline Base.ifelse(cond::CuTracedRNumber, a, b::CuTracedRNumber) = ifelse(cond[], a, b[])
10193
@inline Base.ifelse(cond::CuTracedRNumber, a::CuTracedRNumber, b::CuTracedRNumber) =
102-
Base.ifelse(cond[], a[], b[])
94+
ifelse(cond[], a[], b[])
10395

10496
Base.@constprop :aggressive @inline Base.:^(
10597
a::CuTracedRNumber{T,A}, b::Integer
@@ -140,7 +132,7 @@ end
140132
),
141133
Core.LLVMPtr{UInt8,1},
142134
Tuple{Float64},
143-
Base.convert(Float64, x),
135+
convert(Float64, x),
144136
),
145137
),
146138
)
@@ -164,7 +156,7 @@ end
164156
),
165157
Core.LLVMPtr{UInt8,1},
166158
Tuple{Float32},
167-
Base.convert(Float32, x),
159+
convert(Float32, x),
168160
),
169161
),
170162
)
@@ -181,7 +173,7 @@ Base.@nospecializeinfer function Base.promote_rule(
181173
@nospecialize(a::Type{<:CuTracedRNumber{T}}),
182174
@nospecialize(b::Type{<:CuTracedRNumber{T2}})
183175
) where {T,T2}
184-
return Base.promote_rule(T, T2)
176+
return promote_rule(T, T2)
185177
end
186178
Base.@nospecializeinfer function Base.promote_rule(
187179
::Type{Any}, @nospecialize(b::Type{<:CuTracedRNumber})
@@ -199,7 +191,7 @@ Base.@nospecializeinfer function Base.promote_rule(
199191
if T == T2
200192
return T
201193
else
202-
return Base.promote_rule(T, T2)
194+
return promote_rule(T, T2)
203195
end
204196
end
205197
Base.@nospecializeinfer function Base.promote_rule(
@@ -208,7 +200,7 @@ Base.@nospecializeinfer function Base.promote_rule(
208200
if T == T2
209201
return T
210202
else
211-
return Base.promote_rule(T, T2)
203+
return promote_rule(T, T2)
212204
end
213205
end
214206

@@ -506,9 +498,7 @@ function threads_to_workgroupsize(threads, ndrange)
506498
end
507499
end
508500

509-
function ReactantKernelAbstractionsExt.ka_with_reactant(
510-
ndrange, workgroupsize, obj, args...
511-
)
501+
function Reactant.ka_with_reactant(ndrange, workgroupsize, obj, args...)
512502
backend = KA.backend(obj)
513503

514504
ndrange, workgroupsize, iterspace, dynamic = KA.launch_config(
@@ -588,7 +578,7 @@ function Adapt.adapt_storage(::ReactantKernelAdaptor, xs::TracedRNumber{T}) wher
588578
return res
589579
end
590580

591-
import Reactant.TracedRNumberOverrides.TracedStepRangeLen
581+
import Reactant.TracedStepRangeLen
592582

593583
function Adapt.adapt_storage(::ReactantKernelAdaptor, r::TracedStepRangeLen)
594584
return TracedStepRangeLen(
@@ -1481,7 +1471,7 @@ end
14811471
# In Julia v1.11.3 precompiling this module caches bad code:
14821472
# <https://github.com/EnzymeAD/Reactant.jl/issues/614>.
14831473
@static if !Sys.isapple()
1484-
Reactant.PrecompileTools.@setup_workload begin
1474+
@setup_workload begin
14851475
Reactant.initialize_dialect()
14861476

14871477
if Reactant.XLA.REACTANT_XLA_RUNTIME == "PJRT"
@@ -1492,7 +1482,7 @@ end
14921482
error("Unsupported runtime: $(Reactant.XLA.REACTANT_XLA_RUNTIME)")
14931483
end
14941484

1495-
Reactant.PrecompileTools.@compile_workload begin
1485+
@compile_workload begin
14961486
@static if Reactant.precompilation_supported() && VERSION != v"1.11.3"
14971487
function square_kernel!(x)
14981488
i = CUDA.threadIdx().x

ext/ReactantKernelAbstractionsExt.jl

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
module ReactantKernelAbstractionsExt
22

3-
using Reactant
4-
5-
import KernelAbstractions as KA
3+
using Reactant: Reactant
64

75
using Adapt: Adapt
6+
using KernelAbstractions: KernelAbstractions
87

9-
## back-end
8+
const KA = KernelAbstractions
109

11-
export ReactantBackend
10+
## back-end
1211

1312
# ToDo: Include XLA client, device and sharding in ReactantBackend struct, to
1413
# support more complex applications? If so, need to adapt implementation of
@@ -26,7 +25,7 @@ function Base.getproperty(x::ReactantBackend, sym::Symbol)
2625
end
2726

2827
function KA.allocate(::ReactantBackend, ::Type{T}, dims::Tuple) where {T}
29-
return ConcreteRArray{T}(undef, dims)
28+
return Reactant.ConcreteRArray{T}(undef, dims)
3029
end
3130

3231
function KA.zeros(b::ReactantBackend, ::Type{T}, dims::Tuple) where {T}
@@ -103,23 +102,21 @@ end
103102

104103
function (obj::KA.Kernel{ReactantBackend})(args...; ndrange=nothing, workgroupsize=nothing)
105104
if Reactant.precompiling()
106-
@code_hlo optimize = false tokw(ndrange, workgroupsize, obj, args...)
105+
Reactant.@code_hlo optimize = false tokw(ndrange, workgroupsize, obj, args...)
107106
else
108-
@jit tokw(ndrange, workgroupsize, obj, args...)
107+
Reactant.@jit tokw(ndrange, workgroupsize, obj, args...)
109108
end
110109
return nothing
111110
end
112111

113-
function ka_with_reactant end # defined in the CUDA extension
114-
115112
Reactant.@reactant_overlay @noinline Base.@nospecializeinfer function (
116113
obj::KA.Kernel{ReactantBackend}
117114
)(
118115
args...; ndrange=nothing, workgroupsize=nothing
119116
)
120117
@nospecialize
121118
return Reactant.call_with_reactant(
122-
ka_with_reactant, ndrange, workgroupsize, obj, args...
119+
Reactant.ka_with_reactant, ndrange, workgroupsize, obj, args...
123120
)
124121
end
125122

ext/ReactantNNlibExt/Implementations.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ for (jlop, hloop) in (
77
end
88

99
function NNlib.softmax!(out::AnyTracedRArray{T,N}, x::AbstractArray; dims=1) where {T,N}
10-
x = T.(Reactant.materialize_traced_array(x))
10+
x = T.(materialize_traced_array(x))
1111
max_ = maximum(x; dims)
1212
diff = exp.(x .- max_)
1313
# TOOD: re-enable conditional once https://github.com/EnzymeAD/Reactant.jl/issues/1581
@@ -22,7 +22,7 @@ function NNlib.softmax!(out::AnyTracedRArray{T,N}, x::AbstractArray; dims=1) whe
2222
end
2323

2424
function NNlib.logsoftmax!(out::AnyTracedRArray{T}, x::AbstractArray; dims=1) where {T}
25-
x = T.(Reactant.materialize_traced_array(x))
25+
x = T.(materialize_traced_array(x))
2626
max_ = maximum(x; dims)
2727
diff = x .- max_
2828
# TOOD: re-enable conditional once https://github.com/EnzymeAD/Reactant.jl/issues/1581

ext/ReactantNNlibExt/ReactantNNlibExt.jl

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
11
module ReactantNNlibExt
22

3-
using NNlib
4-
using GPUArraysCore: @allowscalar
53
using Reactant:
64
Reactant, Ops, TracedRArray, AnyTracedRArray, MLIR, TracedRNumber, @reactant_overlay
7-
8-
using Reactant.TracedUtils:
9-
TracedUtils, materialize_traced_array, get_mlir_data, set_mlir_data!
5+
using Reactant.TracedUtils: TracedUtils, get_mlir_data, set_mlir_data!
106
using Reactant.Ops: @opcall
7+
using ReactantCore: materialize_traced_array, @trace
118

12-
using ReactantCore: @trace
13-
using LinearAlgebra: LinearAlgebra, triu
9+
using NNlib: NNlib, DenseConvDims
10+
using GPUArraysCore: @allowscalar
1411
using Statistics: mean
1512

1613
include("Overlay.jl")

ext/ReactantOffsetArraysExt.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
module ReactantOffsetArraysExt
22

3-
using OffsetArrays
4-
using OffsetArrays: OffsetArray, OffsetVector
5-
using Reactant: Reactant, MLIR, Ops, TracedRArray, AbstractConcreteArray
3+
using OffsetArrays: OffsetArrays, OffsetArray, OffsetVector
4+
using Reactant: Reactant, MLIR, Ops, TracedRArray, TracedRNumber, AbstractConcreteArray
65

76
Base.@nospecializeinfer function Reactant.traced_type_inner(
87
@nospecialize(OA::Type{<:OffsetArray}),
@@ -45,22 +44,21 @@ function Base.getindex(
4544
end
4645

4746
function Base.getindex(
48-
x::OffsetVector{Reactant.TracedRNumber{T},Reactant.TracedRArray{T,1}},
49-
indices::Base.OneTo{Int},
47+
x::OffsetVector{TracedRNumber{T},Reactant.TracedRArray{T,1}}, indices::Base.OneTo{Int}
5048
) where {T}
5149
offset_indices = indices .- x.offsets[1]
5250
return getindex(parent(x), offset_indices)
5351
end
5452

5553
parentindex(r::OffsetArrays.IdOffsetRange, i) = i .- r.offset
5654
function Base.getindex(
57-
a::OffsetArray{<:Reactant.TracedRNumber,N}, indices::Vararg{Union{Int,AbstractArray},N}
55+
a::OffsetArray{<:TracedRNumber,N}, indices::Vararg{Union{Int,AbstractArray},N}
5856
) where {N}
5957
J = map(parentindex, axes(a), indices)
6058
return parent(a)[J...]
6159
end
6260

63-
function Base.getindex(a::OffsetVector{<:Reactant.TracedRNumber}, indices::Int)
61+
function Base.getindex(a::OffsetVector{<:TracedRNumber}, indices::Int)
6462
J = parentindex(Base.axes1(a), indices)
6563
return parent(a)[J]
6664
end

0 commit comments

Comments
 (0)