@@ -12,34 +12,37 @@ using ..Reactant:
12
12
Ops,
13
13
MLIR
14
14
15
- using .. TracedUtils: TracedUtils, get_mlir_data, materialize_traced_array, set_mlir_data!
15
+ using ReactantCore: ReactantCore
16
+ using ReactantCore: materialize_traced_array
17
+
18
+ using .. TracedUtils: TracedUtils, get_mlir_data, set_mlir_data!
16
19
17
20
using LinearAlgebra
18
21
19
22
# Various Wrapper Arrays defined in LinearAlgebra
20
- function TracedUtils . materialize_traced_array (
23
+ function ReactantCore . materialize_traced_array (
21
24
x:: Transpose{TracedRNumber{T},<:AnyTracedRArray}
22
25
) where {T}
23
- px = TracedUtils . materialize_traced_array (parent (x))
26
+ px = materialize_traced_array (parent (x))
24
27
A = ndims (px) == 1 ? reshape (px, :, 1 ) : px
25
28
return permutedims (A, (2 , 1 ))
26
29
end
27
30
28
- function TracedUtils . materialize_traced_array (
31
+ function ReactantCore . materialize_traced_array (
29
32
x:: Adjoint{TracedRNumber{T},<:AnyTracedRArray}
30
33
) where {T}
31
34
return Ops. conj (
32
35
materialize_traced_array (transpose (materialize_traced_array (parent (x))))
33
36
)
34
37
end
35
38
36
- function TracedUtils . materialize_traced_array (
39
+ function ReactantCore . materialize_traced_array (
37
40
x:: Diagonal{TracedRNumber{T},<:AnyTracedRVector}
38
41
) where {T}
39
42
return diagm (materialize_traced_array (parent (x)))
40
43
end
41
44
42
- function TracedUtils . materialize_traced_array (
45
+ function ReactantCore . materialize_traced_array (
43
46
x:: Tridiagonal{TracedRNumber{T},<:AnyTracedRVector}
44
47
) where {T}
45
48
return diagm (- 1 => x. dl, 0 => x. d, 1 => x. du)
48
51
for (AT, comp) in ((:LowerTriangular , " GE" ), (:UpperTriangular , " LE" ))
49
52
uAT = Symbol (:Unit , AT)
50
53
@eval begin
51
- function TracedUtils . materialize_traced_array (
54
+ function ReactantCore . materialize_traced_array (
52
55
x:: $ (AT){TracedRNumber{T},<: AnyTracedRMatrix }
53
56
) where {T}
54
57
m, n = size (x)
55
- px = TracedUtils . materialize_traced_array (parent (x))
58
+ px = materialize_traced_array (parent (x))
56
59
row_idxs = Ops. iota (Int, [m, n]; iota_dimension= 1 )
57
60
col_idxs = Ops. iota (Int, [m, n]; iota_dimension= 2 )
58
61
indicator = Ops. compare (row_idxs, col_idxs; comparison_direction= $ (comp))
59
62
return Ops. select (indicator, px, zero (px))
60
63
end
61
64
62
- function TracedUtils . materialize_traced_array (
65
+ function ReactantCore . materialize_traced_array (
63
66
x:: $ (uAT){TracedRNumber{T},<: AnyTracedRMatrix }
64
67
) where {T}
65
68
m, n = size (x)
66
- px = TracedUtils . materialize_traced_array (parent (x))
69
+ px = materialize_traced_array (parent (x))
67
70
row_idxs = Ops. iota (Int, [m, n]; iota_dimension= 1 )
68
71
col_idxs = Ops. iota (Int, [m, n]; iota_dimension= 2 )
69
72
nondiag_indicator = Ops. compare (row_idxs, col_idxs; comparison_direction= " NE" )
@@ -73,7 +76,7 @@ for (AT, comp) in ((:LowerTriangular, "GE"), (:UpperTriangular, "LE"))
73
76
end
74
77
end
75
78
76
- function TracedUtils . materialize_traced_array (
79
+ function ReactantCore . materialize_traced_array (
77
80
x:: Symmetric{TracedRNumber{T},<:AnyTracedRMatrix}
78
81
) where {T}
79
82
m, n = size (x)
0 commit comments