@@ -596,6 +596,7 @@ def __init__(
596596 verts_uvs : Union [torch .Tensor , List [torch .Tensor ], Tuple [torch .Tensor ]],
597597 padding_mode : str = "border" ,
598598 align_corners : bool = True ,
599+ sampling_mode : str = "bilinear" ,
599600 ) -> None :
600601 """
601602 Textures are represented as a per mesh texture map and uv coordinates for each
@@ -613,6 +614,9 @@ def __init__(
613614 indicate the centers of the edge pixels in the maps.
614615 padding_mode: padding mode for outside grid values
615616 ("zeros", "border" or "reflection").
617+ sampling_mode: type of interpolation used to sample the texture.
618+ Corresponds to the mode parameter in PyTorch's
619+ grid_sample ("nearest" or "bilinear").
616620
617621 The align_corners and padding_mode arguments correspond to the arguments
618622 of the `grid_sample` torch function. There is an informative illustration of
@@ -641,6 +645,7 @@ def __init__(
641645 """
642646 self .padding_mode = padding_mode
643647 self .align_corners = align_corners
648+ self .sampling_mode = sampling_mode
644649 if isinstance (faces_uvs , (list , tuple )):
645650 for fv in faces_uvs :
646651 if fv .ndim != 2 or fv .shape [- 1 ] != 3 :
@@ -749,6 +754,9 @@ def clone(self) -> "TexturesUV":
749754 self .maps_padded ().clone (),
750755 self .faces_uvs_padded ().clone (),
751756 self .verts_uvs_padded ().clone (),
757+ align_corners = self .align_corners ,
758+ padding_mode = self .padding_mode ,
759+ sampling_mode = self .sampling_mode ,
752760 )
753761 if self ._maps_list is not None :
754762 tex ._maps_list = [m .clone () for m in self ._maps_list ]
@@ -770,6 +778,9 @@ def detach(self) -> "TexturesUV":
770778 self .maps_padded ().detach (),
771779 self .faces_uvs_padded ().detach (),
772780 self .verts_uvs_padded ().detach (),
781+ align_corners = self .align_corners ,
782+ padding_mode = self .padding_mode ,
783+ sampling_mode = self .sampling_mode ,
773784 )
774785 if self ._maps_list is not None :
775786 tex ._maps_list = [m .detach () for m in self ._maps_list ]
@@ -801,6 +812,7 @@ def __getitem__(self, index) -> "TexturesUV":
801812 maps = maps ,
802813 padding_mode = self .padding_mode ,
803814 align_corners = self .align_corners ,
815+ sampling_mode = self .sampling_mode ,
804816 )
805817 elif all (torch .is_tensor (f ) for f in [faces_uvs , verts_uvs , maps ]):
806818 new_tex = self .__class__ (
@@ -809,6 +821,7 @@ def __getitem__(self, index) -> "TexturesUV":
809821 maps = [maps ],
810822 padding_mode = self .padding_mode ,
811823 align_corners = self .align_corners ,
824+ sampling_mode = self .sampling_mode ,
812825 )
813826 else :
814827 raise ValueError ("Not all values are provided in the correct format" )
@@ -889,6 +902,7 @@ def extend(self, N: int) -> "TexturesUV":
889902 verts_uvs = new_props ["verts_uvs_padded" ],
890903 padding_mode = self .padding_mode ,
891904 align_corners = self .align_corners ,
905+ sampling_mode = self .sampling_mode ,
892906 )
893907
894908 new_tex ._num_faces_per_mesh = new_props ["_num_faces_per_mesh" ]
@@ -966,6 +980,7 @@ def sample_textures(self, fragments, **kwargs) -> torch.Tensor:
966980 texels = F .grid_sample (
967981 texture_maps ,
968982 pixel_uvs ,
983+ mode = self .sampling_mode ,
969984 align_corners = self .align_corners ,
970985 padding_mode = self .padding_mode ,
971986 )
@@ -1003,6 +1018,7 @@ def faces_verts_textures_packed(self) -> torch.Tensor:
10031018 textures = F .grid_sample (
10041019 texture_maps ,
10051020 faces_verts_uvs ,
1021+ mode = self .sampling_mode ,
10061022 align_corners = self .align_corners ,
10071023 padding_mode = self .padding_mode ,
10081024 ) # NxCxmax(Fi)x3
@@ -1060,6 +1076,7 @@ def join_batch(self, textures: List["TexturesUV"]) -> "TexturesUV":
10601076 faces_uvs = faces_uvs_list ,
10611077 padding_mode = self .padding_mode ,
10621078 align_corners = self .align_corners ,
1079+ sampling_mode = self .sampling_mode ,
10631080 )
10641081 new_tex ._num_faces_per_mesh = num_faces_per_mesh
10651082 return new_tex
@@ -1227,6 +1244,7 @@ def join_scene(self) -> "TexturesUV":
12271244 faces_uvs = [torch .cat (faces_uvs_merged )],
12281245 align_corners = self .align_corners ,
12291246 padding_mode = self .padding_mode ,
1247+ sampling_mode = self .sampling_mode ,
12301248 )
12311249
12321250 def centers_for_image (self , index : int ) -> torch .Tensor :
@@ -1259,6 +1277,7 @@ def centers_for_image(self, index: int) -> torch.Tensor:
12591277 torch .flip (coords .to (texture_image ), [2 ]),
12601278 # Convert from [0, 1] -> [-1, 1] range expected by grid sample
12611279 verts_uvs [:, None ] * 2.0 - 1 ,
1280+ mode = self .sampling_mode ,
12621281 align_corners = self .align_corners ,
12631282 padding_mode = self .padding_mode ,
12641283 ).cpu ()
0 commit comments