11export CuConstantMemory
22
3- # Map a constant memory name to its array value
4- const constant_memory_initializer = Dict {Symbol,WeakRef} ()
5-
63"""
74 CuConstantMemory{T,N}(value::Array{T,N})
85 CuConstantMemory{T}(::UndefInitializer, dims::Integer...)
@@ -30,35 +27,44 @@ In cases where the same kernel object gets called mutiple times, and it is desir
3027the value of a `CuConstantMemory` variable in this kernel between calls, please refer
3128to [`Base.copyto!(const_mem::CuConstantMemory{T}, value::Array{T}, kernel::HostKernel)`](@ref)
3229"""
33- struct CuConstantMemory{T,N} <: AbstractArray{T,N}
34- name:: Symbol
35- value:: Array{T,N}
30+ mutable struct CuConstantMemory{T,N} <: AbstractArray{T,N}
31+ name:: String
32+ size:: Dims{N}
33+ value:: Union{Nothing,Array{T,N}}
34+
35+ function CuConstantMemory (value:: Array{T,N} ; name:: String ) where {T,N}
36+ Base. isbitstype (T) || throw (ArgumentError (" CuConstantMemory only supports bits types" ))
37+ return new {T,N} (GPUCompiler. safe_name (" constant_" * name), size (value), deepcopy (value))
38+ end
3639
37- function CuConstantMemory (value:: Array{T,N} ) where {T,N}
38- # TODO : add finalizer that removes the relevant entry from constant_memory_initializer?
40+ function CuConstantMemory (:: UndefInitializer , dims:: Dims{N} ; name:: String ) where {T,N}
3941 Base. isbitstype (T) || throw (ArgumentError (" CuConstantMemory only supports bits types" ))
40- name = gensym (" constant_memory" )
41- name = GPUCompiler. safe_name (string (name))
42- name = Symbol (name)
43- val = deepcopy (value)
44- constant_memory_initializer[name] = WeakRef (val)
45- return new {T,N} (name, val)
42+ return new {T,N} (GPUCompiler. safe_name (" constant_" * name), dims, nothing )
4643 end
4744end
4845
49- CuConstantMemory {T} (:: UndefInitializer , dims:: Integer... ) where {T} =
50- CuConstantMemory (Array {T} (undef, dims))
51- CuConstantMemory {T} (:: UndefInitializer , dims:: Dims{N} ) where {T,N} =
52- CuConstantMemory ( Array {T,N} (undef, dims) )
46+ CuConstantMemory {T} (:: UndefInitializer , dims:: Integer... ; kwargs ... ) where {T} =
47+ CuConstantMemory (Array {T} (undef, dims); kwargs ... )
48+ CuConstantMemory {T} (:: UndefInitializer , dims:: Dims{N} ; kwargs ... ) where {T,N} =
49+ CuConstantMemory {T,N} (undef, dims; kwargs ... )
5350
54- Base. size (A:: CuConstantMemory ) = size (A . value)
51+ Base. size (A:: CuConstantMemory ) = A . size
5552
5653Base. getindex (A:: CuConstantMemory , i:: Integer ) = Base. getindex (A. value, i)
5754Base. setindex! (A:: CuConstantMemory , v, i:: Integer ) = Base. setindex! (A. value, v, i)
5855Base. IndexStyle (:: Type{<:CuConstantMemory} ) = Base. IndexLinear ()
5956
60- Adapt. adapt_storage (:: Adaptor , A:: CuConstantMemory{T,N} ) where {T,N} =
61- CuDeviceConstantMemory {T,N,A.name,size(A.value)} ()
57+ function Adapt. adapt_storage (:: Adaptor , A:: CuConstantMemory{T,N} ) where {T,N}
58+ # convert the values to the type domain
59+ # XXX : this is tough on the compiler when dealing with large initializers.
60+ typevals = if A. value != = nothing
61+ Tuple (reshape (A. value, prod (A. size)))
62+ else
63+ nothing
64+ end
65+
66+ CuDeviceConstantMemory {T,N,Symbol(A.name),A.size,typevals} ()
67+ end
6268
6369
6470"""
@@ -74,64 +80,3 @@ function Base.copyto!(const_mem::CuConstantMemory{T}, value::Array{T}, kernel::H
7480 global_array = CuGlobalArray {T} (kernel. mod, string (const_mem. name), length (const_mem))
7581 copyto! (global_array, value)
7682end
77-
78-
79- function emit_constant_memory_initializer! (mod:: LLVM.Module )
80- for global_var in globals (mod)
81- T_global = llvmtype (global_var)
82- if addrspace (T_global) == AS. Constant
83- constant_memory_name = Symbol (LLVM. name (global_var))
84- if ! haskey (constant_memory_initializer, constant_memory_name)
85- continue # non user defined constant memory, most likely from the CUDA runtime
86- end
87-
88- arr = constant_memory_initializer[constant_memory_name]. value
89- @assert ! isnothing (arr) " calling kernel containing garbage collected constant memory"
90-
91- flattened_arr = reduce (vcat, arr)
92- ctx = LLVM. context (mod)
93- typ = eltype (eltype (T_global))
94-
95- # TODO : have a look at how julia converts structs to llvm:
96- # https://github.com/JuliaLang/julia/blob/80ace52b03d9476f3d3e6ff6da42f04a8df1cf7b/src/cgutils.cpp#L572
97- # this only seems to emit a type though
98- if isa (typ, LLVM. IntegerType) || isa (typ, LLVM. FloatingPointType)
99- init = ConstantArray (flattened_arr, ctx)
100- elseif isa (typ, LLVM. ArrayType) # a struct with every field of the same type gets optimized to an array
101- constant_arrays = LLVM. Constant[]
102- for x in flattened_arr
103- fields = collect (map (name-> getfield (x, name), fieldnames (typeof (x))))
104- constant_array = ConstantArray (fields, ctx)
105- push! (constant_arrays, constant_array)
106- end
107- init = ConstantArray (typ, constant_arrays)
108- elseif isa (typ, LLVM. StructType)
109- constant_structs = LLVM. Constant[]
110- for x in flattened_arr
111- constants = LLVM. Constant[]
112- for fieldname in fieldnames (typeof (x))
113- field = getfield (x, fieldname)
114- if isa (field, Bool)
115- # NOTE: Bools get compiled to i8 instead of the more "correct" type i1
116- push! (constants, ConstantInt (LLVM. Int8Type (ctx), field))
117- elseif isa (field, Integer)
118- push! (constants, ConstantInt (field, ctx))
119- elseif isa (field, AbstractFloat)
120- push! (constants, ConstantFP (field, ctx))
121- else
122- throw (error (" constant memory does not currently support structs with non-primitive fields ($(typeof (x)) .$fieldname ::$(typeof (field)) )" ))
123- end
124- end
125- const_struct = ConstantStruct (typ, constants)
126- push! (constant_structs, const_struct)
127- end
128- init = ConstantArray (typ, constant_structs)
129- else
130- # unreachable, but let's be safe and throw a nice error message just in case
131- throw (error (" could not emit initializer for constant memory of type $typ " ))
132- end
133-
134- initializer! (global_var, init)
135- end
136- end
137- end
0 commit comments