Skip to content

Commit 2946722

Browse files
authored
Use gensym-ed variable in @compile and @jit (#274)
* use gensym-ed variable in `@compile` and `@jit` * add test * formatting
1 parent acd7469 commit 2946722

File tree

2 files changed

+34
-13
lines changed

2 files changed

+34
-13
lines changed

src/Compiler.jl

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ end
430430
@compile f(args...)
431431
"""
432432
macro compile(args...)
433-
return esc(compile_call_expr(__module__, args...))
433+
return esc(first(compile_call_expr(__module__, args...)))
434434
end
435435

436436
"""
@@ -439,12 +439,12 @@ end
439439
Run @compile f(args..) then immediately execute it
440440
"""
441441
macro jit(args...)
442-
compile_expr = compile_call_expr(__module__, args...)
442+
compile_expr, (; f, args) = compile_call_expr(__module__, args...)
443443
#! format: off
444444
return esc(
445445
:(
446446
$(compile_expr);
447-
fn(args...)
447+
$(f)($(args)...)
448448
)
449449
)
450450
#! format: on
@@ -463,6 +463,9 @@ function compile_call_expr(mod, args...)
463463
end
464464
end
465465
call = only(args)
466+
f_symbol = gensym(:f)
467+
args_symbol = gensym(:args)
468+
f_compiled_symbol = gensym(:f_compiled)
466469
if Meta.isexpr(call, :call)
467470
bcast, fname, fname_full = correct_maybe_bcast_call(call.args[1])
468471
fname = if bcast
@@ -477,18 +480,28 @@ function compile_call_expr(mod, args...)
477480
:($(fname))
478481
end
479482
return quote
480-
options = (; optimize=$(options[:optimize]), sync=$(options[:sync]))
481-
f = $(fname)
482-
args = $(Expr(:tuple, call.args[2:end]...))
483-
fn = $(compile)(f, args; options.optimize, options.sync)
484-
end
483+
$(f_symbol) = $(fname)
484+
$(args_symbol) = $(Expr(:tuple, call.args[2:end]...))
485+
$(f_compiled_symbol) = $(compile)(
486+
$(f_symbol),
487+
$(args_symbol);
488+
optimize=$(options[:optimize]),
489+
sync=$(options[:sync]),
490+
)
491+
end,
492+
(; f=f_compiled_symbol, args=args_symbol)
485493
elseif Meta.isexpr(call, :(.), 2) && Meta.isexpr(call.args[2], :tuple)
486494
return quote
487-
options = (; optimize=$(options[:optimize]), sync=$(options[:sync]))
488-
f = Base.Broadcast.BroadcastFunction($(call.args[1]))
489-
args = $(call.args[2:end]...)
490-
fn = $(compile)(f, args; options.optimize, options.sync)
491-
end
495+
$(f_symbol) = Base.Broadcast.BroadcastFunction($(call.args[1]))
496+
$(args_symbol) = $(call.args[2:end]...)
497+
$(f_compiled_symbol) = $(compile)(
498+
$(f_symbol),
499+
$(args_symbol);
500+
optimize=$(options[:optimize]),
501+
sync=$(options[:sync]),
502+
)
503+
end,
504+
(; f=f_compiled_symbol, args=args_symbol)
492505
else
493506
error("Invalid function call: $(call)")
494507
end

test/compile.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,14 @@ Base.sum(x::NamedTuple{(:a,),Tuple{T}}) where {T<:Reactant.TracedRArray} = (; a=
4343
@test y2 Float32.(a)
4444
end
4545

46+
@testset "no variable name collisions in compile macros (#237)" begin
47+
f(x) = x
48+
g(x) = f(x)
49+
x = rand(2, 2)
50+
y = Reactant.to_rarray(x)
51+
@test (@jit g(y); true)
52+
end
53+
4654
# disabled due to long test time (core tests go from 2m to 7m just with this test)
4755
# @testset "resource exhaustation bug (#190)" begin
4856
# x = rand(2, 2)

0 commit comments

Comments
 (0)