|
147 | 147 | x = [:a, :b, :a]
|
148 | 148 | @test @jit(unique(x)) == [:a, :b]
|
149 | 149 | end
|
| 150 | + |
| 151 | +@testset "custom trace path" begin |
| 152 | + struct MockTestCustomPath{T} |
| 153 | + x::T |
| 154 | + end |
| 155 | + |
| 156 | + function Reactant.Compiler.make_tracer( |
| 157 | + seen, prev::MockTestCustomPath, path, mode; kwargs... |
| 158 | + ) |
| 159 | + custom_path = Reactant.append_path(path, (; custom_id=1)) |
| 160 | + traced_x = Reactant.make_tracer(seen, prev.x, custom_path, mode; kwargs...) |
| 161 | + return MockTestCustomPath(traced_x) |
| 162 | + end |
| 163 | + |
| 164 | + function Reactant.traced_getfield( |
| 165 | + x::MockTestCustomPath, fld::@NamedTuple{custom_id::Int} |
| 166 | + ) |
| 167 | + return if fld.custom_id == 1 |
| 168 | + x.x |
| 169 | + else |
| 170 | + error("this is awkward... shouldn't have reach here") |
| 171 | + end |
| 172 | + end |
| 173 | + |
| 174 | + function Reactant.Compiler.create_result( |
| 175 | + tocopy::MockTestCustomPath, path, result_stores |
| 176 | + ) |
| 177 | + custom_path = Reactant.append_path(path, (; custom_id=1)) |
| 178 | + res_x = Reactant.Compiler.create_result(tocopy.x, custom_path, result_stores) |
| 179 | + return :($MockTestCustomPath($res_x)) |
| 180 | + end |
| 181 | + |
| 182 | + fcustom_path(x) = MockTestCustomPath(x.x) |
| 183 | + |
| 184 | + x = MockTestCustomPath(ones(Int)) |
| 185 | + xre = MockTestCustomPath(Reactant.to_rarray(x.x)) |
| 186 | + |
| 187 | + y = @jit fcustom_path(xre) |
| 188 | + @test y isa MockTestCustomPath |
| 189 | + @test y.x isa Reactant.RArray |
| 190 | + @test y.x == fcustom_path(x).x |
| 191 | +end |
0 commit comments