| 
 | 1 | +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.  | 
 | 2 | + | 
 | 3 | +from typing import Dict, List, Optional, Tuple  | 
 | 4 | + | 
 | 5 | +import torch  | 
 | 6 | +from pytorch3d.ops.marching_cubes_data import EDGE_TABLE, EDGE_TO_VERTICES, FACE_TABLE  | 
 | 7 | +from pytorch3d.transforms import Translate  | 
 | 8 | + | 
 | 9 | + | 
 | 10 | +EPS = 0.00001  | 
 | 11 | + | 
 | 12 | + | 
 | 13 | +class Cube:  | 
 | 14 | +    def __init__(self, bfl_vertex: Tuple[int, int, int], spacing: int = 1):  | 
 | 15 | +        """  | 
 | 16 | +        Initializes a cube given the bottom front left vertex coordinate  | 
 | 17 | +        and the cube spacing  | 
 | 18 | +
  | 
 | 19 | +        Edge and vertex convention:  | 
 | 20 | +
  | 
 | 21 | +                    v4_______e4____________v5  | 
 | 22 | +                    /|                    /|  | 
 | 23 | +                   / |                   / |  | 
 | 24 | +                e7/  |                e5/  |  | 
 | 25 | +                 /___|______e6_________/   |  | 
 | 26 | +              v7|    |                 |v6 |e9  | 
 | 27 | +                |    |                 |   |  | 
 | 28 | +                |    |e8               |e10|  | 
 | 29 | +             e11|    |                 |   |  | 
 | 30 | +                |    |_________________|___|  | 
 | 31 | +                |   / v0      e0       |   /v1  | 
 | 32 | +                |  /                   |  /  | 
 | 33 | +                | /e3                  | /e1  | 
 | 34 | +                |/_____________________|/  | 
 | 35 | +                v3         e2          v2  | 
 | 36 | +
  | 
 | 37 | +        Args:  | 
 | 38 | +            bfl_vertex: a tuple of size 3 corresponding to the bottom front left vertex  | 
 | 39 | +                of the cube in (x, y, z) format  | 
 | 40 | +            spacing: the length of each edge of the cube  | 
 | 41 | +        """  | 
 | 42 | +        # match corner orders to algorithm convention  | 
 | 43 | +        if len(bfl_vertex) != 3:  | 
 | 44 | +            msg = "The vertex {} is size {} instead of size 3".format(  | 
 | 45 | +                bfl_vertex, len(bfl_vertex)  | 
 | 46 | +            )  | 
 | 47 | +            raise ValueError(msg)  | 
 | 48 | + | 
 | 49 | +        x, y, z = bfl_vertex  | 
 | 50 | +        self.vertices = torch.tensor(  | 
 | 51 | +            [  | 
 | 52 | +                [x, y, z + spacing],  | 
 | 53 | +                [x + spacing, y, z + spacing],  | 
 | 54 | +                [x + spacing, y, z],  | 
 | 55 | +                [x, y, z],  | 
 | 56 | +                [x, y + spacing, z + spacing],  | 
 | 57 | +                [x + spacing, y + spacing, z + spacing],  | 
 | 58 | +                [x + spacing, y + spacing, z],  | 
 | 59 | +                [x, y + spacing, z],  | 
 | 60 | +            ]  | 
 | 61 | +        )  | 
 | 62 | + | 
 | 63 | +    def get_index(self, volume_data: torch.Tensor, isolevel: float) -> int:  | 
 | 64 | +        """  | 
 | 65 | +        Calculates the cube_index in the range 0-255 to index  | 
 | 66 | +        into EDGE_TABLE and FACE_TABLE  | 
 | 67 | +        Args:  | 
 | 68 | +            volume_data: the 3D scalar data  | 
 | 69 | +            isolevel: the isosurface value used as a threshold  | 
 | 70 | +                for determining whether a point is inside/outside  | 
 | 71 | +                the volume  | 
 | 72 | +        """  | 
 | 73 | +        cube_index = 0  | 
 | 74 | +        bit = 1  | 
 | 75 | +        for index in range(len(self.vertices)):  | 
 | 76 | +            vertex = self.vertices[index]  | 
 | 77 | +            value = _get_value(vertex, volume_data)  | 
 | 78 | +            if value < isolevel:  | 
 | 79 | +                cube_index |= bit  | 
 | 80 | +            bit *= 2  | 
 | 81 | +        return cube_index  | 
 | 82 | + | 
 | 83 | + | 
 | 84 | +def marching_cubes_naive(  | 
 | 85 | +    volume_data_batch: torch.Tensor,  | 
 | 86 | +    isolevel: Optional[float] = None,  | 
 | 87 | +    spacing: int = 1,  | 
 | 88 | +    return_local_coords: bool = True,  | 
 | 89 | +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:  | 
 | 90 | +    """  | 
 | 91 | +    Runs the classic marching cubes algorithm, iterating over  | 
 | 92 | +    the coordinates of the volume_data and using a given isolevel  | 
 | 93 | +    for determining intersected edges of cubes of size `spacing`.  | 
 | 94 | +    Returns vertices and faces of the obtained mesh.  | 
 | 95 | +    This operation is non-differentiable.  | 
 | 96 | +
  | 
 | 97 | +    This is a naive implementation, and is not optimized for efficiency.  | 
 | 98 | +
  | 
 | 99 | +    Args:  | 
 | 100 | +        volume_data_batch: a Tensor of size (N, D, H, W) corresponding to  | 
 | 101 | +            a batch of 3D scalar fields  | 
 | 102 | +        isolevel: the isosurface value to use as the threshold to determine  | 
 | 103 | +            whether points are within a volume. If None, then the average of the  | 
 | 104 | +            maximum and minimum value of the scalar field will be used.  | 
 | 105 | +        spacing: an integer specifying the cube size to use  | 
 | 106 | +        return_local_coords: bool. If True the output vertices will be in local coordinates in  | 
 | 107 | +        the range [-1, 1] x [-1, 1] x [-1, 1]. If False they will be in the range  | 
 | 108 | +        [0, W-1] x [0, H-1] x [0, D-1]  | 
 | 109 | +    Returns:  | 
 | 110 | +        verts: [(V_0, 3), (V_1, 3), ...] List of N FloatTensors of vertices.  | 
 | 111 | +        faces: [(F_0, 3), (F_1, 3), ...] List of N LongTensors of faces.  | 
 | 112 | +    """  | 
 | 113 | +    volume_data_batch = volume_data_batch.detach().cpu()  | 
 | 114 | +    batched_verts, batched_faces = [], []  | 
 | 115 | +    D, H, W = volume_data_batch.shape[1:]  | 
 | 116 | +    # pyre-ignore [16]  | 
 | 117 | +    volume_size_xyz = volume_data_batch.new_tensor([W, H, D])[None]  | 
 | 118 | + | 
 | 119 | +    if return_local_coords:  | 
 | 120 | +        # Convert from local coordinates in the range [-1, 1] range to  | 
 | 121 | +        # world coordinates in the range [0, D-1], [0, H-1], [0, W-1]  | 
 | 122 | +        local_to_world_transform = Translate(  | 
 | 123 | +            x=+1.0, y=+1.0, z=+1.0, device=volume_data_batch.device  | 
 | 124 | +        ).scale((volume_size_xyz - 1) * spacing * 0.5)  | 
 | 125 | +        # Perform the inverse to go from world to local  | 
 | 126 | +        world_to_local_transform = local_to_world_transform.inverse()  | 
 | 127 | + | 
 | 128 | +    for i in range(len(volume_data_batch)):  | 
 | 129 | +        volume_data = volume_data_batch[i]  | 
 | 130 | +        curr_isolevel = (  | 
 | 131 | +            ((volume_data.max() + volume_data.min()) / 2).item()  | 
 | 132 | +            if isolevel is None  | 
 | 133 | +            else isolevel  | 
 | 134 | +        )  | 
 | 135 | +        edge_vertices_to_index = {}  | 
 | 136 | +        vertex_coords_to_index = {}  | 
 | 137 | +        verts, faces = [], []  | 
 | 138 | +        # Use length - spacing for the bounds since we are using  | 
 | 139 | +        # cubes of size spacing, with the lowest x,y,z values  | 
 | 140 | +        # (bottom front left)  | 
 | 141 | +        for x in range(0, W - spacing, spacing):  | 
 | 142 | +            for y in range(0, H - spacing, spacing):  | 
 | 143 | +                for z in range(0, D - spacing, spacing):  | 
 | 144 | +                    cube = Cube((x, y, z), spacing)  | 
 | 145 | +                    new_verts, new_faces = polygonise(  | 
 | 146 | +                        cube,  | 
 | 147 | +                        curr_isolevel,  | 
 | 148 | +                        volume_data,  | 
 | 149 | +                        edge_vertices_to_index,  | 
 | 150 | +                        vertex_coords_to_index,  | 
 | 151 | +                    )  | 
 | 152 | +                    verts.extend(new_verts)  | 
 | 153 | +                    faces.extend(new_faces)  | 
 | 154 | +        if len(faces) > 0 and len(verts) > 0:  | 
 | 155 | +            verts = torch.tensor(verts, dtype=torch.float32)  | 
 | 156 | +            # Convert vertices from world to local coords  | 
 | 157 | +            if return_local_coords:  | 
 | 158 | +                verts = world_to_local_transform.transform_points(verts[None, ...])  | 
 | 159 | +                verts = verts.squeeze()  | 
 | 160 | +            batched_verts.append(verts)  | 
 | 161 | +            batched_faces.append(torch.tensor(faces, dtype=torch.int64))  | 
 | 162 | +    return batched_verts, batched_faces  | 
 | 163 | + | 
 | 164 | + | 
 | 165 | +def polygonise(  | 
 | 166 | +    cube: Cube,  | 
 | 167 | +    isolevel: float,  | 
 | 168 | +    volume_data: torch.Tensor,  | 
 | 169 | +    edge_vertices_to_index: Dict[Tuple[Tuple, Tuple], int],  | 
 | 170 | +    vertex_coords_to_index: Dict[Tuple[float, float, float], int],  | 
 | 171 | +) -> Tuple[list, list]:  | 
 | 172 | +    """  | 
 | 173 | +    Runs the classic marching cubes algorithm for one Cube in the volume.  | 
 | 174 | +    Returns the vertices and faces for the given cube.  | 
 | 175 | +
  | 
 | 176 | +    Args:  | 
 | 177 | +        cube: a Cube indicating the cube being examined for edges that intersect  | 
 | 178 | +            the volume data.  | 
 | 179 | +        isolevel: the isosurface value to use as the threshold to determine  | 
 | 180 | +            whether points are within a volume.  | 
 | 181 | +        volume_data: a Tensor of shape (D, H, W) corresponding to  | 
 | 182 | +            a 3D scalar field  | 
 | 183 | +        edge_vertices_to_index: A dictionary which maps an edge's two coordinates  | 
 | 184 | +            to the index of its interpolated point, if that interpolated point  | 
 | 185 | +            has already been used by a previous point  | 
 | 186 | +        vertex_coords_to_index: A dictionary mapping a point (x, y, z) to the corresponding  | 
 | 187 | +            index of that vertex, if that point has already been marked as a vertex.  | 
 | 188 | +    Returns:  | 
 | 189 | +        verts: List of triangle vertices for the given cube in the volume  | 
 | 190 | +        faces: List of triangle faces for the given cube in the volume  | 
 | 191 | +    """  | 
 | 192 | +    num_existing_verts = max(edge_vertices_to_index.values(), default=-1) + 1  | 
 | 193 | +    verts, faces = [], []  | 
 | 194 | +    cube_index = cube.get_index(volume_data, isolevel)  | 
 | 195 | +    edges = EDGE_TABLE[cube_index]  | 
 | 196 | +    edge_indices = _get_edge_indices(edges)  | 
 | 197 | +    if len(edge_indices) == 0:  | 
 | 198 | +        return [], []  | 
 | 199 | + | 
 | 200 | +    new_verts, edge_index_to_point_index = _calculate_interp_vertices(  | 
 | 201 | +        edge_indices,  | 
 | 202 | +        volume_data,  | 
 | 203 | +        cube,  | 
 | 204 | +        isolevel,  | 
 | 205 | +        edge_vertices_to_index,  | 
 | 206 | +        vertex_coords_to_index,  | 
 | 207 | +        num_existing_verts,  | 
 | 208 | +    )  | 
 | 209 | + | 
 | 210 | +    # Create faces  | 
 | 211 | +    face_triangles = FACE_TABLE[cube_index]  | 
 | 212 | +    for i in range(0, len(face_triangles), 3):  | 
 | 213 | +        tri1 = edge_index_to_point_index[face_triangles[i]]  | 
 | 214 | +        tri2 = edge_index_to_point_index[face_triangles[i + 1]]  | 
 | 215 | +        tri3 = edge_index_to_point_index[face_triangles[i + 2]]  | 
 | 216 | +        if tri1 != tri2 and tri2 != tri3 and tri1 != tri3:  | 
 | 217 | +            faces.append([tri1, tri2, tri3])  | 
 | 218 | + | 
 | 219 | +    verts += new_verts  | 
 | 220 | +    return verts, faces  | 
 | 221 | + | 
 | 222 | + | 
 | 223 | +def _get_edge_indices(edges: int) -> List[int]:  | 
 | 224 | +    """  | 
 | 225 | +    Finds which edge numbers are intersected given the bit representation  | 
 | 226 | +    detailed in marching_cubes_data.EDGE_TABLE.  | 
 | 227 | +
  | 
 | 228 | +    Args:  | 
 | 229 | +        edges: an integer corresponding to the value at cube_index  | 
 | 230 | +            from the EDGE_TABLE in marching_cubes_data.py  | 
 | 231 | +
  | 
 | 232 | +    Returns:  | 
 | 233 | +        edge_indices: A list of edge indices  | 
 | 234 | +    """  | 
 | 235 | +    if edges == 0:  | 
 | 236 | +        return []  | 
 | 237 | + | 
 | 238 | +    edge_indices = []  | 
 | 239 | +    for i in range(12):  | 
 | 240 | +        if edges & (2 ** i):  | 
 | 241 | +            edge_indices.append(i)  | 
 | 242 | +    return edge_indices  | 
 | 243 | + | 
 | 244 | + | 
 | 245 | +def _calculate_interp_vertices(  | 
 | 246 | +    edge_indices: List[int],  | 
 | 247 | +    volume_data: torch.Tensor,  | 
 | 248 | +    cube: Cube,  | 
 | 249 | +    isolevel: float,  | 
 | 250 | +    edge_vertices_to_index: Dict[Tuple[Tuple, Tuple], int],  | 
 | 251 | +    vertex_coords_to_index: Dict[Tuple[float, float, float], int],  | 
 | 252 | +    num_existing_verts: int,  | 
 | 253 | +) -> Tuple[List, Dict[int, int]]:  | 
 | 254 | +    """  | 
 | 255 | +    Finds the interpolated vertices for the intersected edges, either referencing  | 
 | 256 | +    previous calculations or newly calculating and storing the new interpolated  | 
 | 257 | +    points.  | 
 | 258 | +
  | 
 | 259 | +    Args:  | 
 | 260 | +        edge_indices: the numbers of the edges which are intersected. See the  | 
 | 261 | +            Cube class for more detail on the edge numbering convention.  | 
 | 262 | +        volume_data: a Tensor of size (D, H, W) corresponding to  | 
 | 263 | +            a 3D scalar field  | 
 | 264 | +        cube: a Cube indicating the cube being examined for edges that intersect  | 
 | 265 | +            the volume  | 
 | 266 | +        isolevel: the isosurface value to use as the threshold to determine  | 
 | 267 | +            whether points are within a volume.  | 
 | 268 | +        edge_vertices_to_index: A dictionary which maps an edge's two coordinates  | 
 | 269 | +            to the index of its interpolated point, if that interpolated point  | 
 | 270 | +            has already been used by a previous point  | 
 | 271 | +        vertex_coords_to_index: A dictionary mapping a point (x, y, z) to the corresponding  | 
 | 272 | +            index of that vertex, if that point has already been marked as a vertex.  | 
 | 273 | +        num_existing_verts: the number of vertices that have been found in previous  | 
 | 274 | +            calls to polygonise for the given volume_data in the above function, marching_cubes.  | 
 | 275 | +            This is equal to the 1 + the maximum value in edge_vertices_to_index.  | 
 | 276 | +    Returns:  | 
 | 277 | +        interp_points: a list of new interpolated points  | 
 | 278 | +        edge_index_to_point_index: a dictionary mapping an edge number to the index in the  | 
 | 279 | +            marching cubes' vertices list of the interpolated point on that edge. To be precise,  | 
 | 280 | +            it refers to the index within the vertices list after interp_points  | 
 | 281 | +            has been appended to the verts list constructed in the marching_cubes_naive  | 
 | 282 | +            function.  | 
 | 283 | +    """  | 
 | 284 | +    interp_points = []  | 
 | 285 | +    edge_index_to_point_index = {}  | 
 | 286 | +    for edge_index in edge_indices:  | 
 | 287 | +        v1, v2 = EDGE_TO_VERTICES[edge_index]  | 
 | 288 | +        point1, point2 = cube.vertices[v1], cube.vertices[v2]  | 
 | 289 | +        p_tuple1, p_tuple2 = tuple(point1.tolist()), tuple(point2.tolist())  | 
 | 290 | +        if (p_tuple1, p_tuple2) in edge_vertices_to_index:  | 
 | 291 | +            edge_index_to_point_index[edge_index] = edge_vertices_to_index[  | 
 | 292 | +                (p_tuple1, p_tuple2)  | 
 | 293 | +            ]  | 
 | 294 | +        else:  | 
 | 295 | +            val1, val2 = _get_value(point1, volume_data), _get_value(  | 
 | 296 | +                point2, volume_data  | 
 | 297 | +            )  | 
 | 298 | + | 
 | 299 | +            point = None  | 
 | 300 | +            if abs(isolevel - val1) < EPS:  | 
 | 301 | +                point = point1  | 
 | 302 | + | 
 | 303 | +            if abs(isolevel - val2) < EPS:  | 
 | 304 | +                point = point2  | 
 | 305 | + | 
 | 306 | +            if abs(val1 - val2) < EPS:  | 
 | 307 | +                point = point1  | 
 | 308 | + | 
 | 309 | +            if point is None:  | 
 | 310 | +                mu = (isolevel - val1) / (val2 - val1)  | 
 | 311 | +                x1, y1, z1 = point1  | 
 | 312 | +                x2, y2, z2 = point2  | 
 | 313 | +                x = x1 + mu * (x2 - x1)  | 
 | 314 | +                y = y1 + mu * (y2 - y1)  | 
 | 315 | +                z = z1 + mu * (z2 - z1)  | 
 | 316 | +            else:  | 
 | 317 | +                x, y, z = point  | 
 | 318 | + | 
 | 319 | +            x, y, z = x.item(), y.item(), z.item()  # for dictionary keys  | 
 | 320 | + | 
 | 321 | +            vert_index = None  | 
 | 322 | +            if (x, y, z) in vertex_coords_to_index:  | 
 | 323 | +                vert_index = vertex_coords_to_index[(x, y, z)]  | 
 | 324 | +            else:  | 
 | 325 | +                vert_index = num_existing_verts + len(interp_points)  | 
 | 326 | +                interp_points.append([x, y, z])  | 
 | 327 | +                vertex_coords_to_index[(x, y, z)] = vert_index  | 
 | 328 | + | 
 | 329 | +            edge_vertices_to_index[(p_tuple1, p_tuple2)] = vert_index  | 
 | 330 | +            edge_index_to_point_index[edge_index] = vert_index  | 
 | 331 | + | 
 | 332 | +    return interp_points, edge_index_to_point_index  | 
 | 333 | + | 
 | 334 | + | 
 | 335 | +def _get_value(point: Tuple[int, int, int], volume_data: torch.Tensor) -> float:  | 
 | 336 | +    """  | 
 | 337 | +    Gets the value at a given coordinate point in the scalar field.  | 
 | 338 | +
  | 
 | 339 | +    Args:  | 
 | 340 | +        point: data of shape (3) corresponding to an xyz coordinate.  | 
 | 341 | +        volume_data: a Tensor of size (D, H, W) corresponding to  | 
 | 342 | +            a 3D scalar field  | 
 | 343 | +    Returns:  | 
 | 344 | +        data: scalar value in the volume at the given point  | 
 | 345 | +    """  | 
 | 346 | +    x, y, z = point  | 
 | 347 | +    return volume_data[z][y][x]  | 
0 commit comments