66import torch .nn as nn
77from common_testing import TestCaseMixin , get_random_cuda_device
88from pytorch3d .renderer import (
9+ AlphaCompositor ,
910 BlendParams ,
1011 HardGouraudShader ,
1112 Materials ,
1213 MeshRasterizer ,
1314 MeshRenderer ,
1415 PointLights ,
16+ PointsRasterizationSettings ,
17+ PointsRasterizer ,
18+ PointsRenderer ,
1519 RasterizationSettings ,
1620 SoftPhongShader ,
1721 TexturesVertex ,
1822)
1923from pytorch3d .renderer .cameras import FoVPerspectiveCameras , look_at_view_transform
20- from pytorch3d .structures . meshes import Meshes
24+ from pytorch3d .structures import Meshes , Pointclouds
2125from pytorch3d .utils .ico_sphere import ico_sphere
2226
2327
2731print ("GPUs: %s" % ", " .join (GPU_LIST ))
2832
2933
30- class TestRenderMultiGPU (TestCaseMixin , unittest .TestCase ):
34+ class TestRenderMeshesMultiGPU (TestCaseMixin , unittest .TestCase ):
3135 def _check_mesh_renderer_props_on_device (self , renderer , device ):
3236 """
3337 Helper function to check that all the properties of the mesh
@@ -99,7 +103,7 @@ def test_mesh_renderer_to(self):
99103 # This also tests that background_color is correctly moved to
100104 # the new device
101105 device2 = torch .device ("cuda:0" )
102- renderer .to (device2 )
106+ renderer = renderer .to (device2 )
103107 mesh = mesh .to (device2 )
104108 self ._check_mesh_renderer_props_on_device (renderer , device2 )
105109 output_images = renderer (mesh )
@@ -137,7 +141,7 @@ def init_render(self):
137141
138142 def forward (self , verts , texs ):
139143 batch_size = verts .size (0 )
140- self .renderer .to (verts .device )
144+ self .renderer = self . renderer .to (verts .device )
141145 tex = TexturesVertex (verts_features = texs )
142146 faces = self .faces .expand (batch_size , - 1 , - 1 ).to (verts .device )
143147 mesh = Meshes (verts , faces , tex ).to (verts .device )
@@ -157,3 +161,53 @@ def forward(self, verts, texs):
157161 # Test a few iterations
158162 for _ in range (100 ):
159163 model (verts , texs )
164+
165+
166+ class TestRenderPointssMultiGPU (TestCaseMixin , unittest .TestCase ):
167+ def _check_points_renderer_props_on_device (self , renderer , device ):
168+ """
169+ Helper function to check that all the properties have
170+ been moved to the correct device.
171+ """
172+ # Cameras
173+ self .assertEqual (renderer .rasterizer .cameras .device , device )
174+ self .assertEqual (renderer .rasterizer .cameras .R .device , device )
175+ self .assertEqual (renderer .rasterizer .cameras .T .device , device )
176+
177+ def test_points_renderer_to (self ):
178+ """
179+ Test moving all the tensors in the points renderer to a new device.
180+ """
181+
182+ device1 = torch .device ("cpu" )
183+
184+ R , T = look_at_view_transform (1500 , 0.0 , 0.0 )
185+
186+ raster_settings = PointsRasterizationSettings (
187+ image_size = 256 , radius = 0.001 , points_per_pixel = 1
188+ )
189+ cameras = FoVPerspectiveCameras (
190+ device = device1 , R = R , T = T , aspect_ratio = 1.0 , fov = 60.0 , zfar = 100
191+ )
192+ rasterizer = PointsRasterizer (cameras = cameras , raster_settings = raster_settings )
193+
194+ renderer = PointsRenderer (rasterizer = rasterizer , compositor = AlphaCompositor ())
195+
196+ mesh = ico_sphere (2 , device1 )
197+ verts_padded = mesh .verts_padded ()
198+ pointclouds = Pointclouds (
199+ points = verts_padded , features = torch .randn_like (verts_padded )
200+ )
201+ self ._check_points_renderer_props_on_device (renderer , device1 )
202+
203+ # Test rendering on cpu
204+ output_images = renderer (pointclouds )
205+ self .assertEqual (output_images .device , device1 )
206+
207+ # Move renderer and pointclouds to another device and re render
208+ device2 = torch .device ("cuda:0" )
209+ renderer = renderer .to (device2 )
210+ pointclouds = pointclouds .to (device2 )
211+ self ._check_points_renderer_props_on_device (renderer , device2 )
212+ output_images = renderer (pointclouds )
213+ self .assertEqual (output_images .device , device2 )
0 commit comments