@@ -32,33 +32,43 @@ to [`Base.copyto!(const_mem::CuConstantMemory{T}, value::Array{T}, kernel::HostK
3232"""
3333struct CuConstantMemory{T,N} <: AbstractArray{T,N}
3434 name:: Symbol
35- value:: Array{T,N}
35+ size:: Dims{N}
36+ value:: Union{Nothing,Array{T,N}}
3637
3738 function CuConstantMemory (value:: Array{T,N} ) where {T,N}
38- # TODO : add finalizer that removes the relevant entry from constant_memory_initializer?
3939 Base. isbitstype (T) || throw (ArgumentError (" CuConstantMemory only supports bits types" ))
4040 name = gensym (" constant_memory" )
4141 name = GPUCompiler. safe_name (string (name))
4242 name = Symbol (name)
4343 val = deepcopy (value)
44- constant_memory_initializer[name] = WeakRef (val)
45- return new {T,N} (name, val)
44+ # NOTE: we could avoid this mapping by putting values in the CuDeviceConstantMemory
45+ # type, but doing so for 64K (the upper limit) of data breaks the compiler.
46+ constant_memory_initializer[name] = WeakRef (value)
47+ return new {T,N} (name, size (val), val)
48+ end
49+
50+ function CuConstantMemory (:: UndefInitializer , dims:: Dims{N} ) where {T,N}
51+ Base. isbitstype (T) || throw (ArgumentError (" CuConstantMemory only supports bits types" ))
52+ name = gensym (" constant_memory" )
53+ name = GPUCompiler. safe_name (string (name))
54+ name = Symbol (name)
55+ return new {T,N} (name, dims, nothing )
4656 end
4757end
4858
4959CuConstantMemory {T} (:: UndefInitializer , dims:: Integer... ) where {T} =
5060 CuConstantMemory (Array {T} (undef, dims))
5161CuConstantMemory {T} (:: UndefInitializer , dims:: Dims{N} ) where {T,N} =
52- CuConstantMemory ( Array {T,N} (undef, dims) )
62+ CuConstantMemory {T,N} (undef, dims)
5363
54- Base. size (A:: CuConstantMemory ) = size (A . value)
64+ Base. size (A:: CuConstantMemory ) = A . size
5565
5666Base. getindex (A:: CuConstantMemory , i:: Integer ) = Base. getindex (A. value, i)
5767Base. setindex! (A:: CuConstantMemory , v, i:: Integer ) = Base. setindex! (A. value, v, i)
5868Base. IndexStyle (:: Type{<:CuConstantMemory} ) = Base. IndexLinear ()
5969
60- Adapt. adapt_storage (:: Adaptor , A:: CuConstantMemory{T,N} ) where {T,N} =
61- CuDeviceConstantMemory {T,N,A.name,size(A.value) } ()
70+ Adapt. adapt_storage (:: Adaptor , A:: CuConstantMemory{T,N} ) where {T,N} =
71+ CuDeviceConstantMemory {T,N,A.name,A.size } ()
6272
6373
6474"""
@@ -74,64 +84,3 @@ function Base.copyto!(const_mem::CuConstantMemory{T}, value::Array{T}, kernel::H
7484 global_array = CuGlobalArray {T} (kernel. mod, string (const_mem. name), length (const_mem))
7585 copyto! (global_array, value)
7686end
77-
78-
79- function emit_constant_memory_initializer! (mod:: LLVM.Module )
80- for global_var in globals (mod)
81- T_global = llvmtype (global_var)
82- if addrspace (T_global) == AS. Constant
83- constant_memory_name = Symbol (LLVM. name (global_var))
84- if ! haskey (constant_memory_initializer, constant_memory_name)
85- continue # non user defined constant memory, most likely from the CUDA runtime
86- end
87-
88- arr = constant_memory_initializer[constant_memory_name]. value
89- @assert ! isnothing (arr) " calling kernel containing garbage collected constant memory"
90-
91- flattened_arr = reduce (vcat, arr)
92- ctx = LLVM. context (mod)
93- typ = eltype (eltype (T_global))
94-
95- # TODO : have a look at how julia converts structs to llvm:
96- # https://github.com/JuliaLang/julia/blob/80ace52b03d9476f3d3e6ff6da42f04a8df1cf7b/src/cgutils.cpp#L572
97- # this only seems to emit a type though
98- if isa (typ, LLVM. IntegerType) || isa (typ, LLVM. FloatingPointType)
99- init = ConstantArray (flattened_arr, ctx)
100- elseif isa (typ, LLVM. ArrayType) # a struct with every field of the same type gets optimized to an array
101- constant_arrays = LLVM. Constant[]
102- for x in flattened_arr
103- fields = collect (map (name-> getfield (x, name), fieldnames (typeof (x))))
104- constant_array = ConstantArray (fields, ctx)
105- push! (constant_arrays, constant_array)
106- end
107- init = ConstantArray (typ, constant_arrays)
108- elseif isa (typ, LLVM. StructType)
109- constant_structs = LLVM. Constant[]
110- for x in flattened_arr
111- constants = LLVM. Constant[]
112- for fieldname in fieldnames (typeof (x))
113- field = getfield (x, fieldname)
114- if isa (field, Bool)
115- # NOTE: Bools get compiled to i8 instead of the more "correct" type i1
116- push! (constants, ConstantInt (LLVM. Int8Type (ctx), field))
117- elseif isa (field, Integer)
118- push! (constants, ConstantInt (field, ctx))
119- elseif isa (field, AbstractFloat)
120- push! (constants, ConstantFP (field, ctx))
121- else
122- throw (error (" constant memory does not currently support structs with non-primitive fields ($(typeof (x)) .$fieldname ::$(typeof (field)) )" ))
123- end
124- end
125- const_struct = ConstantStruct (typ, constants)
126- push! (constant_structs, const_struct)
127- end
128- init = ConstantArray (typ, constant_structs)
129- else
130- # unreachable, but let's be safe and throw a nice error message just in case
131- throw (error (" could not emit initializer for constant memory of type $typ " ))
132- end
133-
134- initializer! (global_var, init)
135- end
136- end
137- end
0 commit comments