Skip to content

Commit 79f9dc5

Browse files
authored
chore: updates for new JLL (#1190)
* fix: minor fixes * fix: run passes 3rd time * chore: bump jll * test: simple RNN test for whileOp autodiff
1 parent 164e68e commit 79f9dc5

File tree

4 files changed

+48
-20
lines changed

4 files changed

+48
-20
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>", "Mosè Giordano <[email protected]>"]
4-
version = "0.2.74"
4+
version = "0.2.75"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -87,7 +87,7 @@ PythonCall = "0.9"
8787
Random = "1.10"
8888
Random123 = "1.7"
8989
ReactantCore = "0.1.9"
90-
Reactant_jll = "0.0.146"
90+
Reactant_jll = "0.0.147"
9191
ScopedValues = "1.3.0"
9292
Scratch = "1.2"
9393
Sockets = "1.10"

src/Compiler.jl

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,7 @@ function optimization_passes(;
779779
"transpose_wrap",
780780
"transpose_extend",
781781
"transpose_rotate",
782+
"transpose_dynamic_slice",
782783
],
783784
)
784785
if AGGRESSIVE_PROPAGATION[]
@@ -1248,6 +1249,7 @@ function compile_mlir!(
12481249
"enzyme-batch",
12491250
opt_passes2,
12501251
enzyme_pass,
1252+
opt_passes2,
12511253
"canonicalize",
12521254
"remove-unnecessary-enzyme-ops",
12531255
"enzyme-simplify-math",
@@ -1269,6 +1271,7 @@ function compile_mlir!(
12691271
"enzyme-batch",
12701272
opt_passes2,
12711273
enzyme_pass,
1274+
opt_passes2,
12721275
"canonicalize",
12731276
"remove-unnecessary-enzyme-ops",
12741277
"enzyme-simplify-math",
@@ -1287,6 +1290,7 @@ function compile_mlir!(
12871290
"enzyme-batch",
12881291
opt_passes2,
12891292
enzyme_pass,
1293+
opt_passes2,
12901294
"canonicalize",
12911295
"remove-unnecessary-enzyme-ops",
12921296
"enzyme-simplify-math",
@@ -1307,6 +1311,7 @@ function compile_mlir!(
13071311
"enzyme-batch",
13081312
opt_passes2,
13091313
enzyme_pass,
1314+
opt_passes2,
13101315
"canonicalize",
13111316
"remove-unnecessary-enzyme-ops",
13121317
"enzyme-simplify-math",
@@ -1326,6 +1331,7 @@ function compile_mlir!(
13261331
"enzyme-batch",
13271332
opt_passes2,
13281333
enzyme_pass,
1334+
opt_passes2,
13291335
"canonicalize",
13301336
"remove-unnecessary-enzyme-ops",
13311337
"enzyme-simplify-math",
@@ -1397,16 +1403,7 @@ function compile_mlir!(
13971403
error("Invalid optimize option: $(Meta.quot(optimize))")
13981404
end
13991405

1400-
# HACK: remove with next JLL
14011406
if !(optimize isa String)
1402-
if transpose_propagate === :up
1403-
run_pass_pipeline!(
1404-
mod,
1405-
"enzyme-hlo-generate-td{patterns=transpose_while},transform-interpreter,enzyme-hlo-remove-transform",
1406-
"transpose_while",
1407-
)
1408-
end
1409-
14101407
if optimize (:none, :just_batch, :canonicalize) &&
14111408
(transpose_propagate === :up || reshape_propagate === :up)
14121409
# We tried propagating reshapes and transposes up. If at this point we are left with
@@ -1425,7 +1422,6 @@ function compile_mlir!(
14251422
end
14261423

14271424
# Now we resolve paddings if `optimize_then_pad`
1428-
prepad_fnname = fnname
14291425
if optimize_then_pad
14301426
padded_inputs = IdDict()
14311427
has_padded_inputs = false

src/stdlibs/Random.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ using ..Reactant:
1414
Reactant,
1515
TracedUtils,
1616
Ops,
17-
ConcretePJRTArray,
18-
ConcretePJRTNumber,
17+
AbstractConcreteArray,
18+
AbstractConcreteNumber,
1919
unwrapped_eltype
2020
using Random: Random, AbstractRNG
2121

@@ -48,29 +48,29 @@ end
4848
end
4949

5050
@noinline function Random.seed!(rng::ConcreteRNG, seed::Number)
51-
seed isa ConcretePJRTNumber && (seed = unwrapped_eltype(seed)(seed))
51+
seed isa AbstractConcreteNumber && (seed = unwrapped_eltype(seed)(seed))
5252
seed = reinterpret(UInt64, Random.hash_seed(seed))
53-
return Random.seed!(rng, ConcretePJRTArray(seed))
53+
return Random.seed!(rng, Reactant.to_rarray(seed))
5454
end
5555

5656
@noinline function Random.seed!(rng::ConcreteRNG, seed::AbstractVector{<:Integer})
5757
return Random.seed!(rng, seed)
5858
end
5959

6060
@noinline function Random.seed!(rng::ConcreteRNG, seed::AbstractVector{UInt64})
61-
return Random.seed!(rng, ConcretePJRTArray(seed))
61+
return Random.seed!(rng, Reactant.to_rarray(seed))
6262
end
6363

64-
@noinline function Random.seed!(rng::ConcreteRNG, seed::ConcretePJRTArray{UInt64,1})
64+
@noinline function Random.seed!(rng::ConcreteRNG, seed::AbstractConcreteArray{UInt64,1})
6565
rng.seed = seed
6666
return rng
6767
end
6868

6969
Base.copy(rng::ConcreteRNG) = ConcreteRNG(copy(rng.seed), rng.algorithm)
7070
Base.copy(rng::TracedRNG) = TracedRNG(copy(rng.seed), rng.algorithm)
7171

72-
@noinline ConcreteRNG() = ConcreteRNG(ConcretePJRTArray(make_seed()))
73-
@noinline ConcreteRNG(seed::ConcretePJRTArray{UInt64,1}) = ConcreteRNG(seed, "DEFAULT")
72+
@noinline ConcreteRNG() = ConcreteRNG(Reactant.to_rarray(make_seed()))
73+
@noinline ConcreteRNG(seed::AbstractConcreteArray{UInt64,1}) = ConcreteRNG(seed, "DEFAULT")
7474

7575
@noinline default_rng() = ConcreteRNG()
7676

test/nn/lux.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@ function loss_function(model, x, y, ps, st)
55
return CrossEntropyLoss()(y_hat, y)
66
end
77

8+
function loss_function(model, x, ps, st)
9+
y_hat, _ = model(x, ps, st)
10+
return sum(abs2, y_hat)
11+
end
12+
813
function gradient_loss_function(model, x, y, ps, st)
914
dps = Enzyme.make_zero(ps)
1015
_, res = Enzyme.autodiff(
@@ -20,6 +25,20 @@ function gradient_loss_function(model, x, y, ps, st)
2025
return res, dps
2126
end
2227

28+
function gradient_loss_function(model, x, ps, st)
29+
dps = Enzyme.make_zero(ps)
30+
_, res = Enzyme.autodiff(
31+
set_runtime_activity(ReverseWithPrimal),
32+
loss_function,
33+
Active,
34+
Const(model),
35+
Const(x),
36+
Duplicated(ps, dps),
37+
Const(st),
38+
)
39+
return res, dps
40+
end
41+
2342
@testset "Lux.jl Integration" begin
2443
# Generate some data for the XOR problem: vectors of length 2, as columns of a matrix:
2544
noisy = rand(Float32, 2, 1000) # 2×1000 Matrix{Float32}
@@ -67,3 +86,16 @@ end
6786
@test_skip dps1 dps2 atol = 1e-3 rtol = 1e-2
6887
end
6988
end
89+
90+
@testset "RNN Integration" begin
91+
using Reactant, Lux, Enzyme, Random
92+
93+
model = Recurrence(RNNCell(4 => 4); ordering=BatchLastIndex())
94+
ps, st = Reactant.to_rarray(Lux.setup(Random.default_rng(), model))
95+
96+
x = Reactant.to_rarray(rand(Float32, 4, 16, 12))
97+
98+
# This test requires running optimizations between the enzyme autodiff passes
99+
res, ∂ps = @jit gradient_loss_function(model, x, ps, st)
100+
@test res isa Reactant.ConcreteRNumber
101+
end

0 commit comments

Comments
 (0)