Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/Nerf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ function main()
i % 1000 == 0 || continue

render!(renderer, trainer.occupancy, trainer.bbox) do points, directions
model(points, directions)
eager_deallocation_eval(model, points, directions)
end
save("image-$i.png", RGB.(to_image(renderer.buffer)))
end
Expand All @@ -124,7 +124,7 @@ function render_benchmark(renderer::Renderer, trainer::Trainer, n::Int)
for i in 1:n
Core.println(i)
render!(renderer, trainer.occupancy, trainer.bbox) do points, directions
trainer.model(points, directions)
eager_deallocation_eval(trainer.model, points, directions)
end
end
end
Expand All @@ -135,7 +135,7 @@ function benchmark()
model = BasicModel(BasicField(DEVICE))
trainer = Trainer(model, dataset)

GC.enable_logging(true)
# GC.enable_logging(true)

Core.println("Trainer benchmark")

Expand Down
27 changes: 27 additions & 0 deletions src/models/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,27 @@ function (b::BasicField)(
vcat(rgb, reshape(backbone[1, :], 1, :))
end

# TODO once AMDGPU supports eager finalization for arrays, we can remove it.
function eager_deallocation_eval(
b::BasicField, points::P, directions::D, θ,
) where {
P <: AbstractMatrix{Float32}, D <: AbstractMatrix{Float32},
}
encoded_points = b.grid_encoding(points, θ.θge)
encoded_directions = spherical_harmonics(directions)
backbone = eager_deallocation_eval(b.density_mlp, encoded_points, θ.θdensity)
unsafe_free!(encoded_points)

color_input = vcat(backbone, encoded_directions)
unsafe_free!(encoded_directions)
rgb = eager_deallocation_eval(b.color_mlp, color_input, θ.θcolor)
unsafe_free!(color_input)

rgba = vcat(rgb, reshape(backbone[1, :], 1, :))
unsafe_free!.((rgb, backbone))
rgba
end

function density(b::BasicField, points::P, θ, mode = Val{:NOIG}()) where P <: AbstractMatrix{Float32}
_check_mode(mode)
if mode == Val{:NOIG}()
Expand Down Expand Up @@ -111,6 +132,12 @@ function (m::BasicModel)(points::P, directions::D) where {
m.field(points, directions, m.θ)
end

function eager_deallocation_eval(m::BasicModel, points::P, directions::D) where {
P <: AbstractMatrix{Float32}, D <: AbstractMatrix{Float32},
}
eager_deallocation_eval(m.field, points, directions, m.θ)
end

function ∇normals(m::BasicModel, points::P) where P <: AbstractMatrix{Float32}
Y, back = Zygote.pullback(points) do p
density(m.field, p, m.θ, Val{:IG}())
Expand Down
36 changes: 36 additions & 0 deletions src/nn/nn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,17 @@ function (d::Dense{T, F})(x, θ) where {T, F}
d.activation.(θ * x)
end

function eager_deallocation_eval(d::Dense{T, typeof(identity)}, x, θ) where T
θ * x
end

function eager_deallocation_eval(d::Dense{T, F}, x, θ) where {T, F}
y1 = θ * x
y2 = d.activation.(y1)
unsafe_free!(y1)
return y2
end

get_precision(::Dense{T, F}) where {T, F} = T

function init(d::Dense)
Expand Down Expand Up @@ -71,6 +82,31 @@ function recursive_apply(x, l, ::Tuple{}, θₗ, ::Tuple{})
l(x, θₗ)
end

function eager_deallocation_eval(c::Chain, x, θ)
recursive_apply!(
x, first(c.layers), Base.tail(c.layers), first(θ), Base.tail(θ);
free_input=false) # Do not free input, since it might be neede elsewhere.
end

function recursive_apply!(x, l, c::Tuple, θₗ, θ; free_input::Bool)
y = eager_deallocation_eval(l, x, θₗ)
if free_input
BACKEND == "ROC" && AMDGPU.wait!(y)
unsafe_free!(x)
end
recursive_apply!(
y, first(c), Base.tail(c), first(θ), Base.tail(θ);
free_input=true) # Free all intermediate inputs.
end
function recursive_apply!(x, l, ::Tuple{}, θₗ, ::Tuple{}; free_input::Bool)
y = eager_deallocation_eval(l, x, θₗ)
if free_input
BACKEND == "ROC" && AMDGPU.wait!(y)
unsafe_free!(x)
end
return y
end

function glorot_uniform(dims, ::Type{T} = Float32; gain = one(T)) where T
scale::T = gain * √(24f0 / sum(dims))
(rand(T, dims) .- T(0.5f0)) .* scale
Expand Down
3 changes: 1 addition & 2 deletions src/renderer/renderer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,6 @@ function trace(
n_alive = n_rays = length(rays)

for step in 1:max_steps
BACKEND == "ROC" && GC.gc(false) # FIXME

n_alive = compact!(bundle; n_alive, min_transmittance)
n_alive == 0 && break

Expand All @@ -188,6 +186,7 @@ function trace(
# TODO evaluate only density without color
∇n = reinterpret(SVector{3, Float32},
reshape(normals_consumer(raw_points), :))
BACKEND == "ROC" && GC.gc(false) # FIXME
else
∇n = nothing
end
Expand Down