1
+ using Reactant, Test, Adapt
2
+
3
+ struct MyGrid{FT,AT} <: AbstractVector{FT}
4
+ data:: AT
5
+ radius:: FT
6
+ end
7
+
8
+ Adapt. parent (x:: MyGrid ) = x. data
9
+
10
+ Base. getindex (x:: MyGrid , args... ) = Base. getindex (x. data, args... )
11
+
12
+ Base. size (x:: MyGrid ) = Base. size (x. data)
13
+
14
+ function Base. show (io:: IOty , X:: MyGrid ) where {IOty<: Union{IO,IOContext} }
15
+ print (io, Core. Typeof (X), " (" )
16
+ if Adapt. parent (X) != = X
17
+ Base. show (io, Adapt. parent (X))
18
+ end
19
+ return print (io, " )" )
20
+ end
21
+
22
+ Base. @nospecializeinfer function Reactant. traced_type_inner (
23
+ @nospecialize (OA:: Type{MyGrid{FT,AT}} ),
24
+ seen,
25
+ mode:: Reactant.TraceMode ,
26
+ @nospecialize (track_numbers:: Type ),
27
+ @nospecialize (sharding),
28
+ @nospecialize (runtime)
29
+ ) where {FT,AT}
30
+ FT2 = Reactant. traced_type_inner (FT, seen, mode, track_numbers, sharding, runtime)
31
+ AT2 = Reactant. traced_type_inner (AT, seen, mode, track_numbers, sharding, runtime)
32
+
33
+ for NF in (AT2,)
34
+ FT2 = Reactant. promote_traced_type (FT2, eltype (NF))
35
+ end
36
+
37
+ res = MyGrid{FT2,AT2}
38
+ return res
39
+ end
40
+
41
+ @inline Reactant. make_tracer (seen, @nospecialize (prev:: MyGrid ), args... ; kwargs... ) =
42
+ Reactant. make_tracer_via_immutable_constructor (seen, prev, args... ; kwargs... )
43
+
44
+ struct MyGrid2{FT,AT} <: AbstractVector{FT}
45
+ data:: AT
46
+ radius:: FT
47
+ bar:: FT
48
+ end
49
+
50
+ Adapt. parent (x:: MyGrid2 ) = x. data
51
+
52
+ Base. getindex (x:: MyGrid2 , args... ) = Base. getindex (x. data, args... )
53
+
54
+ Base. size (x:: MyGrid2 ) = Base. size (x. data)
55
+
56
+ function Base. show (io:: IOty , X:: MyGrid2 ) where {IOty<: Union{IO,IOContext} }
57
+ print (io, Core. Typeof (X), " (" )
58
+ if Adapt. parent (X) != = X
59
+ Base. show (io, Adapt. parent (X))
60
+ end
61
+ return print (io, " )" )
62
+ end
63
+
64
+ Base. @nospecializeinfer function Reactant. traced_type_inner (
65
+ @nospecialize (OA:: Type{MyGrid2{FT,AT}} ),
66
+ seen,
67
+ mode:: Reactant.TraceMode ,
68
+ @nospecialize (track_numbers:: Type ),
69
+ @nospecialize (sharding),
70
+ @nospecialize (runtime)
71
+ ) where {FT,AT}
72
+ FT2 = Reactant. traced_type_inner (FT, seen, mode, track_numbers, sharding, runtime)
73
+ AT2 = Reactant. traced_type_inner (AT, seen, mode, track_numbers, sharding, runtime)
74
+
75
+ for NF in (AT2,)
76
+ FT2 = Reactant. promote_traced_type (FT2, eltype (NF))
77
+ end
78
+
79
+ res = MyGrid2{FT2,AT2}
80
+ return res
81
+ end
82
+
83
+ @inline Reactant. make_tracer (seen, @nospecialize (prev:: MyGrid2 ), args... ; kwargs... ) =
84
+ Reactant. make_tracer_via_immutable_constructor (seen, prev, args... ; kwargs... )
85
+
86
+ function update! (g)
87
+ @allowscalar g. data[1 ] = g. radius
88
+ return nothing
89
+ end
90
+
91
+ function selfreturn (g)
92
+ return g
93
+ end
94
+
95
+ function call_update! (g)
96
+ @trace update! (g)
97
+ end
98
+
99
+ function call_selfreturn (g)
100
+ @trace selfreturn (g)
101
+ end
102
+
103
+ @testset " Custom construction" begin
104
+ g = MyGrid ([3.14 , 1.59 ], 2.7 )
105
+ rg = Reactant. to_rarray (g)
106
+
107
+ @jit update! (rg)
108
+ @test convert (Array, rg. data) == [2.7 , 1.59 ]
109
+
110
+ rg = Reactant. to_rarray (g)
111
+ res = @jit selfreturn (rg)
112
+ @test convert (Array, res. data) == [3.14 , 1.59 ]
113
+ @test res. radius == 2.7
114
+ @show typeof (res)
115
+ @test typeof (res. radius) <: ConcreteRNumber
116
+
117
+ rg = Reactant. to_rarray (g)
118
+
119
+ @jit call_update! (rg)
120
+ @test convert (Array, rg. data) == [2.7 , 1.59 ]
121
+
122
+ rg = Reactant. to_rarray (g)
123
+ res = @jit call_selfreturn (rg)
124
+ @test convert (Array, res. data) == [3.14 , 1.59 ]
125
+ @test res. radius == 2.7
126
+ @show typeof (res)
127
+ @test typeof (res. radius) <: ConcreteRNumber
128
+ end
129
+
130
+ @testset " Custom construction2 " begin
131
+ g = Ref (MyGrid ([3.14 , 1.59 ], 2.7 ))
132
+ g = (g, g)
133
+
134
+ rg = Reactant. to_rarray (g)
135
+ res = @jit selfreturn (rg)
136
+ @test convert (Array, res[1 ][]. data) == [3.14 , 1.59 ]
137
+ @test convert (Array, res[2 ][]. data) == [3.14 , 1.59 ]
138
+ @test res[1 ][]. data == res[2 ][]. data
139
+ end
0 commit comments