Skip to content

Commit e74186a

Browse files
committed
WIP: Fix llvmcall recursion
1 parent e81fb86 commit e74186a

File tree

2 files changed

+57
-0
lines changed

2 files changed

+57
-0
lines changed

src/overdub.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,24 @@ function overdub_pass!(reflection::Reflection,
316316
end)
317317
end
318318

319+
#=== mark all `llvmcall`s as nooverdub ===#
320+
# TODO: this only works for: `Intrinsics.llvmcall` and not `Core.Intrinsics.llvmcall`
321+
# since there is a getproperty call in the way.
322+
# TODO: Need to fix for `istaggingenabled == true`
323+
if !iskwfunc && !istaggingenabled
324+
insert_statements!(overdubbed_code, overdubbed_codelocs,
325+
(x, i) -> begin
326+
if Base.Meta.isexpr(x, :call) &&
327+
is_ir_element(x.args[1], GlobalRef(Core.Intrinsics, :llvmcall), overdubbed_code)
328+
return 1
329+
end
330+
return nothing
331+
end,
332+
(x, i) -> begin
333+
[Expr(:call, Expr(:nooverdub, GlobalRef(Core.Intrinsics, :llvmcall)), x.args[2:end]...)]
334+
end)
335+
end
336+
319337
#=== untag all `foreigncall` SSAValue/SlotNumber arguments if tagging is enabled ===#
320338

321339
if istaggingenabled && !iskwfunc

test/misctests.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,8 +633,12 @@ callback()
633633

634634
println("done (took ", time() - before_time, " seconds)")
635635

636+
#############################################################################################
636637
# Test overdubbing of a call overload invoke
637638

639+
print(" running CtxCallOverload test...")
640+
before_time = time()
641+
638642
using LinearAlgebra
639643

640644
struct Dense{F,S,T}
@@ -663,3 +667,38 @@ let d = Dense(3,3)
663667
data = rand(3)
664668
Cassette.overdub(CtxCallOverload(), d, data)
665669
end
670+
671+
println("done (took ", time() - before_time, " seconds)")
672+
673+
#############################################################################################
674+
675+
print(" running LLVMCallCtx test...")
676+
before_time = time()
677+
using Cassette
678+
Cassette.@context LLVMCallCtx
679+
680+
# This overdub does nothing
681+
@noinline function Cassette.overdub(ctx::LLVMCallCtx, f, args...)
682+
if Cassette.canrecurse(ctx, f, args...)
683+
Cassette.recurse(ctx, f, args...)
684+
else
685+
Cassette.fallback(ctx, f, args...)
686+
end
687+
end
688+
689+
import Core.Intrinsics
690+
function llvm_sin(x::Float64)
691+
# Needs fix for Core.Intrinsics.llvmcall
692+
Intrinsics.llvmcall(
693+
(
694+
"""declare double @llvm.sin.f64(double)""",
695+
"""%2 = call double @llvm.sin.f64(double %0)
696+
ret double %2"""
697+
),
698+
Float64, Tuple{Float64}, x
699+
)
700+
end
701+
702+
Cassette.@overdub LLVMCallCtx() llvm_sin(4.0)
703+
704+
println("done (took ", time() - before_time, " seconds)")

0 commit comments

Comments
 (0)