Skip to content

Commit a1fa03f

Browse files
authored
Update Project.toml (#705)
* we can always vendor more things * fix * fix
1 parent c20b142 commit a1fa03f

File tree

2 files changed

+221
-8
lines changed

2 files changed

+221
-8
lines changed

ext/ReactantCUDAExt.jl

Lines changed: 147 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ function ka_with_reactant(ndrange, workgroupsize, obj, args...)
297297

298298
# figure out the optimal workgroupsize automatically
299299
if KA.workgroupsize(obj) <: KA.DynamicSize && workgroupsize === nothing
300-
if !Reactant.Compiler.PartitionKA[]
300+
if !Reactant.Compiler.PartitionKA[] || Reactant.Compiler.Raise[]
301301
threads = prod(ndrange)
302302
else
303303
config = CUDA.launch_configuration(kernel.fun; max_threads=prod(ndrange))
@@ -459,6 +459,145 @@ function vendored_optimize_module!(
459459
end
460460
end
461461

462+
function vendored_buildEarlyOptimizerPipeline(mpm, @nospecialize(job), opt_level; instcombine=false)
463+
LLVM.add!(mpm, LLVM.NewPMCGSCCPassManager()) do cgpm
464+
# TODO invokeCGSCCCallbacks
465+
LLVM.add!(cgpm, LLVM.NewPMFunctionPassManager()) do fpm
466+
LLVM.add!(fpm, LLVM.Interop.AllocOptPass())
467+
LLVM.add!(fpm, LLVM.Float2IntPass())
468+
LLVM.add!(fpm, LLVM.LowerConstantIntrinsicsPass())
469+
end
470+
end
471+
LLVM.add!(mpm, GPULowerCPUFeaturesPass())
472+
if opt_level >= 1
473+
LLVM.add!(mpm, LLVM.NewPMFunctionPassManager()) do fpm
474+
if opt_level >= 2
475+
LLVM.add!(fpm, LLVM.SROAPass())
476+
if instcombine
477+
LLVM.add!(fpm, LLVM.InstCombinePass())
478+
else
479+
LLVM.add!(fpm, LLVM.InstSimplifyPass())
480+
end
481+
LLVM.add!(fpm, LLVM.JumpThreadingPass())
482+
LLVM.add!(fpm, LLVM.CorrelatedValuePropagationPass())
483+
LLVM.add!(fpm, LLVM.ReassociatePass())
484+
LLVM.add!(fpm, LLVM.EarlyCSEPass())
485+
LLVM.add!(fpm, LLVM.Interop.AllocOptPass())
486+
else
487+
if instcombine
488+
LLVM.add!(fpm, LLVM.InstCombinePass())
489+
else
490+
LLVM.add!(fpm, LLVM.InstSimplifyPass())
491+
end
492+
LLVM.add!(fpm, LLVM.EarlyCSEPass())
493+
end
494+
end
495+
# TODO invokePeepholeCallbacks
496+
end
497+
end
498+
499+
function vendored_buildIntrinsicLoweringPipeline(mpm, @nospecialize(job), opt_level; instcombine::Bool=false)
500+
GPUCompiler.add!(mpm, LLVM.Interop.RemoveNIPass())
501+
502+
# lower GC intrinsics
503+
if !GPUCompiler.uses_julia_runtime(job)
504+
LLVM.add!(mpm, LLVM.NewPMFunctionPassManager()) do fpm
505+
LLVM.add!(fpm, GPULowerGCFramePass())
506+
end
507+
end
508+
509+
# lower kernel state intrinsics
510+
# NOTE: we can only do so here, as GC lowering can introduce calls to the runtime,
511+
# and thus additional uses of the kernel state intrinsics.
512+
if job.config.kernel
513+
# TODO: now that all kernel state-related passes are being run here, merge some?
514+
LLVM.add!(mpm, AddKernelStatePass())
515+
LLVM.add!(mpm, LLVM.NewPMFunctionPassManager()) do fpm
516+
LLVM.add!(fpm, LowerKernelStatePass())
517+
end
518+
LLVM.add!(mpm, CleanupKernelStatePass())
519+
end
520+
521+
if !GPUCompiler.uses_julia_runtime(job)
522+
# remove dead uses of ptls
523+
LLVM.add!(mpm, LLVM.NewPMFunctionPassManager()) do fpm
524+
LLVM.add!(fpm, LLVM.ADCEPass())
525+
end
526+
LLVM.add!(mpm, GPULowerPTLSPass())
527+
end
528+
529+
LLVM.add!(mpm, LLVM.NewPMFunctionPassManager()) do fpm
530+
# lower exception handling
531+
if GPUCompiler.uses_julia_runtime(job)
532+
LLVM.add!(fpm, LLVM.Interop.LowerExcHandlersPass())
533+
end
534+
LLVM.add!(fpm, GPUCompiler.GCInvariantVerifierPass())
535+
LLVM.add!(fpm, LLVM.Interop.LateLowerGCPass())
536+
if GPUCompiler.uses_julia_runtime(job) && VERSION >= v"1.11.0-DEV.208"
537+
LLVM.add!(fpm, LLVM.Interop.FinalLowerGCPass())
538+
end
539+
end
540+
if GPUCompiler.uses_julia_runtime(job) && VERSION < v"1.11.0-DEV.208"
541+
LLVM.add!(mpm, LLVM.Interop.FinalLowerGCPass())
542+
end
543+
544+
if opt_level >= 2
545+
LLVM.add!(mpm, LLVM.NewPMFunctionPassManager()) do fpm
546+
LLVM.add!(fpm, LLVM.GVNPass())
547+
LLVM.add!(fpm, LLVM.SCCPPass())
548+
LLVM.add!(fpm, LLVM.DCEPass())
549+
end
550+
end
551+
552+
# lower PTLS intrinsics
553+
if GPUCompiler.uses_julia_runtime(job)
554+
LLVM.add!(mpm, LLVM.Interop.LowerPTLSPass())
555+
end
556+
557+
if opt_level >= 1
558+
LLVM.add!(mpm, LLVM.NewPMFunctionPassManager()) do fpm
559+
if instcombine
560+
LLVM.add!(fpm, LLVM.InstCombinePass())
561+
else
562+
LLVM.add!(fpm, LLVM.InstSimplifyPass())
563+
end
564+
LLVM.add!(fpm, LLVM.SimplifyCFGPass(; GPUCompiler.AggressiveSimplifyCFGOptions...))
565+
end
566+
end
567+
568+
# remove Julia address spaces
569+
LLVM.add!(mpm, LLVM.Interop.RemoveJuliaAddrspacesPass())
570+
571+
# Julia's operand bundles confuse the inliner, so repeat here now they are gone.
572+
# FIXME: we should fix the inliner so that inlined code gets optimized early-on
573+
LLVM.add!(mpm, LLVM.AlwaysInlinerPass())
574+
end
575+
576+
function vendored_buildNewPMPipeline!(mpm, @nospecialize(job), opt_level)
577+
# Doesn't call instcombine
578+
GPUCompiler.buildEarlySimplificationPipeline(mpm, job, opt_level)
579+
LLVM.add!(mpm, LLVM.AlwaysInlinerPass())
580+
vendored_buildEarlyOptimizerPipeline(mpm, job, opt_level)
581+
LLVM.add!(mpm, LLVM.NewPMFunctionPassManager()) do fpm
582+
# Doesn't call instcombine
583+
GPUCompiler.buildLoopOptimizerPipeline(fpm, job, opt_level)
584+
# Doesn't call instcombine
585+
GPUCompiler.buildScalarOptimizerPipeline(fpm, job, opt_level)
586+
if GPUCompiler.uses_julia_runtime(job) && opt_level >= 2
587+
# XXX: we disable vectorization, as this generally isn't useful for GPU targets
588+
# and actually causes issues with some back-end compilers (like Metal).
589+
# TODO: Make this not dependent on `uses_julia_runtime` (likely CPU), but it's own control
590+
# Doesn't call instcombine
591+
GPUCompiler.buildVectorPipeline(fpm, job, opt_level)
592+
end
593+
# if isdebug(:optim)
594+
# add!(fpm, WarnMissedTransformationsPass())
595+
# end
596+
end
597+
vendored_buildIntrinsicLoweringPipeline(mpm, job, opt_level)
598+
GPUCompiler.buildCleanupPipeline(mpm, job, opt_level)
599+
end
600+
462601
# compile to executable machine code
463602
function compile(job)
464603
# lower to PTX
@@ -495,11 +634,17 @@ function compile(job)
495634
LLVM.register!(pb, CleanupKernelStatePass())
496635

497636
LLVM.add!(pb, LLVM.NewPMModulePassManager()) do mpm
498-
GPUCompiler.buildNewPMPipeline!(mpm, job, opt_level)
637+
vendored_buildNewPMPipeline!(mpm, job, opt_level)
499638
end
500639
LLVM.run!(pb, mod, tm)
501640
end
641+
if Reactant.Compiler.DUMP_LLVMIR[]
642+
println("cuda.jl pre vendor IR\n", string(mod))
643+
end
502644
vendored_optimize_module!(job, mod)
645+
if Reactant.Compiler.DUMP_LLVMIR[]
646+
println("cuda.jl post vendor IR\n", string(mod))
647+
end
503648
LLVM.run!(CUDA.GPUCompiler.DeadArgumentEliminationPass(), mod, tm)
504649

505650
for fname in ("gpu_report_exception", "gpu_signal_exception")

src/Compiler.jl

Lines changed: 74 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -416,10 +416,12 @@ function optimization_passes(; no_nan::Bool=false, sroa::Bool=false, inline::Boo
416416
if sroa
417417
push!(passes, "propagate-constant-bounds")
418418
if DUMP_LLVMIR[]
419-
push!(passes, "sroa-wrappers{dump_prellvm=true dump_postllvm=true}")
419+
push!(passes, "sroa-wrappers{dump_prellvm=true dump_postllvm=true instcombine=false instsimplify=true}")
420420
else
421-
push!(passes, "sroa-wrappers")
421+
push!(passes, "sroa-wrappers{instcombine=false instsimplify=true}")
422422
end
423+
push!(passes, "canonicalize")
424+
push!(passes, "sroa-wrappers{instcombine=false instsimplify=true}")
423425
push!(passes, "libdevice-funcs-raise")
424426
push!(passes, "canonicalize")
425427
push!(passes, "remove-duplicate-func-def")
@@ -556,6 +558,9 @@ end
556558
const DEBUG_KERNEL = Ref{Bool}(false)
557559
const DUMP_LLVMIR = Ref{Bool}(false)
558560

561+
562+
const Raise = Ref{Bool}(false)
563+
559564
function compile_mlir!(
560565
mod,
561566
f,
@@ -605,16 +610,33 @@ function compile_mlir!(
605610
end
606611

607612
if backend == "cpu"
608-
kern = "lower-kernel{backend=cpu},canonicalize,lower-jit{openmp=true backend=cpu},symbol-dce"
613+
kern = "lower-kernel{backend=cpu},canonicalize"
614+
jit = "lower-jit{openmp=true backend=cpu},symbol-dce"
609615
elseif DEBUG_KERNEL[]
610616
curesulthandler = dlsym(
611617
Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult"
612618
)
613619
@assert curesulthandler !== nothing
614620
curesulthandler = Base.reinterpret(UInt, curesulthandler)
615-
kern = "lower-kernel,canonicalize,lower-jit{debug=true cuResultHandlerPtr=$curesulthandler cuOptLevel=$(cuOptLevel[]) cubinFormat=$(cubinFormat[]) indexBitWidth=$(cuindexBitWidth[]) cubinChip=$(cubinChip[]) cubinFeatures=$(cubinFeatures()) run_init=true toolkitPath=$toolkit},symbol-dce"
621+
kern = if Raise[]
622+
"lower-kernel{backend=cpu},canonicalize"
623+
else
624+
"lower-kernel,canonicalize"
625+
end
626+
jit = "lower-jit{debug=true cuResultHandlerPtr=$curesulthandler cuOptLevel=$(cuOptLevel[]) cubinFormat=$(cubinFormat[]) indexBitWidth=$(cuindexBitWidth[]) cubinChip=$(cubinChip[]) cubinFeatures=$(cubinFeatures()) run_init=true toolkitPath=$toolkit},symbol-dce"
616627
else
617-
kern = "lower-kernel,canonicalize,lower-jit{cuOptLevel=$(cuOptLevel[]) indexBitWidth=$(cuindexBitWidth[]) cubinFormat=$(cubinFormat[]) cubinChip=$(cubinChip[]) cubinFeatures=$(cubinFeatures()) run_init=true toolkitPath=$toolkit},symbol-dce"
628+
kern = if Raise[]
629+
"lower-kernel{backend=cpu},canonicalize"
630+
else
631+
"lower-kernel,canonicalize"
632+
end
633+
jit = "lower-jit{cuOptLevel=$(cuOptLevel[]) indexBitWidth=$(cuindexBitWidth[]) cubinFormat=$(cubinFormat[]) cubinChip=$(cubinChip[]) cubinFeatures=$(cubinFeatures()) run_init=true toolkitPath=$toolkit},symbol-dce"
634+
end
635+
636+
raise = if Raise[]
637+
"convert-llvm-to-cf,canonicalize,enzyme-lift-cf-to-scf,llvm-to-affine-access,canonicalize"
638+
else
639+
"canonicalize"
618640
end
619641

620642
opt_passes = optimization_passes(; no_nan, sroa=true)
@@ -634,6 +656,8 @@ function compile_mlir!(
634656
"enzyme-simplify-math",
635657
opt_passes2,
636658
kern,
659+
raise,
660+
jit
637661
],
638662
',',
639663
),
@@ -655,6 +679,43 @@ function compile_mlir!(
655679
',',
656680
),
657681
)
682+
elseif optimize === :before_jit
683+
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ","))
684+
run_pass_pipeline!(
685+
mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false
686+
)
687+
run_pass_pipeline!(
688+
mod,
689+
join(
690+
[
691+
"canonicalize",
692+
"remove-unnecessary-enzyme-ops",
693+
"enzyme-simplify-math",
694+
opt_passes2,
695+
kern,
696+
raise,
697+
],
698+
',',
699+
),
700+
)
701+
elseif optimize === :before_raise
702+
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ","))
703+
run_pass_pipeline!(
704+
mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false
705+
)
706+
run_pass_pipeline!(
707+
mod,
708+
join(
709+
[
710+
"canonicalize",
711+
"remove-unnecessary-enzyme-ops",
712+
"enzyme-simplify-math",
713+
opt_passes2,
714+
kern
715+
],
716+
',',
717+
),
718+
)
658719
elseif optimize === :no_enzyme
659720
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ","))
660721
run_pass_pipeline!(mod, "arith-raise{stablehlo=true}"; enable_verifier=false)
@@ -696,6 +757,8 @@ function compile_mlir!(
696757
"enzyme-simplify-math",
697758
opt_passes2,
698759
kern,
760+
raise,
761+
jit
699762
],
700763
',',
701764
),
@@ -706,7 +769,12 @@ function compile_mlir!(
706769
mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false
707770
)
708771
run_pass_pipeline!(
709-
mod, "canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math," * kern
772+
mod, join([
773+
"canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math",
774+
kern,
775+
raise,
776+
jit
777+
], ',')
710778
)
711779
elseif optimize === :canonicalize
712780
run_pass_pipeline!(

0 commit comments

Comments
 (0)