Skip to content

Commit aead9df

Browse files
jumerckxwsmosesgiordano
authored
allow @trace track_numbers=false ... (#1026)
* allow `@trace track_numbers=false ...` * add kwarg * version bumps --------- Co-authored-by: William S. Moses <[email protected]> Co-authored-by: Mosè Giordano <[email protected]>
1 parent 79e509a commit aead9df

File tree

5 files changed

+51
-28
lines changed

5 files changed

+51
-28
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ Preferences = "1.4"
8585
PythonCall = "0.9"
8686
Random = "1.10"
8787
Random123 = "1.7"
88-
ReactantCore = "0.1.6"
88+
ReactantCore = "0.1.7"
8989
Reactant_jll = "0.0.100"
9090
Scratch = "1.2"
9191
Sockets = "1.10"

lib/ReactantCore/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ReactantCore"
22
uuid = "a3311ec8-5e00-46d5-b541-4f83e724a433"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>"]
4-
version = "0.1.6"
4+
version = "0.1.7"
55

66
[deps]
77
ExpressionExplorer = "21656369-7473-754a-2065-74616d696c43"

lib/ReactantCore/src/ReactantCore.jl

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -128,11 +128,25 @@ function fn(x)
128128
end
129129
```
130130
"""
131-
macro trace(expr)
131+
macro trace(args...)
132+
track_numbers = true
133+
expr = first(args)
134+
if length(args) > 1 && Meta.isexpr(args[1], :(=))
135+
tn_expr = args[1]
136+
tn_expr.args[1] == :track_numbers ||
137+
error("@trace supports setting track_numbers, but got $(tn_expr)")
138+
139+
track_numbers = tn_expr.args[2]
140+
expr = only(args[2:end])
141+
else
142+
expr = only(args)
143+
end
144+
track_numbers = track_numbers ? Number : Union{}
132145
expr = macroexpand(__module__, expr)
146+
133147
if Meta.isexpr(expr, :(=))
134148
if Meta.isexpr(expr.args[2], :if)
135-
return esc(trace_if_with_returns(__module__, expr))
149+
return esc(trace_if_with_returns(__module__, expr; track_numbers))
136150
end
137151
end
138152
Meta.isexpr(expr, :call) && return esc(trace_call(__module__, expr))
@@ -142,12 +156,12 @@ macro trace(expr)
142156
call = Expr(:call, fname, args...)
143157
return esc(trace_call(__module__, call))
144158
end
145-
Meta.isexpr(expr, :if) && return esc(trace_if(__module__, expr))
146-
Meta.isexpr(expr, :for) && return (esc(trace_for(__module__, expr)))
159+
Meta.isexpr(expr, :if) && return esc(trace_if(__module__, expr; track_numbers))
160+
Meta.isexpr(expr, :for) && return (esc(trace_for(__module__, expr; track_numbers)))
147161
return error("Only `if-elseif-else` blocks are currently supported by `@trace`")
148162
end
149163

150-
function trace_for(mod, expr)
164+
function trace_for(mod, expr; track_numbers)
151165
Meta.isexpr(expr, :for, 2) || error("expected for expr")
152166
assign, body = expr.args
153167

@@ -216,7 +230,9 @@ function trace_for(mod, expr)
216230
($counter + 1, results_...)
217231
end
218232

219-
$(ReactantCore).traced_while(cond_fn, body_fn, args)
233+
$(ReactantCore).traced_while(
234+
cond_fn, body_fn, args; track_numbers=$(track_numbers)
235+
)
220236
end
221237
end
222238

@@ -232,9 +248,9 @@ function trace_for(mod, expr)
232248
end
233249

234250
# ... = if ... style expressions
235-
function trace_if_with_returns(mod, expr)
251+
function trace_if_with_returns(mod, expr; track_numbers)
236252
new_expr, _, all_check_vars = trace_if(
237-
mod, expr.args[2]; store_last_line=expr.args[1], depth=1
253+
mod, expr.args[2]; store_last_line=expr.args[1], depth=1, track_numbers
238254
)
239255
cond_name = first(all_check_vars)
240256
original_cond = expr.args[2].args[1]
@@ -249,7 +265,7 @@ function trace_if_with_returns(mod, expr)
249265
end
250266
end
251267

252-
function trace_if(mod, expr; store_last_line=nothing, depth=0)
268+
function trace_if(mod, expr; store_last_line=nothing, depth=0, track_numbers)
253269
discard_vars_from_expansion = []
254270
original_expr = expr
255271

@@ -260,7 +276,9 @@ function trace_if(mod, expr; store_last_line=nothing, depth=0)
260276
expr = MacroTools.prewalk(expr) do x
261277
counter += 1
262278
if x isa Expr && x.head == :if && counter > 1
263-
ex_new, dv, _ = trace_if(mod, x; store_last_line, depth=depth + 1)
279+
ex_new, dv, _ = trace_if(
280+
mod, x; store_last_line, depth=depth + 1, track_numbers
281+
)
264282
append!(discard_vars_from_expansion, dv)
265283
return ex_new
266284
end
@@ -300,7 +318,7 @@ function trace_if(mod, expr; store_last_line=nothing, depth=0)
300318
if !(expr.args[3] isa Expr) || expr.args[3].head != :elseif
301319
expr.args[3], [], nothing
302320
else
303-
trace_if(mod, expr.args[3]; store_last_line, depth=depth + 1)
321+
trace_if(mod, expr.args[3]; store_last_line, depth=depth + 1, track_numbers)
304322
end
305323
elseif length(expr.args) == 2
306324
tmp_expr = []
@@ -388,7 +406,8 @@ function trace_if(mod, expr; store_last_line=nothing, depth=0)
388406
$(cond_name),
389407
$(true_branch_fn_name),
390408
$(false_branch_fn_name),
391-
($(all_input_vars...),),
409+
($(all_input_vars...),);
410+
track_numbers=$(track_numbers),
392411
)
393412
end
394413

@@ -458,7 +477,7 @@ function remove_shortcircuiting(expr)
458477
end
459478

460479
# Generate this dummy function and later we remove it during tracing
461-
function traced_if(cond, true_fn, false_fn, args)
480+
function traced_if(cond, true_fn, false_fn, args; track_numbers)
462481
return cond ? true_fn(args) : false_fn(args)
463482
end
464483

src/ControlFlow.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
function ReactantCore.traced_if(
2-
cond::TracedRNumber{Bool}, true_fn::TFn, false_fn::FFn, args
2+
cond::TracedRNumber{Bool}, true_fn::TFn, false_fn::FFn, args; track_numbers=Number
33
) where {TFn,FFn}
4-
return Ops.if_condition(cond, true_fn, false_fn, args...)
4+
return Ops.if_condition(cond, true_fn, false_fn, args...; track_numbers)
55
end
66

77
function ReactantCore.traced_call(f::Function, args...)
88
return Ops.call(f, args...)
99
end
1010

11-
function ReactantCore.traced_while(cond_fn::CFn, body_fn::BFn, args) where {CFn,BFn}
12-
return Ops.while_loop(cond_fn, body_fn, args...)
11+
function ReactantCore.traced_while(
12+
cond_fn::CFn, body_fn::BFn, args; track_numbers=Number
13+
) where {CFn,BFn}
14+
return Ops.while_loop(cond_fn, body_fn, args...; track_numbers)
1315
end

src/Ops.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1745,7 +1745,9 @@ use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead.
17451745
)
17461746
end
17471747

1748-
@noinline function while_loop(cond_fn::CFn, body_fn::BFn, args...) where {CFn,BFn}
1748+
@noinline function while_loop(
1749+
cond_fn::CFn, body_fn::BFn, args...; track_numbers
1750+
) where {CFn,BFn}
17491751
# TODO: detect and prevent mutation within the condition
17501752

17511753
# Make all the args traced or concrete
@@ -1754,7 +1756,7 @@ end
17541756
traced_args = Vector{Any}(undef, N)
17551757
for i in 1:N
17561758
@inbounds traced_args[i] = Reactant.make_tracer(
1757-
seen_args, args[i], (), Reactant.NoStopTracedTrack; track_numbers=Number
1759+
seen_args, args[i], (), Reactant.NoStopTracedTrack; track_numbers
17581760
)
17591761
end
17601762

@@ -1809,7 +1811,7 @@ end
18091811
end
18101812

18111813
@noinline function if_condition(
1812-
cond::TracedRNumber{Bool}, true_fn::TFn, false_fn::FFn, args...
1814+
cond::TracedRNumber{Bool}, true_fn::TFn, false_fn::FFn, args...; track_numbers
18131815
) where {TFn,FFn}
18141816
true_fn_names = (gensym(:true_fn_args), gensym(:true_result), gensym(:true_fn_resargs))
18151817
false_fn_names = (
@@ -1828,14 +1830,14 @@ end
18281830
args[i],
18291831
(true_fn_names[1], i),
18301832
Reactant.TracedSetPath;
1831-
track_numbers=Number,
1833+
track_numbers,
18321834
)
18331835
@inbounds fb_traced_args[i] = Reactant.make_tracer(
18341836
fb_seen_args,
18351837
args[i],
18361838
(false_fn_names[1], i),
18371839
Reactant.TracedSetPath;
1838-
track_numbers=Number,
1840+
track_numbers,
18391841
)
18401842
end
18411843

@@ -1899,15 +1901,15 @@ end
18991901
tb_result,
19001902
(true_fn_names[2],),
19011903
Reactant.NoStopTracedTrack;
1902-
track_numbers=Number,
1904+
track_numbers,
19031905
)
19041906
for i in eachindex(tb_linear_args)
19051907
Reactant.make_tracer(
19061908
seen_true_results,
19071909
tb_linear_args[i],
19081910
(true_fn_names[3], i),
19091911
Reactant.NoStopTracedTrack;
1910-
track_numbers=Number,
1912+
track_numbers,
19111913
)
19121914
end
19131915

@@ -1964,15 +1966,15 @@ end
19641966
fb_result,
19651967
(false_fn_names[2],),
19661968
Reactant.NoStopTracedTrack;
1967-
track_numbers=Number,
1969+
track_numbers,
19681970
)
19691971
for i in eachindex(fb_linear_args)
19701972
Reactant.make_tracer(
19711973
seen_false_results,
19721974
fb_linear_args[i],
19731975
(false_fn_names[3], i),
19741976
Reactant.NoStopTracedTrack;
1975-
track_numbers=Number,
1977+
track_numbers,
19761978
)
19771979
end
19781980

0 commit comments

Comments
 (0)