Skip to content

Commit b6ec9a1

Browse files
avik-paljumerckx
andauthored
feat: @trace on function definition (#1685)
* feat: `@trace` on function definition * feat: support functors * feat: handle _ and no symbol functor * feat: support typed kwargs * test: add testcases Co-authored-by: jumerckx <[email protected]> * fix: single line fn --------- Co-authored-by: jumerckx <[email protected]>
1 parent e09c993 commit b6ec9a1

File tree

4 files changed

+211
-5
lines changed

4 files changed

+211
-5
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>", "Mosè Giordano <[email protected]>"]
4-
version = "0.2.163"
4+
version = "0.2.164"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -102,7 +102,7 @@ Preferences = "1.4.3"
102102
PythonCall = "0.9.25"
103103
Random = "1.10"
104104
Random123 = "1.7"
105-
ReactantCore = "0.1.15"
105+
ReactantCore = "0.1.16"
106106
Reactant_jll = "0.0.240"
107107
ScopedValues = "1.3.0"
108108
Scratch = "1.2"

lib/ReactantCore/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ReactantCore"
22
uuid = "a3311ec8-5e00-46d5-b541-4f83e724a433"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>"]
4-
version = "0.1.15"
4+
version = "0.1.16"
55

66
[deps]
77
ExpressionExplorer = "21656369-7473-754a-2065-74616d696c43"

lib/ReactantCore/src/ReactantCore.jl

Lines changed: 99 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module ReactantCore
22

33
using ExpressionExplorer: ExpressionExplorer
4-
using MacroTools: MacroTools
4+
using MacroTools: MacroTools, @capture
55

66
export @trace, within_compile, MissingTracedValue, promote_to_traced
77

@@ -174,26 +174,123 @@ macro trace(args...)
174174
track_numbers = track_numbers ? Number : Union{}
175175
expr = macroexpand(__module__, expr)
176176

177+
#! format: off
178+
if @capture(
179+
expr,
180+
(
181+
fnname_(call_args__) where {Typs__} = fnbody_
182+
) | (
183+
fnname_(call_args__) = fnbody_
184+
) | (
185+
function fnname_(call_args__)
186+
fnbody_
187+
end
188+
) | (
189+
function fnname_(call_args__) where {Typs__}
190+
fnbody_
191+
end
192+
)
193+
)
194+
return esc(trace_function_definition(__module__, expr))
195+
end
196+
#! format: on
197+
177198
if Meta.isexpr(expr, :(=))
178199
if Meta.isexpr(expr.args[2], :if)
179200
return esc(trace_if_with_returns(expr; track_numbers))
180201
end
181202
end
203+
182204
Meta.isexpr(expr, :call) && return esc(trace_call(__module__, expr))
205+
183206
if Meta.isexpr(expr, :(.), 2) && Meta.isexpr(expr.args[2], :tuple)
184207
fname = :($(Base.Broadcast.BroadcastFunction)($(expr.args[1])))
185208
args = only(expr.args[2:end]).args
186209
call = Expr(:call, fname, args...)
187210
return esc(trace_call(__module__, call))
188211
end
212+
189213
Meta.isexpr(expr, :if) && return esc(trace_if(expr; track_numbers))
214+
190215
Meta.isexpr(expr, :for) &&
191216
return (esc(trace_for(expr; track_numbers, checkpointing, mincut)))
217+
192218
Meta.isexpr(expr, :while) &&
193219
return (esc(trace_while(expr; track_numbers, checkpointing, mincut)))
220+
194221
return error(
195-
"Only `if-elseif-else` blocks, `for` and `while` loops are currently supported by `@trace`",
222+
"Only `if-elseif-else` blocks, function definitions, `function calls`, `for` and \
223+
`while` loops are currently supported by `@trace`"
224+
)
225+
end
226+
227+
function get_argname(expr)
228+
if Meta.isexpr(expr, :(::))
229+
length(expr.args) == 2 && return expr.args[1], expr
230+
@assert length(expr.args) == 1
231+
var = gensym(:_)
232+
return var, Expr(:(::), var, expr.args[1])
233+
end
234+
Meta.isexpr(expr, :kw) && return get_argname(expr.args[1])[1], expr
235+
Meta.isexpr(expr, :(...)) && return expr, expr
236+
@assert expr isa Symbol
237+
if expr == :_
238+
var = gensym(:_)
239+
return var, var
240+
end
241+
return expr, expr
242+
end
243+
244+
function trace_function_definition(mod, expr)
245+
internal_fn = MacroTools.splitdef(expr)
246+
orig_fname = internal_fn[:name]
247+
248+
isfunctor = Meta.isexpr(orig_fname, :(::))
249+
fname = gensym(Symbol(orig_fname, :internal))
250+
internal_fn[:name] = fname
251+
252+
if isfunctor
253+
if length(orig_fname.args) == 1
254+
sym_name = gensym("functor")
255+
orig_fname = Expr(:(::), sym_name, orig_fname.args[1])
256+
end
257+
@assert length(orig_fname.args) == 2
258+
insert!(internal_fn[:args], 1, :($orig_fname))
259+
end
260+
261+
new_fn = MacroTools.splitdef(expr)
262+
263+
standardized_argnames = get_argname.(new_fn[:args])
264+
argnames = first.(standardized_argnames)
265+
new_fn[:args] = last.(standardized_argnames)
266+
267+
if isfunctor
268+
insert!(argnames, 1, orig_fname.args[1])
269+
end
270+
271+
if isempty(new_fn[:kwargs])
272+
traced_call_expr = :($(traced_call)($(fname), $(argnames...)))
273+
untraced_call_expr = :($(fname)($(argnames...)))
274+
else
275+
kws = first.(get_argname.(new_fn[:kwargs]))
276+
traced_call_expr =
277+
:($(traced_call)(Core.kwcall, (; $(kws...)), $(fname), $(argnames...)))
278+
untraced_call_expr = :(Core.kwcall((; $(kws...)), $(fname), $(argnames...)))
279+
end
280+
281+
new_fn[:name] = orig_fname
282+
new_fn[:body] = :(
283+
if $(within_compile)() && $(any)($(is_traced), ($(argnames...),))
284+
return $(traced_call_expr)
285+
else
286+
return $(untraced_call_expr)
287+
end
196288
)
289+
290+
return quote
291+
$(MacroTools.combinedef(new_fn))
292+
$(MacroTools.combinedef(internal_fn))
293+
end
197294
end
198295

199296
function trace_while(expr; track_numbers, mincut, checkpointing, first_arg=nothing)

test/control_flow.jl

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -820,6 +820,115 @@ end
820820
@test a_ra == a
821821
end
822822

823+
@testset "trace function definitions" begin
824+
# Basic traced function definition
825+
@trace function traced_add(x, y)
826+
return x .+ y
827+
end
828+
829+
@testset "basic traced function" begin
830+
a = rand(2, 3)
831+
b = rand(2, 3)
832+
a_ra = Reactant.to_rarray(a)
833+
b_ra = Reactant.to_rarray(b)
834+
835+
# Should work the same as untraced when JIT compiled
836+
@test @jit(traced_add(a_ra, b_ra)) traced_add(a, b)
837+
ir = @code_hlo optimize = false traced_add(a_ra, b_ra)
838+
func_names = [
839+
String(Reactant.MLIR.IR.attr(op, "sym_name")) for
840+
op in Reactant.MLIR.IR.OperationIterator(Reactant.MLIR.IR.body(ir))
841+
]
842+
@test any(contains("traced_add"), func_names)
843+
844+
# Should also work with regular arrays (outside compile context)
845+
@test traced_add(a, b) a .+ b
846+
end
847+
848+
# Traced function with typed arguments
849+
@trace function traced_multiply(x::AbstractArray, y::AbstractArray)
850+
return x .* y
851+
end
852+
853+
@testset "traced function with type annotations" begin
854+
a = rand(3, 3)
855+
b = rand(3, 3)
856+
a_ra = Reactant.to_rarray(a)
857+
b_ra = Reactant.to_rarray(b)
858+
859+
@test @jit(traced_multiply(a_ra, b_ra)) traced_multiply(a, b)
860+
ir = @code_hlo optimize = false traced_multiply(a_ra, b_ra)
861+
func_names = [
862+
String(Reactant.MLIR.IR.attr(op, "sym_name")) for
863+
op in Reactant.MLIR.IR.OperationIterator(Reactant.MLIR.IR.body(ir))
864+
]
865+
@test any(contains("traced_multiply"), func_names)
866+
867+
@test traced_multiply(a, b) a .* b
868+
end
869+
870+
@trace singleline(x) = x .+ 1
871+
872+
@testset "single line function" begin
873+
a = rand(2, 3)
874+
a_ra = Reactant.to_rarray(a)
875+
876+
@test @jit(singleline(a_ra)) singleline(a)
877+
ir = @code_hlo optimize = false singleline(a_ra)
878+
func_names = [
879+
String(Reactant.MLIR.IR.attr(op, "sym_name")) for
880+
op in Reactant.MLIR.IR.OperationIterator(Reactant.MLIR.IR.body(ir))
881+
]
882+
@test any(contains("singleline"), func_names)
883+
884+
@test singleline(a) a .+ 1
885+
end
886+
887+
struct FunctorTest1{X}
888+
x::X
889+
end
890+
891+
@trace function (f::FunctorTest1)(y)
892+
return f.x .+ y
893+
end
894+
895+
@testset "function with functor" begin
896+
a = rand(2, 3)
897+
a_ra = Reactant.to_rarray(a)
898+
899+
fn1 = FunctorTest1(2.0f0)
900+
901+
@test @jit(fn1(a_ra)) fn1(a)
902+
ir = @code_hlo optimize = false fn1(a_ra)
903+
func_names = [
904+
String(Reactant.MLIR.IR.attr(op, "sym_name")) for
905+
op in Reactant.MLIR.IR.OperationIterator(Reactant.MLIR.IR.body(ir))
906+
]
907+
@test any(contains("FunctorTest1"), func_names)
908+
909+
@test fn1(a) a .+ 2.0f0
910+
end
911+
912+
@trace function func_with_kwargs(x; y=1)
913+
return x .+ y
914+
end
915+
916+
@testset "function with kwargs" begin
917+
a = rand(2, 3)
918+
a_ra = Reactant.to_rarray(a)
919+
920+
@test @jit(func_with_kwargs(a_ra; y=2.0f0)) func_with_kwargs(a; y=2.0f0)
921+
ir = @code_hlo optimize = false func_with_kwargs(a_ra; y=2.0f0)
922+
func_names = [
923+
String(Reactant.MLIR.IR.attr(op, "sym_name")) for
924+
op in Reactant.MLIR.IR.OperationIterator(Reactant.MLIR.IR.body(ir))
925+
]
926+
@test any(contains("kwcall"), func_names)
927+
928+
@test func_with_kwargs(a; y=2.0f0) a .+ 2.0f0
929+
end
930+
end
931+
823932
mutable struct TestClock{I}
824933
iteration::I
825934
end

0 commit comments

Comments
 (0)