Skip to content

Commit 3579eb2

Browse files
authored
nicer error for type mutation in traced loop body (#1035)
1 parent 35d34f1 commit 3579eb2

File tree

4 files changed

+20
-4
lines changed

4 files changed

+20
-4
lines changed

lib/ReactantCore/src/ReactantCore.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ function trace_for(mod, expr; track_numbers)
231231
end
232232

233233
$(ReactantCore).traced_while(
234-
cond_fn, body_fn, args; track_numbers=$(track_numbers)
234+
cond_fn, body_fn, args; track_numbers=$(track_numbers), verify_arg_names=$(QuoteNode(args_init))
235235
)
236236
end
237237
end

src/ControlFlow.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ function ReactantCore.traced_call(f::Function, args...)
99
end
1010

1111
function ReactantCore.traced_while(
12-
cond_fn::CFn, body_fn::BFn, args; track_numbers=Number
12+
cond_fn::CFn, body_fn::BFn, args; track_numbers=Number, verify_arg_names=nothing
1313
) where {CFn,BFn}
14-
return Ops.while_loop(cond_fn, body_fn, args...; track_numbers)
14+
@warn verify_arg_names
15+
return Ops.while_loop(cond_fn, body_fn, args...; track_numbers, verify_arg_names)
1516
end

src/Ops.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1746,7 +1746,7 @@ use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead.
17461746
end
17471747

17481748
@noinline function while_loop(
1749-
cond_fn::CFn, body_fn::BFn, args...; track_numbers
1749+
cond_fn::CFn, body_fn::BFn, args...; track_numbers, verify_arg_names=nothing
17501750
) where {CFn,BFn}
17511751
# TODO: detect and prevent mutation within the condition
17521752

@@ -1780,6 +1780,7 @@ end
17801780
do_transpose=false,
17811781
).f
17821782

1783+
@warn verify_arg_names
17831784
body_fn_compiled =
17841785
Reactant.TracedUtils.make_mlir_fn(
17851786
body_fn,
@@ -1790,6 +1791,7 @@ end
17901791
return_dialect=:stablehlo,
17911792
args_in_result=:none,
17921793
do_transpose=false,
1794+
verify_arg_names
17931795
).f
17941796

17951797
cond_reg = Reactant.TracedUtils.__take_region(cond_fn_compiled)

src/TracedUtils.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ function make_mlir_fn(
189189
input_shardings=nothing, # This is not meant to be used by the user.
190190
output_shardings=nothing, # This is not meant to be used by the user.
191191
runtime=nothing,
192+
verify_arg_names=nothing,
192193
)
193194
if sizeof(typeof(f)) != 0 || f isa Base.BroadcastFunction
194195
mlir_fn_res = make_mlir_fn(
@@ -347,6 +348,18 @@ function make_mlir_fn(
347348
if args_in_result == :mutated
348349
append!(linear_results, linear_args[mutated_args])
349350
end
351+
if !isnothing(verify_arg_names) && typeof.(linear_args) != typeof.(linear_results)
352+
@assert length(linear_args) <= length(linear_results)
353+
argis = first.(get_argidx.(linear_args))
354+
resis = Set(getindex.(get_residx.(linear_results), Ref(2)))
355+
# this can be more efficient
356+
conflicts = setdiff(resis, argis)
357+
@assert !isempty(conflicts) "Expected to have some conflicts, but none were found."
358+
359+
error("""Types do not match between function arguments and results.
360+
The following arguments should be traced: $(join(verify_arg_names.args[collect(conflicts)], ", "))
361+
""")
362+
end
350363

351364
out_tys = if do_transpose
352365
[transpose_ty(Ops.mlir_type(arg)) for arg in linear_results]

0 commit comments

Comments
 (0)