diff --git a/src/Nerf.jl b/src/Nerf.jl index 43d0daa..d6a4b41 100644 --- a/src/Nerf.jl +++ b/src/Nerf.jl @@ -104,7 +104,7 @@ function main() i % 1000 == 0 || continue render!(renderer, trainer.occupancy, trainer.bbox) do points, directions - model(points, directions) + eager_deallocation_eval(model, points, directions) end save("image-$i.png", RGB.(to_image(renderer.buffer))) end @@ -124,7 +124,7 @@ function render_benchmark(renderer::Renderer, trainer::Trainer, n::Int) for i in 1:n Core.println(i) render!(renderer, trainer.occupancy, trainer.bbox) do points, directions - trainer.model(points, directions) + eager_deallocation_eval(trainer.model, points, directions) end end end @@ -135,7 +135,7 @@ function benchmark() model = BasicModel(BasicField(DEVICE)) trainer = Trainer(model, dataset) - GC.enable_logging(true) + # GC.enable_logging(true) Core.println("Trainer benchmark") diff --git a/src/models/basic.jl b/src/models/basic.jl index 7857b07..7c535cb 100644 --- a/src/models/basic.jl +++ b/src/models/basic.jl @@ -57,6 +57,27 @@ function (b::BasicField)( vcat(rgb, reshape(backbone[1, :], 1, :)) end +# TODO once AMDGPU supports eager finalization for arrays, we can remove it. +function eager_deallocation_eval( + b::BasicField, points::P, directions::D, θ, +) where { + P <: AbstractMatrix{Float32}, D <: AbstractMatrix{Float32}, +} + encoded_points = b.grid_encoding(points, θ.θge) + encoded_directions = spherical_harmonics(directions) + backbone = eager_deallocation_eval(b.density_mlp, encoded_points, θ.θdensity) + unsafe_free!(encoded_points) + + color_input = vcat(backbone, encoded_directions) + unsafe_free!(encoded_directions) + rgb = eager_deallocation_eval(b.color_mlp, color_input, θ.θcolor) + unsafe_free!(color_input) + + rgba = vcat(rgb, reshape(backbone[1, :], 1, :)) + unsafe_free!.((rgb, backbone)) + rgba +end + function density(b::BasicField, points::P, θ, mode = Val{:NOIG}()) where P <: AbstractMatrix{Float32} _check_mode(mode) if mode == Val{:NOIG}() @@ -111,6 +132,12 @@ function (m::BasicModel)(points::P, directions::D) where { m.field(points, directions, m.θ) end +function eager_deallocation_eval(m::BasicModel, points::P, directions::D) where { + P <: AbstractMatrix{Float32}, D <: AbstractMatrix{Float32}, +} + eager_deallocation_eval(m.field, points, directions, m.θ) +end + function ∇normals(m::BasicModel, points::P) where P <: AbstractMatrix{Float32} Y, back = Zygote.pullback(points) do p density(m.field, p, m.θ, Val{:IG}()) diff --git a/src/nn/nn.jl b/src/nn/nn.jl index 172c1cc..fd8e993 100644 --- a/src/nn/nn.jl +++ b/src/nn/nn.jl @@ -27,6 +27,17 @@ function (d::Dense{T, F})(x, θ) where {T, F} d.activation.(θ * x) end +function eager_deallocation_eval(d::Dense{T, typeof(identity)}, x, θ) where T + θ * x +end + +function eager_deallocation_eval(d::Dense{T, F}, x, θ) where {T, F} + y1 = θ * x + y2 = d.activation.(y1) + unsafe_free!(y1) + return y2 +end + get_precision(::Dense{T, F}) where {T, F} = T function init(d::Dense) @@ -71,6 +82,31 @@ function recursive_apply(x, l, ::Tuple{}, θₗ, ::Tuple{}) l(x, θₗ) end +function eager_deallocation_eval(c::Chain, x, θ) + recursive_apply!( + x, first(c.layers), Base.tail(c.layers), first(θ), Base.tail(θ); + free_input=false) # Do not free input, since it might be neede elsewhere. +end + +function recursive_apply!(x, l, c::Tuple, θₗ, θ; free_input::Bool) + y = eager_deallocation_eval(l, x, θₗ) + if free_input + BACKEND == "ROC" && AMDGPU.wait!(y) + unsafe_free!(x) + end + recursive_apply!( + y, first(c), Base.tail(c), first(θ), Base.tail(θ); + free_input=true) # Free all intermediate inputs. +end +function recursive_apply!(x, l, ::Tuple{}, θₗ, ::Tuple{}; free_input::Bool) + y = eager_deallocation_eval(l, x, θₗ) + if free_input + BACKEND == "ROC" && AMDGPU.wait!(y) + unsafe_free!(x) + end + return y +end + function glorot_uniform(dims, ::Type{T} = Float32; gain = one(T)) where T scale::T = gain * √(24f0 / sum(dims)) (rand(T, dims) .- T(0.5f0)) .* scale diff --git a/src/renderer/renderer.jl b/src/renderer/renderer.jl index da2e2b7..fd112bd 100644 --- a/src/renderer/renderer.jl +++ b/src/renderer/renderer.jl @@ -167,8 +167,6 @@ function trace( n_alive = n_rays = length(rays) for step in 1:max_steps - BACKEND == "ROC" && GC.gc(false) # FIXME - n_alive = compact!(bundle; n_alive, min_transmittance) n_alive == 0 && break @@ -188,6 +186,7 @@ function trace( # TODO evaluate only density without color ∇n = reinterpret(SVector{3, Float32}, reshape(normals_consumer(raw_points), :)) + BACKEND == "ROC" && GC.gc(false) # FIXME else ∇n = nothing end