|  | 
|  | 1 | +export CuConstantMemory | 
|  | 2 | + | 
|  | 3 | +# Map a constant memory name to its array value | 
|  | 4 | +const constant_memory_initializer = Dict{Symbol,WeakRef}() | 
|  | 5 | + | 
|  | 6 | +""" | 
|  | 7 | +    CuConstantMemory{T,N}(value::Array{T,N}) | 
|  | 8 | +    CuConstantMemory{T}(::UndefInitializer, dims::Integer...) | 
|  | 9 | +    CuConstantMemory{T}(::UndefInitializer, dims::Dims{N}) | 
|  | 10 | +
 | 
|  | 11 | +Construct an `N`-dimensional constant memory array of type `T`, where `isbits(T)`. | 
|  | 12 | +
 | 
|  | 13 | +Note that `deepcopy` will be called on the `value` constructor argument, meaning that | 
|  | 14 | +mutations to the original `value` or its elements after construction will not be | 
|  | 15 | +reflected in the value of `CuConstantMemory`. | 
|  | 16 | +
 | 
|  | 17 | +The `UndefInitializer` constructors behave exactly like the regular `Array` version, | 
|  | 18 | +i.e. the value of `CuConstantMemory` will be completely random when using them. | 
|  | 19 | +
 | 
|  | 20 | +Unlike in CUDA C, structs cannot be put directly into constant memory. This feature can | 
|  | 21 | +be emulated however by wrapping the struct inside of a 1-element array. | 
|  | 22 | +
 | 
|  | 23 | +When using `CuConstantMemory` as a global variable it is required to pass it as an argument | 
|  | 24 | +to a kernel, where the argument is of type [`CuDeviceConstantMemory{T,N}`](@ref). | 
|  | 25 | +When using `CuConstantMemory` as a local variable that is captured by a kernel closure | 
|  | 26 | +this is not required, and it can be used directly like any other captured variable | 
|  | 27 | +without passing it as an argument. | 
|  | 28 | +
 | 
|  | 29 | +In cases where the same kernel object gets called mutiple times, and it is desired to mutate | 
|  | 30 | +the value of a `CuConstantMemory` variable in this kernel between calls, please refer | 
|  | 31 | +to [`Base.copyto!(const_mem::CuConstantMemory{T}, value::Array{T}, kernel::HostKernel)`](@ref) | 
|  | 32 | +""" | 
|  | 33 | +struct CuConstantMemory{T,N} <: AbstractArray{T,N} | 
|  | 34 | +    name::Symbol | 
|  | 35 | +    value::Array{T,N} | 
|  | 36 | + | 
|  | 37 | +    function CuConstantMemory(value::Array{T,N}) where {T,N} | 
|  | 38 | +        # TODO: add finalizer that removes the relevant entry from constant_memory_initializer? | 
|  | 39 | +        Base.isbitstype(T) || throw(ArgumentError("CuConstantMemory only supports bits types")) | 
|  | 40 | +        name = gensym("constant_memory") | 
|  | 41 | +        name = GPUCompiler.safe_name(string(name)) | 
|  | 42 | +        name = Symbol(name) | 
|  | 43 | +        val = deepcopy(value) | 
|  | 44 | +        constant_memory_initializer[name] = WeakRef(val) | 
|  | 45 | +        return new{T,N}(name, val) | 
|  | 46 | +    end | 
|  | 47 | +end | 
|  | 48 | + | 
|  | 49 | +CuConstantMemory{T}(::UndefInitializer, dims::Integer...) where {T} = | 
|  | 50 | +    CuConstantMemory(Array{T}(undef, dims)) | 
|  | 51 | +CuConstantMemory{T}(::UndefInitializer, dims::Dims{N}) where {T,N} = | 
|  | 52 | +    CuConstantMemory(Array{T,N}(undef, dims)) | 
|  | 53 | + | 
|  | 54 | +Base.size(A::CuConstantMemory) = size(A.value) | 
|  | 55 | + | 
|  | 56 | +Base.getindex(A::CuConstantMemory, i::Integer) = Base.getindex(A.value, i) | 
|  | 57 | +Base.setindex!(A::CuConstantMemory, v, i::Integer) = Base.setindex!(A.value, v, i) | 
|  | 58 | +Base.IndexStyle(::Type{<:CuConstantMemory}) = Base.IndexLinear() | 
|  | 59 | + | 
|  | 60 | +Adapt.adapt_storage(::Adaptor, A::CuConstantMemory{T,N}) where {T,N} =  | 
|  | 61 | +    CuDeviceConstantMemory{T,N,A.name,size(A.value)}() | 
|  | 62 | + | 
|  | 63 | + | 
|  | 64 | +""" | 
|  | 65 | +Given a `kernel` returned by `@cuda`, copy `value` into `const_mem` for subsequent calls to this `kernel`. | 
|  | 66 | +If `const_mem` is not used within `kernel`, an error will be thrown. | 
|  | 67 | +""" | 
|  | 68 | +function Base.copyto!(const_mem::CuConstantMemory{T}, value::Array{T}, kernel::HostKernel) where T | 
|  | 69 | +    # TODO: add bool argument to also change the value field of const_mem? | 
|  | 70 | +    if size(const_mem) != size(value) | 
|  | 71 | +        throw(DimensionMismatch("size of `value` does not match size of constant memory")) | 
|  | 72 | +    end | 
|  | 73 | + | 
|  | 74 | +    global_array = CuGlobalArray{T}(kernel.mod, string(const_mem.name), length(const_mem)) | 
|  | 75 | +    copyto!(global_array, value) | 
|  | 76 | +end | 
|  | 77 | + | 
|  | 78 | + | 
|  | 79 | +function emit_constant_memory_initializer!(mod::LLVM.Module) | 
|  | 80 | +    for global_var in globals(mod) | 
|  | 81 | +        T_global = llvmtype(global_var) | 
|  | 82 | +        if addrspace(T_global) == AS.Constant | 
|  | 83 | +            constant_memory_name = Symbol(LLVM.name(global_var)) | 
|  | 84 | +            if !haskey(constant_memory_initializer, constant_memory_name) | 
|  | 85 | +                continue # non user defined constant memory, most likely from the CUDA runtime | 
|  | 86 | +            end | 
|  | 87 | + | 
|  | 88 | +            arr = constant_memory_initializer[constant_memory_name].value | 
|  | 89 | +            @assert !isnothing(arr) "calling kernel containing garbage collected constant memory" | 
|  | 90 | + | 
|  | 91 | +            flattened_arr = reduce(vcat, arr) | 
|  | 92 | +            ctx = LLVM.context(mod) | 
|  | 93 | +            typ = eltype(eltype(T_global)) | 
|  | 94 | + | 
|  | 95 | +            # TODO: have a look at how julia converts structs to llvm: | 
|  | 96 | +            #       https://github.com/JuliaLang/julia/blob/80ace52b03d9476f3d3e6ff6da42f04a8df1cf7b/src/cgutils.cpp#L572 | 
|  | 97 | +            #       this only seems to emit a type though | 
|  | 98 | +            if isa(typ, LLVM.IntegerType) || isa(typ, LLVM.FloatingPointType) | 
|  | 99 | +                init = ConstantArray(flattened_arr, ctx) | 
|  | 100 | +            elseif isa(typ, LLVM.ArrayType) # a struct with every field of the same type gets optimized to an array | 
|  | 101 | +                constant_arrays = LLVM.Constant[] | 
|  | 102 | +                for x in flattened_arr | 
|  | 103 | +                    fields = collect(map(name->getfield(x, name), fieldnames(typeof(x)))) | 
|  | 104 | +                    constant_array = ConstantArray(fields, ctx) | 
|  | 105 | +                    push!(constant_arrays, constant_array) | 
|  | 106 | +                end | 
|  | 107 | +                init = ConstantArray(typ, constant_arrays) | 
|  | 108 | +            elseif isa(typ, LLVM.StructType) | 
|  | 109 | +                constant_structs = LLVM.Constant[] | 
|  | 110 | +                for x in flattened_arr | 
|  | 111 | +                    constants = LLVM.Constant[] | 
|  | 112 | +                    for fieldname in fieldnames(typeof(x)) | 
|  | 113 | +                        field = getfield(x, fieldname) | 
|  | 114 | +                        if isa(field, Bool) | 
|  | 115 | +                            # NOTE: Bools get compiled to i8 instead of the more "correct" type i1 | 
|  | 116 | +                            push!(constants, ConstantInt(LLVM.Int8Type(ctx), field)) | 
|  | 117 | +                        elseif isa(field, Integer) | 
|  | 118 | +                            push!(constants, ConstantInt(field, ctx)) | 
|  | 119 | +                        elseif isa(field, AbstractFloat) | 
|  | 120 | +                            push!(constants, ConstantFP(field, ctx)) | 
|  | 121 | +                        else | 
|  | 122 | +                            throw(error("constant memory does not currently support structs with non-primitive fields ($(typeof(x)).$fieldname::$(typeof(field)))")) | 
|  | 123 | +                        end | 
|  | 124 | +                    end | 
|  | 125 | +                    const_struct = ConstantStruct(typ, constants) | 
|  | 126 | +                    push!(constant_structs, const_struct) | 
|  | 127 | +                end | 
|  | 128 | +                init = ConstantArray(typ, constant_structs) | 
|  | 129 | +            else | 
|  | 130 | +                # unreachable, but let's be safe and throw a nice error message just in case | 
|  | 131 | +                throw(error("could not emit initializer for constant memory of type $typ")) | 
|  | 132 | +            end | 
|  | 133 | +             | 
|  | 134 | +            initializer!(global_var, init) | 
|  | 135 | +        end | 
|  | 136 | +    end | 
|  | 137 | +end | 
0 commit comments