@@ -195,20 +195,18 @@ end
195
195
@test res2 ≈ 4 * 3 * 3.1 ^ 2
196
196
end
197
197
198
- if ! contains (string (Reactant. devices ()[1 ]), " TPU" )
199
- @testset " Seed initialization of Complex arrays on matmul: Issue #593" begin
198
+ @testset " Seed initialization of Complex arrays on matmul: Issue #593" begin
199
+ df (x, y) = Enzyme. gradient (ReverseWithPrimal, * , x, y)
200
+ @test begin
200
201
a = ones (ComplexF64, 2 , 2 )
201
202
b = 2.0 * ones (ComplexF64, 2 , 2 )
202
203
a_re = Reactant. to_rarray (a)
203
204
b_re = Reactant. to_rarray (b)
204
- df (x, y) = Enzyme. gradient (ReverseWithPrimal, * , x, y)
205
- @test begin
206
- res = @jit df (a_re, b_re) # before, this segfaulted
207
- (res. val ≈ 4 ones (2 , 2 )) &&
208
- (res. derivs[1 ] ≈ 4 ones (2 , 2 )) &&
209
- (res. derivs[2 ] ≈ 2 ones (2 , 2 ))
210
- end
211
- end
205
+ res = @jit df (a_re, b_re) # before, this segfaulted
206
+ (res. val ≈ 4 ones (2 , 2 )) &&
207
+ (res. derivs[1 ] ≈ 4 ones (2 , 2 )) &&
208
+ (res. derivs[2 ] ≈ 2 ones (2 , 2 ))
209
+ end skip = contains (string (Reactant. devices ()[1 ]), " TPU" )
212
210
end
213
211
214
212
@testset " onehot" begin
257
255
258
256
@testset " seed" begin
259
257
x = Reactant. to_rarray (rand (2 , 2 ))
260
- st = (; rng= Reactant. ConcreteRNG ())
258
+ st = (; rng= Reactant. ReactantRNG ())
261
259
262
260
@test begin
263
261
hlo = @code_hlo gradient_fn (x, st)
0 commit comments