diff --git a/Project.toml b/Project.toml index 1a87b40..88a0733 100644 --- a/Project.toml +++ b/Project.toml @@ -4,11 +4,12 @@ authors = ["Anton Smirnov "] version = "0.1.0" [deps] -AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534" ImageMagick = "6218d12a-5da1-5696-b52f-db25d2ecc6d1" ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795" @@ -25,6 +26,5 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -AMDGPU = "0.4" KernelAbstractions = "0.9" Zygote = "0.6.55" diff --git a/src/Nerf.jl b/src/Nerf.jl index d0d4541..c02cd8e 100644 --- a/src/Nerf.jl +++ b/src/Nerf.jl @@ -16,8 +16,7 @@ using Rotations using StaticArrays using Statistics using Zygote - -# TODO rand on device +using Flux include("kautils.jl") @@ -75,38 +74,33 @@ include("ray.jl") include("acceleration/occupancy.jl") include("encoding/grid.jl") include("encoding/spherical_harmonics.jl") -include("nn/nn.jl") include("sampler.jl") include("loss.jl") include("trainer.jl") include("renderer/renderer.jl") +include("models/common.jl") include("models/basic.jl") include("marching_cubes/marching_cubes.jl") include("marching_tetrahedra/marching_tetrahedra.jl") -function sync_free!(Backend, args...) - unsafe_free!.(args) -end - @info "[Nerf.jl] Backend: $BACKEND_NAME" @info "[Nerf.jl] Device: $Backend" # TODO -# - use Flux for models +# - join Backend and Flux.gpu somehow # - non-allocating renderer (except NN part) -# - get rid of sync_free function main() config_file = joinpath(pkgdir(Nerf), "data", "raccoon_sofa2", "transforms.json") dataset = Dataset(Backend; config_file) - model = BasicModel(BasicField(Backend)) + model = BasicModel(BasicField()) |> Flux.gpu trainer = Trainer(model, dataset; n_rays=512) camera = Camera(MMatrix{3, 4, Float32}(I), dataset.intrinsics) renderer = Renderer(Backend, camera, trainer.bbox, trainer.cone) - for i in 1:20_000 + for i in 1:100 loss = step!(trainer) @show i, loss @@ -143,9 +137,28 @@ end function benchmark() config_file = joinpath(pkgdir(Nerf), "data", "raccoon_sofa2", "transforms.json") dataset = Dataset(Backend; config_file) - model = BasicModel(BasicField(Backend)) + model = BasicModel(BasicField()) |> Flux.gpu trainer = Trainer(model, dataset; n_rays=512) + positions = CUDA.rand(Float32, 3, 512 * 512) + directions = CUDA.rand(Float32, 3, 512 * 512) + + @time begin + for i in 1:10 + model(positions, directions) + end + CUDA.synchronize() + end + + @time begin + for i in 1:1000 + model(positions, directions) + end + CUDA.synchronize() + end + return + + # GC.enable_logging(true) Core.println("Trainer benchmark") diff --git a/src/acceleration/occupancy.jl b/src/acceleration/occupancy.jl index 44840c4..64cb9a4 100644 --- a/src/acceleration/occupancy.jl +++ b/src/acceleration/occupancy.jl @@ -73,11 +73,11 @@ function update!( step ÷= update_frequency - Backend = get_backend(oc) - points = allocate(Backend, SVector{3, Float32}, (n_samples,)) - indices = allocate(Backend, UInt32, (n_samples,)) + kab = get_backend(oc) + points = allocate(kab, SVector{3, Float32}, (n_samples,)) + indices = allocate(kab, UInt32, (n_samples,)) - gp_kernel = generate_points!(Backend) + gp_kernel = generate_points!(kab) gp_kernel( points, indices, rng_state, density, bbox, -0.01f0, UInt32(step); ndrange=n_uniform) @@ -92,28 +92,40 @@ function update!( raw_points = reshape(reinterpret(Float32, points), 3, :) log_densities = density_eval_fn(raw_points) - sync_free!(Backend, points) + unsafe_free!(points) + if BACKEND_NAME == "AMD" + KernelAbstractions.synchronize(kab) + end - tmp_density = KernelAbstractions.zeros(Backend, Float32, size(oc.density)) - distribute_density!(Backend)( + tmp_density = KernelAbstractions.zeros(kab, Float32, size(oc.density)) + distribute_density!(kab)( reinterpret(UInt32, tmp_density), log_densities, indices, cone.min_stepsize; ndrange=length(indices)) - sync_free!(Backend, indices, log_densities) + unsafe_free!.((indices, log_densities)) + if BACKEND_NAME == "AMD" + KernelAbstractions.synchronize(kab) + end - ema_update!(Backend)( + ema_update!(kab)( oc.density, tmp_density, decay; ndrange=length(oc.density)) - sync_free!(Backend, tmp_density) + unsafe_free!(tmp_density) + if BACKEND_NAME == "AMD" + KernelAbstractions.synchronize(kab) + end update_binary!(oc; threshold) + if BACKEND_NAME == "AMD" + KernelAbstractions.synchronize(kab) + end return rng_state end function update_binary!(oc::OccupancyGrid; threshold::Float32 = 0.01f0) - Backend = get_backend(oc) + kab = get_backend(oc) oc.mean_density = mean(x -> max(0f0, x), @view(oc.density[:, :, :, 1])) threshold = min(threshold, oc.mean_density) - distribute_to_binary!(Backend)( + distribute_to_binary!(kab)( oc.binary, oc.density, threshold; ndrange=length(oc.binary)) binary_level_length = offset_binary(oc, 1) @@ -121,7 +133,7 @@ function update_binary!(oc::OccupancyGrid; threshold::Float32 = 0.01f0) ndrange = binary_level_length ÷ 8 n_levels = size(oc.density, 4) - bmp_kernel = binary_max_pool!(Backend) + bmp_kernel = binary_max_pool!(kab) for l in 1:(n_levels - 1) s, m, e = binary_level_length .* ((l - 1), l, (l + 1)) prev_level = @view(oc.binary[(s + 1):m]) @@ -222,9 +234,9 @@ end function mark_invisible_regions!( oc::OccupancyGrid; intrinsics, rotations, translations, ) - Backend = get_backend(oc) + kab = get_backend(oc) res_scale = 0.5f0 .* intrinsics.resolution ./ intrinsics.focal - _mark_invisible_regions!(Backend)( + _mark_invisible_regions!(kab)( oc.density, rotations, translations, res_scale; ndrange=length(oc.density)) end diff --git a/src/encoding/grid.jl b/src/encoding/grid.jl index 1db9738..c8b760b 100644 --- a/src/encoding/grid.jl +++ b/src/encoding/grid.jl @@ -1,7 +1,8 @@ include("grid_utils.jl") include("grid_kernels.jl") -struct GridEncoding{O} +struct GridEncoding{O, T} + θ::T offset_table::O n_dims::UInt32 n_features::UInt32 @@ -10,9 +11,12 @@ struct GridEncoding{O} base_resolution::UInt32 scale::Float32 end +Flux.@functor GridEncoding -function GridEncoding( - Backend; n_levels::Int = 16, scale::Float32 = 1.5f0, +Flux.trainable(ge::GridEncoding) = (; θ=ge.θ) + +function GridEncoding(; + n_levels::Int = 16, scale::Float32 = 1.5f0, base_resolution::Int = 16, n_features::Int = 2, hashmap_size::Int = 19, ) @assert n_levels < 34 "Too many levels for the offset table." @@ -39,8 +43,10 @@ function GridEncoding( offset_table[end] = offset n_params = offset * n_features + θ = rand(Float32, n_features, n_params ÷ n_features) .* 2f-4 .- 1f-4 + GridEncoding( - adapt(Backend, offset_table), UInt32(n_dims), UInt32(n_features), + θ, offset_table, UInt32(n_dims), UInt32(n_features), UInt32(n_levels), UInt32(n_params), UInt32(base_resolution), scale) end @@ -52,31 +58,26 @@ function _get_kernel_params(ge) NPD, NFPL end -function init(ge::GridEncoding) - shape = Int64.((ge.n_features, ge.n_params ÷ ge.n_features)) - adapt(get_backend(ge), rand(Float32, shape) .* 2f-4 .- 1f-4) -end - -function reset!(::GridEncoding, θ) - copy!(θ, rand(Float32, size(θ)) .* 2f-4 .- 1f-4) +function reset!(ge::GridEncoding) + copy!(ge.θ, rand(Float32, size(θ)) .* 2f-4 .- 1f-4) end function get_output_shape(ge::GridEncoding) Int.((ge.n_features, ge.n_levels)) end -function (ge::GridEncoding)(x, θ) +function (ge::GridEncoding)(x) Backend = get_backend(ge) n = size(x, 2) y = allocate(Backend, Float32, (get_output_shape(ge)..., n)) NPD, NFPL = _get_kernel_params(ge) grid_kernel!(Backend)( - y, nothing, x, θ, ge.offset_table, NPD, NFPL, + y, nothing, x, ge.θ, ge.offset_table, NPD, NFPL, ge.base_resolution, log2(ge.scale); ndrange=(n, ge.n_levels)) reshape(y, :, n) end -function (ge::GridEncoding)(x, θ, ::Val{:IG}) +function (ge::GridEncoding)(x, ::Val{:IG}) Backend = get_backend(ge) n = size(x, 2) y = allocate(Backend, Float32, (get_output_shape(ge)..., n)) @@ -85,16 +86,16 @@ function (ge::GridEncoding)(x, θ, ::Val{:IG}) NPD, NFPL = _get_kernel_params(ge) grid_kernel!(Backend)( - y, ∂y∂x, x, θ, ge.offset_table, NPD, NFPL, ge.base_resolution, + y, ∂y∂x, x, ge.θ, ge.offset_table, NPD, NFPL, ge.base_resolution, log2(ge.scale); ndrange=(n, ge.n_levels)) reshape(y, :, n), ∂y∂x end -function ∇(ge::GridEncoding, ∂f∂y, x, θ) +function ∇(ge::GridEncoding, ∂f∂y, x) Backend = get_backend(ge) n = size(x, 2) NPD, NFPL = _get_kernel_params(ge) - ∂grid = KernelAbstractions.zeros(Backend, Float32, size(θ)) + ∂grid = KernelAbstractions.zeros(Backend, Float32, size(ge.θ)) ∇grid_kernel!(Backend)( ∂grid, ∂f∂y, x, ge.offset_table, NPD, NFPL, ge.base_resolution, log2(ge.scale); ndrange=(n, ge.n_levels)) @@ -111,23 +112,22 @@ function ∇grid_input(ge::GridEncoding, ∂L∂y, ∂y∂x) ∂L∂x end -function ChainRulesCore.rrule(ge::GridEncoding, x, θ) +function ChainRulesCore.rrule(ge::GridEncoding, x) n = size(x, 2) function encode_pullback(Δ) - Δ2 = reshape(unthunk(Δ), (get_output_shape(ge)..., n)) - Tangent{GridEncoding}(), NoTangent(), ∇(ge, Δ2, x, θ) + Δ = reshape(unthunk(Δ), (get_output_shape(ge)..., n)) + Tangent{GridEncoding}(;θ=∇(ge, Δ, x)), NoTangent() end - ge(x, θ), encode_pullback + ge(x), encode_pullback end -function ChainRulesCore.rrule(ge::GridEncoding, x, θ, ::Val{:IG}) +function ChainRulesCore.rrule(ge::GridEncoding, x, ::Val{:IG}) n = size(x, 2) - y, ∂y∂x = ge(x, θ, Val{:IG}()) + y, ∂y∂x = ge(x, Val{:IG}()) function encode_pullback(Δ) - Δ2 = reshape(unthunk(Δ), (get_output_shape(ge)..., n)) - ( - Tangent{GridEncoding}(), @thunk(∇grid_input(ge, Δ2, ∂y∂x)), - @thunk(∇(ge, Δ2, x, θ)), NoTangent()) + Δ = reshape(unthunk(Δ), (get_output_shape(ge)..., n)) + (Tangent{GridEncoding}(; θ=@thunk(∇(ge, Δ, x))), + @thunk(∇grid_input(ge, Δ, ∂y∂x)), NoTangent()) end y, encode_pullback end diff --git a/src/kautils.jl b/src/kautils.jl index 2729121..f65fdd1 100644 --- a/src/kautils.jl +++ b/src/kautils.jl @@ -1,7 +1,7 @@ -# Supported values are: ROC, CUDA. +# Supported values are: AMD, CUDA. const BACKEND_NAME::String = @load_preference("backend", "ROC") -@static if BACKEND_NAME == "ROC" +@static if BACKEND_NAME == "AMD" using AMDGPU AMDGPU.allowscalar(false) const Backend::ROCBackend = ROCBackend() diff --git a/src/models/basic.jl b/src/models/basic.jl index 88a4b04..5ba0f31 100644 --- a/src/models/basic.jl +++ b/src/models/basic.jl @@ -3,35 +3,31 @@ struct BasicField{G <: GridEncoding, D, C} density_mlp::D color_mlp::C end +Flux.@functor BasicField -function BasicField(Backend; backbone_size::Int = 16, grid_kwargs...) - ge = GridEncoding(Backend; grid_kwargs...) +function BasicField(; + backbone_size::Int = 16, mlp_hidden_size::Int = 64, grid_kwargs..., +) + ge = GridEncoding(; grid_kwargs...) density_mlp_input = prod(get_output_shape(ge)) color_mlp_input = 16 + backbone_size # 16 for spherical harmonics density_mlp = Chain( - Dense(density_mlp_input => 64, relu), - Dense(64 => backbone_size)) + Dense(density_mlp_input => mlp_hidden_size, relu), + Dense(mlp_hidden_size => backbone_size)) color_mlp = Chain( - Dense(color_mlp_input => 64, relu), - Dense(64 => 64, relu), - Dense(64 => 3, sigmoid)) + Dense(color_mlp_input => mlp_hidden_size, relu), + Dense(mlp_hidden_size => mlp_hidden_size, relu), + Dense(mlp_hidden_size => 3, sigmoid)) BasicField(ge, density_mlp, color_mlp) end KernelAbstractions.get_backend(b::BasicField) = get_backend(b.grid_encoding) -function init(b::BasicField) - Backend = get_backend(b) - θge = init(b.grid_encoding) - θdensity = map(l -> adapt(Backend, l), init(b.density_mlp)) - θcolor = map(l -> adapt(Backend, l), init(b.color_mlp)) - (; θge, θdensity, θcolor) -end - -function reset!(b::BasicField, θ) - reset!(b.grid_encoding, θ.θge) - reset!(b.density_mlp, θ.θdensity) - reset!(b.color_mlp, θ.θcolor) +function reset!(b::BasicField) + reset!(b.grid_encoding) + # TODO reset chain + reset!(b.density_mlp) + reset!(b.color_mlp) end function _check_mode(mode) @@ -40,101 +36,112 @@ function _check_mode(mode) "`mode`=`$mode` must be either `Val{:NOIG}()` or `Val{:IG}()`.")) end -function (b::BasicField)(points::P, directions::D, θ, mode = Val{:NOIG}()) where { +function (b::BasicField)(points::P, directions::D, mode = Val{:NOIG}()) where { P <: AbstractMatrix{Float32}, D <: AbstractMatrix{Float32}, } _check_mode(mode) if mode == Val{:NOIG}() - encoded_points = b.grid_encoding(points, θ.θge) + encoded_points = b.grid_encoding(points) else - encoded_points = b.grid_encoding(points, θ.θge, mode) + encoded_points = b.grid_encoding(points, mode) end encoded_directions = spherical_harmonics(directions) - backbone = b.density_mlp(encoded_points, θ.θdensity) - rgb = b.color_mlp(vcat(backbone, encoded_directions), θ.θcolor) + backbone = b.density_mlp(encoded_points) + rgb = b.color_mlp(vcat(backbone, encoded_directions)) vcat(rgb, reshape(backbone[1, :], 1, :)) end -function density(b::BasicField, points::P, θ, mode = Val{:NOIG}()) where P <: AbstractMatrix{Float32} +function density(b::BasicField, points::P, mode = Val{:NOIG}()) where P <: AbstractMatrix{Float32} _check_mode(mode) if mode == Val{:NOIG}() - encoded_points = b.grid_encoding(points, θ.θge) + encoded_points = b.grid_encoding(points) else - encoded_points = b.grid_encoding(points, θ.θge, mode) + encoded_points = b.grid_encoding(points, mode) end - b.density_mlp(encoded_points, θ.θdensity)[1, :] + b.density_mlp(encoded_points)[1, :] end -function _dealloc_density(b::BasicField, points::P, θ) where P <: AbstractMatrix{Float32} - Backend = get_backend(b) - - encoded_points = b.grid_encoding(points, θ.θge) - tmp = b.density_mlp.layers[1](encoded_points, θ.θdensity[1]) - sync_free!(Backend, encoded_points) - dst = b.density_mlp.layers[2](tmp, θ.θdensity[2]) - sync_free!(Backend, tmp) +function _dealloc_density(b::BasicField, points::P) where P <: AbstractMatrix{Float32} + encoded_points = b.grid_encoding(points) + tmp = b.density_mlp.layers[1](encoded_points) + unsafe_free!(encoded_points) + dst = b.density_mlp.layers[2](tmp) + unsafe_free!(tmp) y = dst[1, :] - sync_free!(Backend, dst) + unsafe_free!(dst) return y end -function batched_density(b::BasicField, points::P, θ; batch::Int) where P <: AbstractMatrix{Float32} +function batched_density(b::BasicField, points::P; batch::Int) where P <: AbstractMatrix{Float32} n = size(points, 2) n_iterations = ceil(Int, n / batch) - Backend = get_backend(b) - σ = allocate(Backend, Float32, (n,)) + kab = get_backend(b) + σ = allocate(kab, Float32, (n,)) for i in 1:n_iterations i_start = (i - 1) * batch + 1 i_end = min(n, i * batch) - batch_σ = _dealloc_density(b, @view(points[:, i_start:i_end]), θ) + batch_σ = _dealloc_density(b, @view(points[:, i_start:i_end])) + if BACKEND_NAME == "AMD" + KernelAbstractions.synchronize(kab) + end + σ[i_start:i_end] .= batch_σ - sync_free!(Backend, batch_σ) + unsafe_free!(batch_σ) end σ end -# TODO eager dealloc density function - -struct BasicModel{F, P, O <: Adam} +mutable struct BasicModel{F, O} field::F - θ::P optimizer::O end +Flux.@functor BasicModel function BasicModel(field::BasicField) - θ = init(field) - BasicModel(field, θ, Adam(get_backend(field), θ; lr=1f-2)) + BasicModel(field, Adam(1f-2)) end KernelAbstractions.get_backend(m::BasicModel) = get_backend(m.field) # TODO simplify types function batched_density(m::BasicModel, points::P; batch::Int) where P <: AbstractMatrix - batched_density(m.field, points, m.θ; batch) + batched_density(m.field, points; batch) end function reset!(m::BasicModel) reset!(m.field, m.θ) - reset!(m.optimizer) + m.optimizer = Adam(1f-2) end function (m::BasicModel)(points::P, directions::D) where { P <: AbstractMatrix{Float32}, D <: AbstractMatrix{Float32}, } - m.field(points, directions, m.θ) + m.field(points, directions) end function ∇normals(m::BasicModel, points::P) where P <: AbstractMatrix{Float32} - Backend = get_backend(m) + kab = get_backend(m) Y, back = Zygote.pullback(points) do p - density(m.field, p, m.θ, Val{:IG}()) + density(m.field, p, Val{:IG}()) + end + if BACKEND_NAME == "AMD" + KernelAbstractions.synchronize(kab) + GC.gc(false) + KernelAbstractions.synchronize(kab) end + Δ = KernelAbstractions.ones(Backend, Float32, size(Y)) ∇ = back(Δ)[1] n⃗ = safe_normalize(-∇; dims=1) # TODO in-place normalization kernel with negation - sync_free!(Backend, Δ, ∇) + + if BACKEND_NAME == "AMD" + KernelAbstractions.synchronize(kab) + unsafe_free!.((Δ, ∇)) + GC.gc(false) + KernelAbstractions.synchronize(kab) + end n⃗ end @@ -142,15 +149,23 @@ function step!( m::BasicModel, points::P, directions::D; bundle::RayBundle, samples::RaySamples, images::Images, n_rays::Int, rng_state::UInt64, -) where { - P <: AbstractMatrix{Float32}, D <: AbstractMatrix{Float32}, -} - loss::Float32 = 0f0 - loss, ∇ = Zygote.withgradient(m.θ) do θ - rgba = m.field(points, directions, θ) +) where {P <: AbstractMatrix{Float32}, D <: AbstractMatrix{Float32}} + state = Flux.setup(m.optimizer, m.field) # TODO store state in the model + loss, ∇ = Zygote.withgradient(m.field) do nn + rgba = nn(points, directions) photometric_loss(rgba; bundle, samples, images, n_rays, rng_state) end - step!(m.optimizer, m.θ, ∇[1]; dispose=true) + kab = KernelAbstractions.get_backend(m) + if BACKEND_NAME == "AMD" + KernelAbstractions.synchronize(kab) + GC.gc(false) + KernelAbstractions.synchronize(kab) + end + + Flux.Optimise.update!(state, m.field, ∇[1]) # TODO replace model from struct? + if BACKEND_NAME == "AMD" + KernelAbstractions.synchronize(kab) + end loss end diff --git a/src/nn/common.jl b/src/models/common.jl similarity index 100% rename from src/nn/common.jl rename to src/models/common.jl diff --git a/src/nn/adam.jl b/src/nn/adam.jl deleted file mode 100644 index ae25757..0000000 --- a/src/nn/adam.jl +++ /dev/null @@ -1,95 +0,0 @@ -Base.@kwdef mutable struct Adam{T} - μ::Vector{T} - ν::Vector{T} - current_step::UInt32 = UInt32(0) - - # Hyperparameters. - lr::Float32 = 1f-2 - β1::Float32 = 0.9f0 - β2::Float32 = 0.999f0 - ϵ::Float32 = 1f-8 -end - -KernelAbstractions.get_backend(opt::Adam) = get_backend(first(opt.μ)) - -function Adam(Backend, θ; kwargs...) - μ, ν = [], [] # TODO unstable - _add_moments!(μ, ν, θ, Backend) - Adam(; μ, ν, kwargs...) -end - -function _add_moments!(μ, ν, θ::T, Backend) where T <: Union{Tuple, NamedTuple} - foreach(θᵢ -> _add_moments!(μ, ν, θᵢ, Backend), θ) -end - -function _add_moments!(μ, ν, θ, Backend) - push!(μ, KernelAbstractions.zeros(Backend, Float32, length(θ))) - push!(ν, KernelAbstractions.zeros(Backend, Float32, length(θ))) -end - -function reset!(opt::Adam) - fill!.(opt.μ, 0f0) - fill!.(opt.ν, 0f0) - opt.current_step = 0x0 -end - -""" - step!(opt::Adam, θ, ∇; dispose::Bool = true) - -Apply update rule to parameters `θ` with gradients `∇`. - -# Arguments: - -- `dispose::Bool`: Free memory taken by gradients `∇` after update. -""" -function step!(opt::Adam, θ, ∇; dispose::Bool) - length(θ) == length(∇) || error( - "Number of parameters must be the same as the number of gradients, " * - "but is instead `$(length(θ))` vs `$(length(∇))`.") - - opt.current_step += 0x1 - _step!(opt, θ, ∇, 1; dispose) - return -end - -function _step!(opt::Adam, θ::T, ∇::G, i; dispose::Bool) where { - T <: Union{Tuple, NamedTuple}, G <: Union{Tuple, NamedTuple}, -} - for (θᵢ, ∇ᵢ) in zip(θ, ∇) - i = _step!(opt, θᵢ, ∇ᵢ, i; dispose) - end - i -end - -function _step!(opt::Adam, θ::T, ∇::T, i; dispose::Bool) where T <: AbstractArray - # @assert !any(isnan.(θ)) "NaN parameters of size $(size(θ))" - # @assert !any(isnan.(∇)) "NaN parameters of size $(size(∇))" - size(θ) == size(∇) || error( - "Shape of parameters and gradients must be the same, " * - "but is instead `$(size(θ))` vs `$(size(∇))`.") - adam_step_kernel!(get_backend(opt))( - opt.μ[i], opt.ν[i], θ, ∇, Float32(opt.current_step), - opt.lr, opt.β1, opt.β2, opt.ϵ; ndrange=length(θ)) - - dispose && sync_free!(get_backend(opt), ∇) - - i + 1 -end - -@kernel function adam_step_kernel!( - μ, ν, Θ, @Const(∇), current_step::Float32, lr::Float32, - β1::Float32, β2::Float32, ϵ::Float32, -) - i = @index(Global) - ∇ᵢ = ∇[i] - ωᵢ = Θ[i] - - ∇ᵢ² = ∇ᵢ * ∇ᵢ - μᵢ = μ[i] = β1 * μ[i] + (1f0 - β1) * ∇ᵢ - νᵢ = ν[i] = β2 * ν[i] + (1f0 - β2) * ∇ᵢ² - - # Debiasing. - lr *= √(1f0 - β2^current_step) / (1f0 - β1^current_step) - Θ[i] = ωᵢ - (lr * μᵢ) / (√νᵢ + ϵ) -end - diff --git a/src/nn/nn.jl b/src/nn/nn.jl deleted file mode 100644 index 172c1cc..0000000 --- a/src/nn/nn.jl +++ /dev/null @@ -1,90 +0,0 @@ -include("common.jl") -include("adam.jl") - -struct Dense{T, F} - activation::F - in_channels::Int64 - out_channels::Int64 - - function Dense{T}( - mapping::Pair{Int64, Int64}, activation::F = identity, - ) where {T <: Union{Float16, Float32}, F} - new{T, F}(activation, first(mapping), last(mapping)) - end - - function Dense( - mapping::Pair{Int64, Int64}, activation::F = identity, - ) where F - new{Float32, F}(activation, first(mapping), last(mapping)) - end -end - -function (d::Dense{T, typeof(identity)})(x, θ) where T - θ * x -end - -function (d::Dense{T, F})(x, θ) where {T, F} - d.activation.(θ * x) -end - -get_precision(::Dense{T, F}) where {T, F} = T - -function init(d::Dense) - glorot_uniform((d.out_channels, d.in_channels), get_precision(d)) -end - -function reset!(::Dense, θ) - copy!(θ, glorot_uniform(size(θ), eltype(θ))) -end - -struct Chain{L} - layers::L - Chain(layers...) = new{typeof(layers)}(layers) -end - -# TODO make inferrable at length > 3 -# TODO init on device -function init(c::Chain) - recursive_init((), first(c.layers), Base.tail(c.layers)) -end - -function recursive_init(θ, l, c::Tuple) - recursive_init((θ..., init(l)), first(c), Base.tail(c)) -end -function recursive_init(θ, l, ::Tuple{}) - (θ..., init(l)) -end - -function reset!(c::Chain, θ) - foreach(l -> reset!(l[1], l[2]), zip(c.layers, θ)) -end - -function (c::Chain)(x, θ) - recursive_apply( - x, first(c.layers), Base.tail(c.layers), first(θ), Base.tail(θ)) -end - -function recursive_apply(x, l, c::Tuple, θₗ, θ) - recursive_apply(l(x, θₗ), first(c), Base.tail(c), first(θ), Base.tail(θ)) -end -function recursive_apply(x, l, ::Tuple{}, θₗ, ::Tuple{}) - l(x, θₗ) -end - -function glorot_uniform(dims, ::Type{T} = Float32; gain = one(T)) where T - scale::T = gain * √(24f0 / sum(dims)) - (rand(T, dims) .- T(0.5f0)) .* scale -end - -function relu(x::T) where T - ifelse(x < zero(T), zero(T), x) -end - -function softplus(x::T) where T - log1p(exp(-abs(x))) + relu(x) -end - -function sigmoid(x::T) where T - t = exp(-abs(x)) - ifelse(x ≥ zero(T), inv(one(T) + t), t / (one(T) + t)) -end diff --git a/src/renderer/renderer.jl b/src/renderer/renderer.jl index 7647559..898c710 100644 --- a/src/renderer/renderer.jl +++ b/src/renderer/renderer.jl @@ -128,7 +128,7 @@ function render_tile!( offset, r.mode; ndrange=n_hit) accumulate!(r.buffer; offset, tile_size) end - sync_free!(get_backend(r.buffer), bundle) + unsafe_free!(bundle) return nothing end @@ -192,7 +192,7 @@ function trace( camera_origin, camera_forward, r.bbox, min_transmittance, UInt32(n_steps); ndrange=n_alive) - sync_free!(Backend, samples, span) + unsafe_free!.((samples, span)) end n_hit = Int(Array(bundle.hit_counter)[1]) n_hit, bundle diff --git a/src/trainer.jl b/src/trainer.jl index 6786c37..fc260cc 100644 --- a/src/trainer.jl +++ b/src/trainer.jl @@ -73,7 +73,7 @@ function step!(t::Trainer) bundle, samples, images=t.dataset.images, n_rays=t.n_rays, rng_state=t.rng_state) - sync_free!(get_backend(t.model), bundle, samples) + unsafe_free!.((bundle, samples)) t.rng_state = advance(t.rng_state) t.step += 1 diff --git a/test/Project.toml b/test/Project.toml index db0ab6f..ce11dfa 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,8 +2,12 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[extras] +AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" diff --git a/test/grid_encoding.jl b/test/grid_encoding.jl index 17a8ffa..751ef3d 100644 --- a/test/grid_encoding.jl +++ b/test/grid_encoding.jl @@ -26,45 +26,43 @@ end end @testset "Deterministic result" begin - ge = Nerf.GridEncoding(Backend) - θ = Nerf.init(ge) - fill!(θ, 1f0) + ge = Nerf.GridEncoding() |> Flux.gpu + fill!(ge.θ, 1f0) x = KernelAbstractions.ones(Backend, Float32, (3, 16)) - y = ge(x, θ) + y = ge(x) @test sum(y) == 512f0 end @testset "Hashgrid gradients" begin - ge = Nerf.GridEncoding(Backend) - θ = Nerf.init(ge) + ge = Nerf.GridEncoding() |> Flux.gpu n = 16 - x = adapt(Backend, rand(Float32, (3, n))) + x = rand(Float32, (3, n)) |> Flux.gpu - ∇ = Zygote.gradient(θ) do θ - sum(ge(x, θ)) + ∇ = Zygote.gradient(ge) do ge + sum(ge(x)) end - @test size(∇[1]) == size(θ) + @test size(∇[1].θ) == size(ge.θ) - ∇ = Zygote.gradient(θ) do θ - sum(ge(x, θ, Val{:IG}())) + ∇ = Zygote.gradient(ge) do ge + sum(ge(x, Val{:IG}())) end - @test size(∇[1]) == size(θ) + @test size(∇[1].θ) == size(ge.θ) ∇ = Zygote.gradient(x) do x - sum(ge(x, θ, Val{:IG}())) + sum(ge(x, Val{:IG}())) end @test size(∇[1]) == (3, n) - ∇ = Zygote.gradient(θ, x) do θ, x - sum(ge(x, θ, Val{:IG}())) + ∇ = Zygote.gradient(ge, x) do ge, x + sum(ge(x, Val{:IG}())) end - @test size(∇[1]) == size(θ) + @test size(∇[1].θ) == size(ge.θ) @test size(∇[2]) == (3, n) - y, back = Zygote.pullback(x) do xi - ge(xi, θ, Val{:IG}()) + y, back = Zygote.pullback(x) do x + ge(x, Val{:IG}()) end Δ = KernelAbstractions.ones(Backend, Float32, size(y)) ∇ = back(Δ)[1] diff --git a/test/nn.jl b/test/nn.jl index cd2878a..8cedf1b 100644 --- a/test/nn.jl +++ b/test/nn.jl @@ -1,32 +1,3 @@ -@testset "Chain type-stability" begin - for T in (Float32, Float16) - @inferred Nerf.Chain( - Nerf.Dense{T}(3=>64, Nerf.relu), - Nerf.Dense{T}(64=>64, Nerf.relu), - Nerf.Dense{T}(64=>3)) - c = Nerf.Chain( - Nerf.Dense{T}(3=>64, Nerf.relu), - Nerf.Dense{T}(64=>64, Nerf.relu), - Nerf.Dense{T}(64=>3)) - - # TODO make inferrable at > 3 length - @inferred Nerf.init(c) - θ = map(l -> adapt(Backend, l), Nerf.init(c)) - - n = 16 - x = adapt(Backend, rand(T, (3, n))) - @inferred c(x, θ) - y = c(x, θ) - @test eltype(y) == T - - ∇ = Zygote.gradient(θ -> sum(c(x, θ)), θ) - # TODO look into it? - # @inferred Zygote.gradient(θ -> sum(c(x, θ)), θ) - - @test eltype(∇[1][1]) == T - end -end - @testset "Test normalization works correctly" begin n = 2 host_directions = zeros(Float32, 3, n) diff --git a/test/runtests.jl b/test/runtests.jl index 089936b..9551eca 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,10 +6,12 @@ using Nerf using StaticArrays using Test using Zygote +using Flux using KernelAbstractions const Backend = Nerf.Backend -@info "[Nerf.jl] Testing on backend: $Backend" +Nerf.AMDGPU.versioninfo() +@info "[Nerf.jl] Testing on backend: $Backend, Flux GPU: $(Flux.GPU_BACKEND)" const DEFAULT_CONFIG_FILE::String = joinpath( pkgdir(Nerf), "data", "raccoon_sofa2", "transforms.json") @@ -36,7 +38,7 @@ const DEFAULT_CONFIG_FILE::String = joinpath( @testset "Sampler" begin include("sampler.jl") end - @testset "Renderer" begin - include("renderer.jl") - end + # @testset "Renderer" begin + # include("renderer.jl") + # end end