Skip to content

Commit f690590

Browse files
committed
WIP: perform eager initialization, retain undefined-ness.
1 parent 0dbaf72 commit f690590

File tree

3 files changed

+75
-72
lines changed

3 files changed

+75
-72
lines changed

src/compiler/gpucompiler.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ function GPUCompiler.finish_module!(job::CUDACompilerJob, mod::LLVM.Module)
3535
Tuple{CompilerJob{PTXCompilerTarget}, typeof(mod)},
3636
job, mod)
3737
emit_exception_flag!(mod)
38-
emit_constant_memory_initializer!(mod)
3938
end
4039

4140
function GPUCompiler.link_libraries!(job::CUDACompilerJob, mod::LLVM.Module,

src/device/intrinsics/memory_constant.jl

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,66 @@ end
4747
mod = LLVM.parent(llvm_f)
4848

4949
# create a constant memory global variable
50+
# TODO: global_var alignment?
5051
T_global = LLVM.ArrayType(T_result, len)
5152
global_var = GlobalVariable(mod, T_global, string(global_name), AS.Constant)
52-
linkage!(global_var, LLVM.API.LLVMExternalLinkage) # NOTE: external linkage is the default
53+
linkage!(global_var, LLVM.API.LLVMWeakAnyLinkage) # merge, but make sure symbols aren't discarded
5354
extinit!(global_var, true)
54-
# TODO: global_var alignment?
55+
# XXX: if we don't extinit, LLVM can inline the constant memory if it's predefined.
56+
# that means we wouldn't be able to re-set it afterwards. do we want that?
57+
58+
# initialize the constant memory
59+
if haskey(constant_memory_initializer, global_name)
60+
arr = constant_memory_initializer[global_name].value
61+
if isnothing(arr)
62+
GPUCompiler.@safe_error "calling kernel containing garbage collected constant memory"
63+
end
64+
65+
flattened_arr = reduce(vcat, arr)
66+
typ = eltype(T_global)
67+
68+
# TODO: have a look at how julia converts structs to llvm:
69+
# https://github.com/JuliaLang/julia/blob/80ace52b03d9476f3d3e6ff6da42f04a8df1cf7b/src/cgutils.cpp#L572
70+
# this only seems to emit a type though
71+
init = if isa(typ, LLVM.IntegerType) || isa(typ, LLVM.FloatingPointType)
72+
ConstantArray(flattened_arr, ctx)
73+
elseif isa(typ, LLVM.ArrayType) # a struct with every field of the same type gets optimized to an array
74+
constant_arrays = LLVM.Constant[]
75+
for x in flattened_arr
76+
fields = collect(map(name->getfield(x, name), fieldnames(typeof(x))))
77+
constant_array = ConstantArray(fields, ctx)
78+
push!(constant_arrays, constant_array)
79+
end
80+
ConstantArray(typ, constant_arrays)
81+
elseif isa(typ, LLVM.StructType)
82+
constant_structs = LLVM.Constant[]
83+
for x in flattened_arr
84+
constants = LLVM.Constant[]
85+
for fieldname in fieldnames(typeof(x))
86+
field = getfield(x, fieldname)
87+
if isa(field, Bool)
88+
# NOTE: Bools get compiled to i8 instead of the more "correct" type i1
89+
push!(constants, ConstantInt(LLVM.Int8Type(ctx), field))
90+
elseif isa(field, Integer)
91+
push!(constants, ConstantInt(field, ctx))
92+
elseif isa(field, AbstractFloat)
93+
push!(constants, ConstantFP(field, ctx))
94+
else
95+
GPUCompiler.@safe_error "constant memory does not currently support structs with non-primitive fields ($(typeof(x)).$fieldname::$(typeof(field)))"
96+
end
97+
end
98+
const_struct = ConstantStruct(typ, constants)
99+
push!(constant_structs, const_struct)
100+
end
101+
ConstantArray(typ, constant_structs)
102+
else
103+
# unreachable, but let's be safe and throw a nice error message just in case
104+
GPUCompiler.@safe_error "Could not emit initializer for constant memory of type $typ"
105+
nothing
106+
end
107+
108+
init !== nothing && initializer!(global_var, init)
109+
end
55110

56111
# generate IR
57112
Builder(ctx) do builder

src/memory_constant.jl

Lines changed: 18 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -32,33 +32,43 @@ to [`Base.copyto!(const_mem::CuConstantMemory{T}, value::Array{T}, kernel::HostK
3232
"""
3333
struct CuConstantMemory{T,N} <: AbstractArray{T,N}
3434
name::Symbol
35-
value::Array{T,N}
35+
size::Dims{N}
36+
value::Union{Nothing,Array{T,N}}
3637

3738
function CuConstantMemory(value::Array{T,N}) where {T,N}
38-
# TODO: add finalizer that removes the relevant entry from constant_memory_initializer?
3939
Base.isbitstype(T) || throw(ArgumentError("CuConstantMemory only supports bits types"))
4040
name = gensym("constant_memory")
4141
name = GPUCompiler.safe_name(string(name))
4242
name = Symbol(name)
4343
val = deepcopy(value)
44-
constant_memory_initializer[name] = WeakRef(val)
45-
return new{T,N}(name, val)
44+
# NOTE: we could avoid this mapping by putting values in the CuDeviceConstantMemory
45+
# type, but doing so for 64K (the upper limit) of data breaks the compiler.
46+
constant_memory_initializer[name] = WeakRef(value)
47+
return new{T,N}(name, size(val), val)
48+
end
49+
50+
function CuConstantMemory(::UndefInitializer, dims::Dims{N}) where {T,N}
51+
Base.isbitstype(T) || throw(ArgumentError("CuConstantMemory only supports bits types"))
52+
name = gensym("constant_memory")
53+
name = GPUCompiler.safe_name(string(name))
54+
name = Symbol(name)
55+
return new{T,N}(name, dims, nothing)
4656
end
4757
end
4858

4959
CuConstantMemory{T}(::UndefInitializer, dims::Integer...) where {T} =
5060
CuConstantMemory(Array{T}(undef, dims))
5161
CuConstantMemory{T}(::UndefInitializer, dims::Dims{N}) where {T,N} =
52-
CuConstantMemory(Array{T,N}(undef, dims))
62+
CuConstantMemory{T,N}(undef, dims)
5363

54-
Base.size(A::CuConstantMemory) = size(A.value)
64+
Base.size(A::CuConstantMemory) = A.size
5565

5666
Base.getindex(A::CuConstantMemory, i::Integer) = Base.getindex(A.value, i)
5767
Base.setindex!(A::CuConstantMemory, v, i::Integer) = Base.setindex!(A.value, v, i)
5868
Base.IndexStyle(::Type{<:CuConstantMemory}) = Base.IndexLinear()
5969

60-
Adapt.adapt_storage(::Adaptor, A::CuConstantMemory{T,N}) where {T,N} =
61-
CuDeviceConstantMemory{T,N,A.name,size(A.value)}()
70+
Adapt.adapt_storage(::Adaptor, A::CuConstantMemory{T,N}) where {T,N} =
71+
CuDeviceConstantMemory{T,N,A.name,A.size}()
6272

6373

6474
"""
@@ -74,64 +84,3 @@ function Base.copyto!(const_mem::CuConstantMemory{T}, value::Array{T}, kernel::H
7484
global_array = CuGlobalArray{T}(kernel.mod, string(const_mem.name), length(const_mem))
7585
copyto!(global_array, value)
7686
end
77-
78-
79-
function emit_constant_memory_initializer!(mod::LLVM.Module)
80-
for global_var in globals(mod)
81-
T_global = llvmtype(global_var)
82-
if addrspace(T_global) == AS.Constant
83-
constant_memory_name = Symbol(LLVM.name(global_var))
84-
if !haskey(constant_memory_initializer, constant_memory_name)
85-
continue # non user defined constant memory, most likely from the CUDA runtime
86-
end
87-
88-
arr = constant_memory_initializer[constant_memory_name].value
89-
@assert !isnothing(arr) "calling kernel containing garbage collected constant memory"
90-
91-
flattened_arr = reduce(vcat, arr)
92-
ctx = LLVM.context(mod)
93-
typ = eltype(eltype(T_global))
94-
95-
# TODO: have a look at how julia converts structs to llvm:
96-
# https://github.com/JuliaLang/julia/blob/80ace52b03d9476f3d3e6ff6da42f04a8df1cf7b/src/cgutils.cpp#L572
97-
# this only seems to emit a type though
98-
if isa(typ, LLVM.IntegerType) || isa(typ, LLVM.FloatingPointType)
99-
init = ConstantArray(flattened_arr, ctx)
100-
elseif isa(typ, LLVM.ArrayType) # a struct with every field of the same type gets optimized to an array
101-
constant_arrays = LLVM.Constant[]
102-
for x in flattened_arr
103-
fields = collect(map(name->getfield(x, name), fieldnames(typeof(x))))
104-
constant_array = ConstantArray(fields, ctx)
105-
push!(constant_arrays, constant_array)
106-
end
107-
init = ConstantArray(typ, constant_arrays)
108-
elseif isa(typ, LLVM.StructType)
109-
constant_structs = LLVM.Constant[]
110-
for x in flattened_arr
111-
constants = LLVM.Constant[]
112-
for fieldname in fieldnames(typeof(x))
113-
field = getfield(x, fieldname)
114-
if isa(field, Bool)
115-
# NOTE: Bools get compiled to i8 instead of the more "correct" type i1
116-
push!(constants, ConstantInt(LLVM.Int8Type(ctx), field))
117-
elseif isa(field, Integer)
118-
push!(constants, ConstantInt(field, ctx))
119-
elseif isa(field, AbstractFloat)
120-
push!(constants, ConstantFP(field, ctx))
121-
else
122-
throw(error("constant memory does not currently support structs with non-primitive fields ($(typeof(x)).$fieldname::$(typeof(field)))"))
123-
end
124-
end
125-
const_struct = ConstantStruct(typ, constants)
126-
push!(constant_structs, const_struct)
127-
end
128-
init = ConstantArray(typ, constant_structs)
129-
else
130-
# unreachable, but let's be safe and throw a nice error message just in case
131-
throw(error("could not emit initializer for constant memory of type $typ"))
132-
end
133-
134-
initializer!(global_var, init)
135-
end
136-
end
137-
end

0 commit comments

Comments
 (0)