Skip to content

Commit 0ddeab2

Browse files
CUDA: fix gc issues (#685)
* CUDA: fix gc issues * Update ext/ReactantCUDAExt.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 18400b6 commit 0ddeab2

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

ext/ReactantCUDAExt.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ struct CuTracedArray{T,N,A,Size} <: DenseArray{T,N}
1818
ptr::Core.LLVMPtr{T,A}
1919

2020
function CuTracedArray{T,N,A,Size}(xs::TracedRArray) where {T,N,A,Size}
21-
push!(Reactant.Compiler.context_gc_vector[MLIR.IR.context()], xs)
21+
gc_vec = Reactant.Compiler.context_gc_vector[MLIR.IR.context()]
22+
push!(gc_vec, xs)
23+
@assert gc_vec[end] === xs
2224
ptr = Base.reinterpret(Core.LLVMPtr{T,CUDA.AS.Global}, Base.pointer_from_objref(xs))
2325
return new(ptr)
2426
end

src/TracedRArray.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@ Base.strides(x::TracedRArray) = Base.size_to_strides(1, size(x)...)
2727

2828
Base.IndexStyle(::Type{<:TracedRArray}) = Base.IndexLinear()
2929

30+
# This is required otherwise we will copy a tracedrarray each time
31+
# we use it
32+
function Base.convert(::Type{TracedRArray}, x::TracedRArray)
33+
return x
34+
end
35+
3036
function Base.convert(::Type{TracedRArray}, x::AnyTracedRArray)
3137
return Base.convert(TracedRArray{unwrapped_eltype(x),ndims(x)}, x)
3238
end

0 commit comments

Comments
 (0)