Skip to content

Commit 00f7a5a

Browse files
committed
Handle llvmcall in overdub_pass
1 parent b318b15 commit 00f7a5a

File tree

4 files changed

+191
-1
lines changed

4 files changed

+191
-1
lines changed

src/overdub.jl

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,97 @@ function overdub_pass!(reflection::Reflection,
223223
append!(overdubbed_code, code_info.code)
224224
append!(overdubbed_codelocs, code_info.codelocs)
225225

226+
#=== mark all `llvmcall`s as nooverdub, optionally mark all `Intrinsics`/`Builtins` nooverdub ===#
227+
228+
function unravel_intrinsics(x)
229+
stmt = Base.Meta.isexpr(x, :(=)) ? x.args[2] : x
230+
if Base.Meta.isexpr(stmt, :call)
231+
applycall = is_ir_element(stmt.args[1], GlobalRef(Core, :_apply), overdubbed_code)
232+
f = applycall ? stmt.args[2] : stmt.args[1]
233+
f = ir_element(f, overdubbed_code)
234+
if f isa Expr && Base.Meta.isexpr(f, :call) &&
235+
is_ir_element(f.args[1], GlobalRef(Base, :getproperty), overdubbed_code)
236+
237+
# resolve getproperty here
238+
# this is formed by Core.Intrinsics.llvmcall
239+
# %1 = Base.getproperty(Core, :Intrinsics)
240+
# %2 = GlobalRef(%1, :llvmcall)
241+
mod = ir_element(f.args[2], overdubbed_code)
242+
if mod isa GlobalRef
243+
mod = resolve_early(mod) # returns nothing if fails
244+
end
245+
if !(mod isa Module)
246+
# might be nothing or a Slot
247+
return nothing
248+
end
249+
fname = ir_element(f.args[3], overdubbed_code)
250+
if fname isa QuoteNode
251+
fname = fname.value
252+
end
253+
f = GlobalRef(mod, fname)
254+
end
255+
if f isa GlobalRef
256+
f = resolve_early(f)
257+
end
258+
return f
259+
end
260+
return nothing
261+
end
262+
263+
# TODO: add user-facing flag to do this for all intrinsics
264+
if !iskwfunc
265+
insert_statements!(overdubbed_code, overdubbed_codelocs,
266+
(x, i) -> begin
267+
intrinsic = unravel_intrinsics(x)
268+
if intrinsic === nothing
269+
return nothing
270+
end
271+
if intrinsic === Core.Intrinsics.llvmcall
272+
if istaggingenabled
273+
count = 0
274+
for arg in stmt.args
275+
if isa(arg, SSAValue) || isa(arg, SlotNumber)
276+
count += 1
277+
end
278+
end
279+
return count + 1
280+
else
281+
return 1
282+
end
283+
end
284+
end,
285+
(x, i) -> begin
286+
stmt = Base.Meta.isexpr(x, :(=)) ? x.args[2] : x
287+
applycall = is_ir_element(stmt.args[1], GlobalRef(Core, :_apply), overdubbed_code)
288+
intrinsic = unravel_intrinsics(x)
289+
items = Any[]
290+
args = nothing
291+
if istaggingenabled
292+
args = Any[]
293+
for arg in stmt.args
294+
if isa(arg, SSAValue) || isa(arg, SlotNumber)
295+
push!(args, SSAValue(i + length(items)))
296+
push!(items, Expr(:call, Expr(:nooverdub, GlobalRef(Cassette, :untag)), arg, overdub_ctx_slot))
297+
else
298+
push!(result.args, arg)
299+
end
300+
end
301+
end
302+
idx = 1
303+
if applycall
304+
idx = 2
305+
end
306+
# using stmt.args[idx] instead of `intrinsic` leads to a bug
307+
stmt.args[idx] = Expr(:nooverdub, intrinsic)
308+
if args !== nothing
309+
idx += 1
310+
stmt.args[idx:end] = args
311+
end
312+
push!(items, x)
313+
return items
314+
end)
315+
end
316+
226317
#=== perform tagged module transformation if tagging is enabled ===#
227318

228319
if istaggingenabled && !iskwfunc

src/pass.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,3 +197,35 @@ function is_ir_element(x, y, code::Vector)
197197
end
198198
return result
199199
end
200+
201+
"""
202+
ir_element(x, code::Vector)
203+
204+
Follows the series of `SSAValue` that define `x`.
205+
206+
See also: [`is_ir_element`](@ref)
207+
"""
208+
function ir_element(x, code::Vector)
209+
while isa(x, Core.SSAValue)
210+
x = code[x.id]
211+
end
212+
return x
213+
end
214+
215+
"""
216+
resolve_early(ref::GlobalRef)
217+
218+
Resolves a `Core.Compiler.GlobalRef` during compilation, may
219+
return `nothing` if the binding is not resolved or defined yet.
220+
Only use this when you are certain that the result of the lookup
221+
will not change.
222+
"""
223+
function resolve_early(ref::GlobalRef)
224+
mod = ref.mod
225+
name = ref.name
226+
if Base.isbindingresolved(mod, name) && Base.isdefined(mod, name)
227+
return getfield(mod, name)
228+
else
229+
return nothing
230+
end
231+
end

test/misctaggingtests.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,3 +496,34 @@ result = overdub(ctx, matrixliteral, tag(1, ctx, "hi"))
496496
@test metameta(result, ctx) == fill(Cassette.Meta("hi", Cassette.NoMetaMeta()), 2, 2)
497497

498498
println("done (took ", time() - before_time, " seconds)")
499+
500+
#############################################################################################
501+
502+
print(" running TaggedLLVMCallCtx test...")
503+
before_time = time()
504+
Cassette.@context TaggedLLVMCallCtx
505+
Cassette.metadatatype(::Type{<:ArrayIndexCtx}, ::Type{Float64}) = Float64
506+
507+
function Cassette.overdub(ctx::TaggedLLVMCallCtx, f, args...)
508+
if Cassette.canrecurse(ctx, f, args...)
509+
Cassette.recurse(ctx, f, args...)
510+
else
511+
Cassette.fallback(ctx, f, args...)
512+
end
513+
end
514+
515+
function llvm_sin(x::Float64)
516+
Core.Intrinsics.llvmcall(
517+
(
518+
"""declare double @llvm.sin.f64(double)""",
519+
"""%2 = call double @llvm.sin.f64(double %0)
520+
ret double %2"""
521+
),
522+
Float64, Tuple{Float64}, x
523+
)
524+
end
525+
526+
ctx = enabletagging(TaggedLLVMCallCtx(), llvm_sin)
527+
Cassette.@overdub ctx llvm_sin(tag(4.0, ctx, 1.0))
528+
529+
println("done (took ", time() - before_time, " seconds)")

test/misctests.jl

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -633,8 +633,12 @@ callback()
633633

634634
println("done (took ", time() - before_time, " seconds)")
635635

636+
#############################################################################################
636637
# Test overdubbing of a call overload invoke
637638

639+
print(" running CtxCallOverload test...")
640+
before_time = time()
641+
638642
using LinearAlgebra
639643

640644
struct Dense{F,S,T}
@@ -664,11 +668,43 @@ let d = Dense(3,3)
664668
Cassette.overdub(CtxCallOverload(), d, data)
665669
end
666670

671+
println("done (took ", time() - before_time, " seconds)")
672+
667673
#############################################################################################
668674

669-
print(" running OverdubOverdubCtx test...")
675+
println(" running OverdubOverdubCtx test...")
670676

671677
# Fixed in PR #148
672678
Cassette.@context OverdubOverdubCtx;
673679
overdub_overdub_me() = 2
674680
Cassette.overdub(OverdubOverdubCtx(), Cassette.overdub, OverdubOverdubCtx(), overdub_overdub_me)
681+
682+
#############################################################################################
683+
684+
print(" running LLVMCallCtx test...")
685+
before_time = time()
686+
Cassette.@context LLVMCallCtx
687+
688+
# This overdub does nothing, intentionally not marked `@inline`
689+
function Cassette.overdub(ctx::LLVMCallCtx, f, args...)
690+
if Cassette.canrecurse(ctx, f, args...)
691+
Cassette.recurse(ctx, f, args...)
692+
else
693+
Cassette.fallback(ctx, f, args...)
694+
end
695+
end
696+
697+
function llvm_sin(x::Float64)
698+
Core.Intrinsics.llvmcall(
699+
(
700+
"""declare double @llvm.sin.f64(double)""",
701+
"""%2 = call double @llvm.sin.f64(double %0)
702+
ret double %2"""
703+
),
704+
Float64, Tuple{Float64}, x
705+
)
706+
end
707+
708+
Cassette.@overdub LLVMCallCtx() llvm_sin(4.0)
709+
710+
println("done (took ", time() - before_time, " seconds)")

0 commit comments

Comments
 (0)