@@ -39,6 +39,17 @@ function JuliaContext(f; kwargs...)
3939end
4040
4141
42+ # # deferred compilation
43+
44+ """
45+ var"gpuc.deferred"(f, args...)::Ptr{Cvoid}
46+
47+ As if we were to call `f(args...)` but instead we are
48+ putting down a marker and return a function pointer to later
49+ call.
50+ """
51+ function var"gpuc.deferred" end
52+
4253# # compiler entrypoint
4354
4455export compile
@@ -127,33 +138,6 @@ function codegen(output::Symbol, @nospecialize(job::CompilerJob); toplevel::Bool
127138 error (" Unknown compilation output $output " )
128139end
129140
130- # primitive mechanism for deferred compilation, for implementing CUDA dynamic parallelism.
131- # this could both be generalized (e.g. supporting actual function calls, instead of
132- # returning a function pointer), and be integrated with the nonrecursive codegen.
133- const deferred_codegen_jobs = Dict {Int, Any} ()
134-
135- # We make this function explicitly callable so that we can drive OrcJIT's
136- # lazy compilation from, while also enabling recursive compilation.
137- Base. @ccallable Ptr{Cvoid} function deferred_codegen (ptr:: Ptr{Cvoid} )
138- ptr
139- end
140-
141- @generated function deferred_codegen (:: Val{ft} , :: Val{tt} ) where {ft,tt}
142- id = length (deferred_codegen_jobs) + 1
143- deferred_codegen_jobs[id] = (; ft, tt)
144- # don't bother looking up the method instance, as we'll do so again during codegen
145- # using the world age of the parent.
146- #
147- # this also works around an issue on <1.10, where we don't know the world age of
148- # generated functions so use the current world counter, which may be too new
149- # for the world we're compiling for.
150-
151- quote
152- # TODO : add an edge to this method instance to support method redefinitions
153- ccall (" extern deferred_codegen" , llvmcall, Ptr{Cvoid}, (Int,), $ id)
154- end
155- end
156-
157141const __llvm_initialized = Ref (false )
158142
159143@locked function emit_llvm (@nospecialize (job:: CompilerJob ); toplevel:: Bool ,
@@ -183,79 +167,76 @@ const __llvm_initialized = Ref(false)
183167 entry = finish_module! (job, ir, entry)
184168
185169 # deferred code generation
186- has_deferred_jobs = toplevel && ! only_entry && haskey ( functions (ir), " deferred_codegen " )
187- jobs = Dict {CompilerJob, String} (job => entry_fn )
188- if has_deferred_jobs
189- dyn_marker = functions (ir)[" deferred_codegen " ]
190-
191- # iterative compilation (non-recursive)
192- changed = true
193- while changed
194- changed = false
195-
196- # find deferred compiler
197- # TODO : recover this information earlier, from the Julia IR
198- worklist = Dict {CompilerJob, Vector{LLVM.CallInst}} ( )
199- for use in uses (dyn_marker)
200- # decode the call
201- call = user (use) :: LLVM.CallInst
202- id = convert (Int, first (operands (call) ))
203-
204- global deferred_codegen_jobs
205- dyn_val = deferred_codegen_jobs[id]
206-
207- # get a job in the appopriate world
208- dyn_job = if dyn_val isa CompilerJob
209- # trust that the user knows what they're doing
210- dyn_val
170+ run_optimization_for_deferred = false
171+ if haskey ( functions (ir), " gpuc.lookup " )
172+ run_optimization_for_deferred = true
173+ dyn_marker = functions (ir)[" gpuc.lookup " ]
174+
175+ # gpuc.deferred is lowered to a gpuc.lookup foreigncall, so we need to extract the
176+ # target method instance from the LLVM IR
177+ function find_base_object (val)
178+ while true
179+ if val isa ConstantExpr && ( opcode (val) == LLVM . API . LLVMIntToPtr ||
180+ opcode (val) == LLVM . API . LLVMBitCast ||
181+ opcode (val) == LLVM . API . LLVMAddrSpaceCast)
182+ val = first ( operands (val) )
183+ elseif val isa LLVM . IntToPtrInst ||
184+ val isa LLVM . BitCastInst ||
185+ val isa LLVM. AddrSpaceCastInst
186+ val = first (operands (val ))
187+ elseif val isa LLVM . LoadInst
188+ # In 1.11+ we no longer embed integer constants directly.
189+ gv = first ( operands (val))
190+ if gv isa LLVM . GlobalValue
191+ val = LLVM . initializer (gv)
192+ continue
193+ end
194+ break
211195 else
212- ft, tt = dyn_val
213- dyn_src = methodinstance (ft, tt, tls_world_age ())
214- CompilerJob (dyn_src, job. config)
196+ break
215197 end
216-
217- push! (get! (worklist, dyn_job, LLVM. CallInst[]), call)
218198 end
199+ return val
200+ end
219201
220- # compile and link
221- for dyn_job in keys (worklist)
222- # cached compilation
223- dyn_entry_fn = get! (jobs, dyn_job) do
224- dyn_ir, dyn_meta = codegen (:llvm , dyn_job; toplevel= false ,
225- parent_job= job)
226- dyn_entry_fn = LLVM. name (dyn_meta. entry)
227- merge! (compiled, dyn_meta. compiled)
228- @assert context (dyn_ir) == context (ir)
229- link! (ir, dyn_ir)
230- changed = true
231- dyn_entry_fn
232- end
233- dyn_entry = functions (ir)[dyn_entry_fn]
234-
235- # insert a pointer to the function everywhere the entry is used
236- T_ptr = convert (LLVMType, Ptr{Cvoid})
237- for call in worklist[dyn_job]
238- @dispose builder= IRBuilder () begin
239- position! (builder, call)
240- fptr = if LLVM. version () >= v " 17"
241- T_ptr = LLVM. PointerType ()
242- bitcast! (builder, dyn_entry, T_ptr)
243- elseif VERSION >= v " 1.12.0-DEV.225"
244- T_ptr = LLVM. PointerType (LLVM. Int8Type ())
245- bitcast! (builder, dyn_entry, T_ptr)
246- else
247- ptrtoint! (builder, dyn_entry, T_ptr)
248- end
249- replace_uses! (call, fptr)
202+ worklist = Dict {Any, Vector{LLVM.CallInst}} ()
203+ for use in uses (dyn_marker)
204+ # decode the call
205+ call = user (use):: LLVM.CallInst
206+ dyn_mi_inst = find_base_object (operands (call)[1 ])
207+ @compiler_assert isa (dyn_mi_inst, LLVM. ConstantInt) job
208+ dyn_mi = Base. unsafe_pointer_to_objref (
209+ convert (Ptr{Cvoid}, convert (Int, dyn_mi_inst)))
210+ push! (get! (worklist, dyn_mi, LLVM. CallInst[]), call)
211+ end
212+
213+ for dyn_mi in keys (worklist)
214+ dyn_fn_name = compiled[dyn_mi]. specfunc
215+ dyn_fn = functions (ir)[dyn_fn_name]
216+
217+ # insert a pointer to the function everywhere the entry is used
218+ T_ptr = convert (LLVMType, Ptr{Cvoid})
219+ for call in worklist[dyn_mi]
220+ @dispose builder= IRBuilder () begin
221+ position! (builder, call)
222+ fptr = if LLVM. version () >= v " 17"
223+ T_ptr = LLVM. PointerType ()
224+ bitcast! (builder, dyn_fn, T_ptr)
225+ elseif VERSION >= v " 1.12.0-DEV.225"
226+ T_ptr = LLVM. PointerType (LLVM. Int8Type ())
227+ bitcast! (builder, dyn_fn, T_ptr)
228+ else
229+ ptrtoint! (builder, dyn_fn, T_ptr)
250230 end
251- erase ! (call)
231+ replace_uses ! (call, fptr )
252232 end
233+ unsafe_delete! (LLVM. parent (call), call)
253234 end
254235 end
255236
256237 # all deferred compilations should have been resolved
257238 @compiler_assert isempty (uses (dyn_marker)) job
258- erase! ( dyn_marker)
239+ unsafe_delete! (ir, dyn_marker)
259240 end
260241
261242 if libraries
@@ -285,7 +266,7 @@ const __llvm_initialized = Ref(false)
285266 # global variables. this makes sure that the optimizer can, e.g.,
286267 # rewrite function signatures.
287268 if toplevel
288- preserved_gvs = collect ( values (jobs))
269+ preserved_gvs = [entry_fn]
289270 for gvar in globals (ir)
290271 if linkage (gvar) == LLVM. API. LLVMExternalLinkage
291272 push! (preserved_gvs, LLVM. name (gvar))
@@ -317,7 +298,7 @@ const __llvm_initialized = Ref(false)
317298 # deferred codegen has some special optimization requirements,
318299 # which also need to happen _after_ regular optimization.
319300 # XXX : make these part of the optimizer pipeline?
320- if has_deferred_jobs
301+ if run_optimization_for_deferred
321302 @dispose pb= NewPMPassBuilder () begin
322303 add! (pb, NewPMFunctionPassManager ()) do fpm
323304 add! (fpm, InstCombinePass ())
@@ -353,15 +334,15 @@ const __llvm_initialized = Ref(false)
353334 # finish the module
354335 #
355336 # we want to finish the module after optimization, so we cannot do so
356- # during deferred code generation. instead , process the deferred jobs
357- # here.
337+ # during deferred code generation. Instead , process the merged module
338+ # from all the jobs here.
358339 if toplevel
359340 entry = finish_ir! (job, ir, entry)
360341
361- for (job′, fn′) in jobs
362- job′ == job && continue
363- finish_ir! (job′, ir, functions (ir)[fn′])
364- end
342+ # for (job′, fn′) in jobs
343+ # job′ == job && continue
344+ # finish_ir!(job′, ir, functions(ir)[fn′])
345+ # end
365346 end
366347
367348 # replace non-entry function definitions with a declaration
0 commit comments