Skip to content

Commit 9e8eec0

Browse files
Pangorawmofeing
andauthored
implement @trace for (#255)
* implement `@trace` for * Apply suggestions from code review Co-authored-by: Sergio Sánchez Ramírez <[email protected]> * Allow using induction variable * floating point ranges * lightspeed * import LinearAlgebra * generate 0 to N loop * ir test * clean iter * Revert "Apply suggestions from code review" This reverts commit 079ed4a. * remove precompilation warning * format and fix * loop ranges as traced numbers * fmt and add non unit step test * fmt2 * integers --------- Co-authored-by: Sergio Sánchez Ramírez <[email protected]>
1 parent f2a91bf commit 9e8eec0

File tree

6 files changed

+299
-18
lines changed

6 files changed

+299
-18
lines changed

lib/ReactantCore/src/ReactantCore.jl

Lines changed: 97 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ if no traced value is found inside the expression, then there is no overhead.
3131
- `if` conditions (with `elseif` and other niceties) (`@trace if ...`)
3232
- `if` statements with a preceeding assignment (`@trace a = if ...`) (note the positioning
3333
of the macro needs to be before the assignment and not before the `if`)
34+
- `for` statements with a single induction variable iterating over a syntactic `StepRange` of integers.
3435
3536
## Special Considerations
3637
@@ -81,6 +82,15 @@ end
8182
This will not compile since `y` is a `Float32` in one branch and a `Float64` in the other.
8283
You need to ensure that all branches have the same type.
8384
85+
Another example is the following for loop which changes the type of `x` between iterations.
86+
87+
```julia
88+
x = ... # ConcreteRArray{Int64, 1}
89+
for i in 1f0:0.5f0:10f0
90+
x = x .+ i # ConcreteRArray{Float32, 1}
91+
end
92+
```
93+
8494
### Certain Symbols are Reserved
8595
8696
Symbols like $(SPECIAL_SYMBOLS) are not allowed as variables in `@trace` expressions. While certain cases might work but these are not guaranteed to work. For
@@ -100,15 +110,84 @@ end
100110
"""
101111
macro trace(expr)
102112
expr = macroexpand(__module__, expr)
103-
if expr.head == :(=)
104-
if expr.args[2] isa Expr && expr.args[2].head == :if
113+
if Meta.isexpr(expr, :(=))
114+
if Meta.isexpr(expr.args[2], :if)
105115
return esc(trace_if_with_returns(__module__, expr))
106116
end
107117
end
108-
expr.head == :if && return esc(trace_if(__module__, expr))
118+
Meta.isexpr(expr, :if) && return esc(trace_if(__module__, expr))
119+
Meta.isexpr(expr, :for) && return (esc(trace_for(__module__, expr)))
109120
return error("Only `if-elseif-else` blocks are currently supported by `@trace`")
110121
end
111122

123+
function trace_for(mod, expr)
124+
Meta.isexpr(expr, :for, 2) || error("expected for expr")
125+
assign, body = expr.args
126+
127+
error_if_any_control_flow(body)
128+
if !Meta.isexpr(assign, :(=)) ||
129+
!(assign.args[1] isa Symbol) ||
130+
!Meta.isexpr(assign.args[2], :call) ||
131+
assign.args[2].args[1] !== :(:)
132+
error("malformed for loop assignment")
133+
end
134+
135+
induction, range = assign.args
136+
137+
counter = gensym(:i)
138+
num_iters = gensym(:num_iters)
139+
140+
start = range.args[2]
141+
step = length(range.args) == 3 ? 1 : range.args[3]
142+
limit = range.args[end]
143+
144+
body_symbols = ExpressionExplorer.compute_symbols_state(
145+
quote
146+
$(Expr(:local, assign))
147+
$body
148+
end,
149+
)
150+
151+
external_syms = body_symbols.assignments body_symbols.references
152+
filter!((SPECIAL_SYMBOLS), external_syms)
153+
154+
all_syms = Expr(:tuple, counter, external_syms...)
155+
args_init = Expr(
156+
:tuple, :(Reactant.promote_to(Reactant.TracedRNumber{Int}, 0)), external_syms...
157+
)
158+
159+
reactant_code_block = quote
160+
let args = $(args_init)
161+
cond_fn =
162+
$(all_syms) -> begin
163+
local num_iters = div($limit - $start, $step, RoundDown)
164+
local num_iters = Reactant.promote_to(
165+
Reactant.TracedRNumber{Int64}, num_iters
166+
)
167+
$counter < num_iters + 1
168+
end
169+
body_fn =
170+
$(all_syms) -> begin
171+
local step_ = $step
172+
local start_ = $start
173+
local $induction = start_ + $counter * step_
174+
$body
175+
($counter + 1, $(all_syms.args[(begin + 1):end]...))
176+
end
177+
178+
$(ReactantCore).traced_while(cond_fn, body_fn, args)
179+
end
180+
end
181+
182+
return quote
183+
if any($(is_traced), $(Expr(:tuple, all_syms.args[(begin + 1):end]...)))
184+
$(reactant_code_block)
185+
else
186+
$(expr)
187+
end
188+
end
189+
end
190+
112191
# ... = if ... style expressions
113192
function trace_if_with_returns(mod, expr)
114193
new_expr, _, all_check_vars = trace_if(
@@ -128,7 +207,7 @@ function trace_if(mod, expr; store_last_line=nothing, depth=0)
128207
original_expr = expr
129208

130209
if depth == 0
131-
error_if_return(expr)
210+
error_if_any_control_flow(expr)
132211

133212
counter = 0
134213
expr = MacroTools.prewalk(expr) do x
@@ -285,6 +364,13 @@ function traced_if(cond, true_fn::TFn, false_fn::FFn, args) where {TFn,FFn}
285364
return cond ? true_fn(args) : false_fn(args)
286365
end
287366

367+
function traced_while(cond_fn, body_fn, args) where {CFn,BFn}
368+
while cond_fn(args...)
369+
args = body_fn(args...)
370+
end
371+
return args
372+
end
373+
288374
function cleanup_expr_to_avoid_boxing(expr, prepend::Symbol, all_vars)
289375
return MacroTools.postwalk(expr) do x
290376
if x isa Symbol && x all_vars
@@ -294,10 +380,14 @@ function cleanup_expr_to_avoid_boxing(expr, prepend::Symbol, all_vars)
294380
end
295381
end
296382

297-
function error_if_return(expr)
383+
const CONTROL_FLOW_EXPRS = [:return, :break, :continue, :symbolicgoto]
384+
385+
function error_if_any_control_flow(expr)
298386
return MacroTools.postwalk(expr) do x
299-
if x isa Expr && x.head == :return
300-
error("Cannot use @trace on a block that contains a return statement")
387+
for head in CONTROL_FLOW_EXPRS
388+
if Meta.isexpr(x, head)
389+
error("Cannot use @trace on a block that contains a $head statement")
390+
end
301391
end
302392
return x
303393
end

src/Compiler.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -417,15 +417,15 @@ macro code_hlo(options, maybe_call=nothing)
417417
f = $(fname)
418418
args = $(Expr(:vect, call.args[2:end]...))
419419
mode = first($(compile_mlir)(f, args; optimize=options.optimize))
420-
return mode
420+
mode
421421
end
422422
elseif Meta.isexpr(call, :(.), 2) && Meta.isexpr(call.args[2], :tuple)
423423
quote
424424
options = $(options)
425425
f = Base.Broadcast.BroadcastFunction($(call.args[1]))
426426
args = $(call.args[2:end]...)
427427
mode = first($(compile_mlir)(f, args; optimize=options.optimize))
428-
return mode
428+
mode
429429
end
430430
else
431431
error("Invalid function call: $(call)")

src/ControlFlow.jl

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,70 @@ function ReactantCore.traced_if(
7474
end
7575
end
7676

77-
function get_region_removing_missing_values(compiled_fn, insertions)
77+
function ReactantCore.traced_while(
78+
cond_fn::CFn, body_fn::BFn, args
79+
) where {CFn<:Function,BFn<:Function}
80+
# TODO: detect and prevent mutation within the condition
81+
82+
# We promote all incoming args (is there a better way to do this?)
83+
traced_args = [
84+
if v isa Number && !(v isa TracedType)
85+
Reactant.promote_to(TracedRNumber{typeof(v)}, v)
86+
else
87+
v
88+
end for v in args
89+
]
90+
91+
(_, cond_fn_compiled, cond_fn_results, _, _, _, _, in_tys, cond_fn_linear_results) = Reactant.make_mlir_fn(
92+
cond_fn,
93+
traced_args,
94+
(),
95+
string(gensym("cond_fn")),
96+
false;
97+
no_args_in_result=true,
98+
return_dialect=:stablehlo,
99+
do_transpose=false,
100+
)
101+
102+
(_, body_fn_compiled, body_fn_results, _, _, _, _, _, body_fn_linear_results) = Reactant.make_mlir_fn(
103+
body_fn,
104+
traced_args,
105+
(),
106+
string(gensym("body_fn")),
107+
false;
108+
no_args_in_result=true,
109+
return_dialect=:stablehlo,
110+
do_transpose=false,
111+
)
112+
113+
cond_reg = take_region(cond_fn_compiled)
114+
body_reg = take_region(body_fn_compiled)
115+
116+
MLIR.IR.rmfromparent!(cond_fn_compiled)
117+
MLIR.IR.rmfromparent!(body_fn_compiled)
118+
119+
result_0 = in_tys
120+
121+
operands = MLIR.IR.Value[v.mlir_data for v in traced_args]
122+
123+
while_compiled = MLIR.Dialects.stablehlo.while_(
124+
operands; result_0, cond=cond_reg, body=body_reg
125+
)
126+
127+
return map(enumerate(traced_args)) do (i, res)
128+
res.mlir_data = MLIR.IR.result(while_compiled, i)
129+
return res
130+
end
131+
end
132+
133+
function take_region(compiled_fn)
78134
region = MLIR.IR.Region()
79135
MLIR.API.mlirRegionTakeBody(region, MLIR.API.mlirOperationGetRegion(compiled_fn, 0))
136+
return region
137+
end
138+
139+
function get_region_removing_missing_values(compiled_fn, insertions)
140+
region = take_region(compiled_fn)
80141
block = MLIR.IR.Block(MLIR.API.mlirRegionGetFirstBlock(region), false)
81142
return_op = MLIR.IR.terminator(block)
82143
for (i, rt) in insertions

src/TracedRNumber.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,20 @@ for (jlop, hloop) in (
124124
end
125125
end
126126

127+
function Base.div(
128+
@nospecialize(lhs::TracedRNumber{T}), rhs, ::typeof(RoundDown)
129+
) where {T<:Integer}
130+
return TracedRNumber{T}(
131+
(),
132+
MLIR.IR.result(
133+
MLIR.Dialects.stablehlo.divide(
134+
lhs.mlir_data, promote_to(TracedRNumber{T}, rhs).mlir_data
135+
),
136+
1,
137+
),
138+
)
139+
end
140+
127141
for (jlop, hloop, hlocomp) in (
128142
(:(Base.:(==)), :compare, "EQ"),
129143
(:(Base.:(!=)), :compare, "NE"),

src/utils.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ function make_mlir_fn(
4343
return_dialect=:func,
4444
no_args_in_result::Bool=false,
4545
construct_function_without_args::Bool=false,
46+
do_transpose=true,
4647
)
4748
if sizeof(typeof(f)) != 0 || f isa BroadcastFunction
4849
return (
@@ -57,6 +58,7 @@ function make_mlir_fn(
5758
return_dialect,
5859
no_args_in_result,
5960
construct_function_without_args,
61+
do_transpose,
6062
)[2:end]...,
6163
)
6264
end
@@ -82,8 +84,10 @@ function make_mlir_fn(
8284

8385
in_tys = if toscalar
8486
[MLIR.IR.TensorType((), MLIR.IR.Type(eltype(arg))) for arg in linear_args]
85-
else
87+
elseif do_transpose
8688
[transpose_ty(mlir_type(arg)) for arg in linear_args]
89+
else
90+
[mlir_type(arg) for arg in linear_args]
8791
end
8892

8993
sym_visibility = nothing
@@ -115,7 +119,7 @@ function make_mlir_fn(
115119
arg.mlir_data = args[i].mlir_data
116120
else
117121
raw_arg = MLIR.IR.argument(fnbody, i)
118-
row_maj_arg = transpose_val(raw_arg)
122+
row_maj_arg = do_transpose ? transpose_val(raw_arg) : raw_arg
119123
arg.mlir_data = row_maj_arg
120124
end
121125
end
@@ -180,12 +184,12 @@ function make_mlir_fn(
180184
ret = MLIR.IR.block!(fnbody) do
181185
vals = MLIR.IR.Value[]
182186
for res in linear_results
183-
if res isa MissingTracedValue
184-
col_maj = broadcast_to_size(false, ()).mlir_data
185-
elseif construct_function_without_args
186-
col_maj = res.mlir_data
187-
else
188-
col_maj = transpose_val(res.mlir_data)
187+
col_maj = if res isa MissingTracedValue
188+
broadcast_to_size(false, ()).mlir_data
189+
elseif construct_function_without_args || !do_transpose
190+
res.mlir_data
191+
elseif do_transpose
192+
transpose_val(res.mlir_data)
189193
end
190194
push!(vals, col_maj)
191195
end

0 commit comments

Comments
 (0)