99
1010import torch
1111from pytorch3d import _C
12+ from pytorch3d .common .datatypes import Device
1213
1314
1415# Example functions for blending the top K colors per pixel using the outputs
@@ -37,6 +38,17 @@ class BlendParams(NamedTuple):
3738 background_color : Union [torch .Tensor , Sequence [float ]] = (1.0 , 1.0 , 1.0 )
3839
3940
41+ def _get_background_color (
42+ blend_params : BlendParams , device : Device , dtype = torch .float32
43+ ) -> torch .Tensor :
44+ background_color_ = blend_params .background_color
45+ if isinstance (background_color_ , torch .Tensor ):
46+ background_color = background_color_ .to (device )
47+ else :
48+ background_color = torch .tensor (background_color_ , dtype = dtype , device = device )
49+ return background_color
50+
51+
4052def hard_rgb_blend (
4153 colors : torch .Tensor , fragments , blend_params : BlendParams
4254) -> torch .Tensor :
@@ -57,18 +69,11 @@ def hard_rgb_blend(
5769 Returns:
5870 RGBA pixel_colors: (N, H, W, 4)
5971 """
60- N , H , W , K = fragments .pix_to_face .shape
61- device = fragments .pix_to_face .device
72+ background_color = _get_background_color (blend_params , fragments .pix_to_face .device )
6273
6374 # Mask for the background.
6475 is_background = fragments .pix_to_face [..., 0 ] < 0 # (N, H, W)
6576
66- background_color_ = blend_params .background_color
67- if isinstance (background_color_ , torch .Tensor ):
68- background_color = background_color_ .to (device )
69- else :
70- background_color = colors .new_tensor (background_color_ )
71-
7277 # Find out how much background_color needs to be expanded to be used for masked_scatter.
7378 num_background_pixels = is_background .sum ()
7479
@@ -182,13 +187,8 @@ def softmax_rgb_blend(
182187 """
183188
184189 N , H , W , K = fragments .pix_to_face .shape
185- device = fragments .pix_to_face .device
186190 pixel_colors = torch .ones ((N , H , W , 4 ), dtype = colors .dtype , device = colors .device )
187- background_ = blend_params .background_color
188- if not isinstance (background_ , torch .Tensor ):
189- background = torch .tensor (background_ , dtype = torch .float32 , device = device )
190- else :
191- background = background_ .to (device )
191+ background_color = _get_background_color (blend_params , fragments .pix_to_face .device )
192192
193193 # Weight for background color
194194 eps = 1e-10
@@ -233,7 +233,7 @@ def softmax_rgb_blend(
233233
234234 # Sum: weights * textures + background color
235235 weighted_colors = (weights_num [..., None ] * colors ).sum (dim = - 2 )
236- weighted_background = delta * background
236+ weighted_background = delta * background_color
237237 pixel_colors [..., :3 ] = (weighted_colors + weighted_background ) / denom
238238 pixel_colors [..., 3 ] = 1.0 - alpha
239239
0 commit comments