Skip to content

Commit 628c576

Browse files
committed
test: fixes
1 parent 7c19933 commit 628c576

File tree

3 files changed

+14
-14
lines changed

3 files changed

+14
-14
lines changed

ext/ReactantFillArraysExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module ReactantFillArraysExt
22

3-
using Reactant: Reactant, TracedUtils, TracedRNumber, Ops, Sharding, unwrapped_eltype
3+
using Reactant: Reactant, TracedRNumber, Sharding, unwrapped_eltype
44
using ReactantCore: ReactantCore
55
using FillArrays: FillArrays, AbstractFill, Fill, Ones, Zeros, OneElement
66
using GPUArraysCore: @allowscalar

test/core/wrapped_arrays.jl

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using Reactant, Test, Statistics, NNlib, LinearAlgebra
1+
using Reactant, Test, Statistics, LinearAlgebra
22

33
function view_getindex_1(x)
44
x = view(x, 2:3, 1:2, :)
@@ -77,18 +77,6 @@ end
7777
@test v1 v2
7878
end
7979

80-
function btranspose_badjoint(x)
81-
x1 = NNlib.batched_transpose(x)
82-
x2 = NNlib.batched_adjoint(x)
83-
return x1 .+ x2
84-
end
85-
86-
@testset "batched transpose/adjoint" begin
87-
x = rand(4, 2, 3)
88-
x_ra = Reactant.to_rarray(x)
89-
@test @jit(btranspose_badjoint(x_ra)) btranspose_badjoint(x)
90-
end
91-
9280
function bypass_permutedims(x)
9381
x = PermutedDimsArray(x, (2, 1, 3)) # Don't use permutedims here
9482
return view(x, 2:3, 1:2, :)

test/integration/NNlib/runtests.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -792,3 +792,15 @@ end
792792
hlo = repr(@code_hlo(NNlib.gather(x_ra, idxs_ra)))
793793
@test !contains(hlo, "i64>")
794794
end
795+
796+
function btranspose_badjoint(x)
797+
x1 = NNlib.batched_transpose(x)
798+
x2 = NNlib.batched_adjoint(x)
799+
return x1 .+ x2
800+
end
801+
802+
@testset "batched transpose/adjoint" begin
803+
x = rand(4, 2, 3)
804+
x_ra = Reactant.to_rarray(x)
805+
@test @jit(btranspose_badjoint(x_ra)) btranspose_badjoint(x)
806+
end

0 commit comments

Comments
 (0)