Skip to content

Commit f00ef66

Browse files
bottlerfacebook-github-bot
authored andcommitted
NeRF training: avoid caching unused visualization data.
Summary: If we are not visualizing the training with visdom, then there are a couple of outputs of the coarse rendering step which are not small and are returned by the renderer but never used. We don't need to bother transferring them to the CPU. Reviewed By: nikhilaravi Differential Revision: D28939958 fbshipit-source-id: 7e0d6681d6524f7fb57b6b20164580006120de80
1 parent 7204a4c commit f00ef66

File tree

2 files changed

+23
-19
lines changed

2 files changed

+23
-19
lines changed

projects/nerf/nerf/nerf_renderer.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def __init__(
6464
n_layers_xyz: int = 8,
6565
append_xyz: Tuple[int] = (5,),
6666
density_noise_std: float = 0.0,
67+
visualization: bool = False,
6768
):
6869
"""
6970
Args:
@@ -102,6 +103,7 @@ def __init__(
102103
density_noise_std: The standard deviation of the random normal noise
103104
added to the output of the occupancy MLP.
104105
Active only when `self.training==True`.
106+
visualization: whether to store extra output for visualization.
105107
"""
106108

107109
super().__init__()
@@ -159,6 +161,7 @@ def __init__(
159161
self._density_noise_std = density_noise_std
160162
self._chunk_size_test = chunk_size_test
161163
self._image_size = image_size
164+
self.visualization = visualization
162165

163166
def precache_rays(
164167
self,
@@ -248,16 +251,15 @@ def _process_ray_chunk(
248251
else:
249252
raise ValueError(f"No such rendering pass {renderer_pass}")
250253

251-
return {
252-
"rgb_fine": rgb_fine,
253-
"rgb_coarse": rgb_coarse,
254-
"rgb_gt": rgb_gt,
254+
out = {"rgb_fine": rgb_fine, "rgb_coarse": rgb_coarse, "rgb_gt": rgb_gt}
255+
if self.visualization:
255256
# Store the coarse rays/weights only for visualization purposes.
256-
"coarse_ray_bundle": type(coarse_ray_bundle)(
257+
out["coarse_ray_bundle"] = type(coarse_ray_bundle)(
257258
*[v.detach().cpu() for k, v in coarse_ray_bundle._asdict().items()]
258-
),
259-
"coarse_weights": coarse_weights.detach().cpu(),
260-
}
259+
)
260+
out["coarse_weights"] = coarse_weights.detach().cpu()
261+
262+
return out
261263

262264
def forward(
263265
self,

projects/nerf/train_nerf.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def main(cfg: DictConfig):
5252
n_hidden_neurons_dir=cfg.implicit_function.n_hidden_neurons_dir,
5353
n_layers_xyz=cfg.implicit_function.n_layers_xyz,
5454
density_noise_std=cfg.implicit_function.density_noise_std,
55+
visualization=cfg.visualization.visdom,
5556
)
5657

5758
# Move the model to the relevant device.
@@ -195,17 +196,18 @@ def lr_lambda(epoch):
195196
stats.print(stat_set="train")
196197

197198
# Update the visualization cache.
198-
visuals_cache.append(
199-
{
200-
"camera": camera.cpu(),
201-
"camera_idx": camera_idx,
202-
"image": image.cpu().detach(),
203-
"rgb_fine": nerf_out["rgb_fine"].cpu().detach(),
204-
"rgb_coarse": nerf_out["rgb_coarse"].cpu().detach(),
205-
"rgb_gt": nerf_out["rgb_gt"].cpu().detach(),
206-
"coarse_ray_bundle": nerf_out["coarse_ray_bundle"],
207-
}
208-
)
199+
if viz is not None:
200+
visuals_cache.append(
201+
{
202+
"camera": camera.cpu(),
203+
"camera_idx": camera_idx,
204+
"image": image.cpu().detach(),
205+
"rgb_fine": nerf_out["rgb_fine"].cpu().detach(),
206+
"rgb_coarse": nerf_out["rgb_coarse"].cpu().detach(),
207+
"rgb_gt": nerf_out["rgb_gt"].cpu().detach(),
208+
"coarse_ray_bundle": nerf_out["coarse_ray_bundle"],
209+
}
210+
)
209211

210212
# Adjust the learning rate.
211213
lr_scheduler.step()

0 commit comments

Comments
 (0)