Skip to content

Commit af79cee

Browse files
authored
fix custom trace paths (#1261)
* fix custom trace paths * test custom trace paths * try fix test
1 parent 2cbeb46 commit af79cee

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

src/TracedUtils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,7 @@ function prepare_mlir_fn_args(
500500
string(idx)
501501
end
502502
stridx *= "." * fldname
503-
aval = getfield(aval, idx)
503+
aval = Reactant.Compiler.traced_getfield(aval, idx)
504504
end
505505
end
506506
MLIR.IR.push_argument!(

test/compile.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,45 @@ end
147147
x = [:a, :b, :a]
148148
@test @jit(unique(x)) == [:a, :b]
149149
end
150+
151+
@testset "custom trace path" begin
152+
struct MockTestCustomPath{T}
153+
x::T
154+
end
155+
156+
function Reactant.Compiler.make_tracer(
157+
seen, prev::MockTestCustomPath, path, mode; kwargs...
158+
)
159+
custom_path = Reactant.append_path(path, (; custom_id=1))
160+
traced_x = Reactant.make_tracer(seen, prev.x, custom_path, mode; kwargs...)
161+
return MockTestCustomPath(traced_x)
162+
end
163+
164+
function Reactant.traced_getfield(
165+
x::MockTestCustomPath, fld::@NamedTuple{custom_id::Int}
166+
)
167+
return if fld.custom_id == 1
168+
x.x
169+
else
170+
error("this is awkward... shouldn't have reach here")
171+
end
172+
end
173+
174+
function Reactant.Compiler.create_result(
175+
tocopy::MockTestCustomPath, path, result_stores
176+
)
177+
custom_path = Reactant.append_path(path, (; custom_id=1))
178+
res_x = Reactant.Compiler.create_result(tocopy.x, custom_path, result_stores)
179+
return :($MockTestCustomPath($res_x))
180+
end
181+
182+
fcustom_path(x) = MockTestCustomPath(x.x)
183+
184+
x = MockTestCustomPath(ones(Int))
185+
xre = MockTestCustomPath(Reactant.to_rarray(x.x))
186+
187+
y = @jit fcustom_path(xre)
188+
@test y isa MockTestCustomPath
189+
@test y.x isa Reactant.RArray
190+
@test y.x == fcustom_path(x).x
191+
end

0 commit comments

Comments
 (0)