from dataclasses import dataclass from functools import lru_cache from typing import Tuple import torch from ._mc_table import MC_TABLE from .torch_mesh import TorchMesh def marching_cubes( field: torch.Tensor, min_point: torch.Tensor, size: torch.Tensor, ) -> TorchMesh: """ For a signed distance field, produce a mesh using marching cubes. :param field: a 3D tensor of field values, where negative values correspond to the outside of the shape. The dimensions correspond to the x, y, and z directions, respectively. :param min_point: a tensor of shape [3] containing the point corresponding to (0, 0, 0) in the field. :param size: a tensor of shape [3] containing the per-axis distance from the (0, 0, 0) field corner and the (-1, -1, -1) field corner. """ assert len(field.shape) == 3, "input must be a 3D scalar field" dev = field.device grid_size = field.shape grid_size_tensor = torch.tensor(grid_size).to(size) lut = _lookup_table(dev) # Create bitmasks between 0 and 255 (inclusive) indicating the state # of the eight corners of each cube. bitmasks = (field > 0).to(torch.uint8) bitmasks = bitmasks[:-1, :, :] | (bitmasks[1:, :, :] << 1) bitmasks = bitmasks[:, :-1, :] | (bitmasks[:, 1:, :] << 2) bitmasks = bitmasks[:, :, :-1] | (bitmasks[:, :, 1:] << 4) # Compute corner coordinates across the entire grid. corner_coords = torch.empty(*grid_size, 3, device=dev, dtype=field.dtype) corner_coords[range(grid_size[0]), :, :, 0] = torch.arange( grid_size[0], device=dev, dtype=field.dtype )[:, None, None] corner_coords[:, range(grid_size[1]), :, 1] = torch.arange( grid_size[1], device=dev, dtype=field.dtype )[:, None] corner_coords[:, :, range(grid_size[2]), 2] = torch.arange( grid_size[2], device=dev, dtype=field.dtype ) # Compute all vertices across all edges in the grid, even though we will # throw some out later. We have (X-1)*Y*Z + X*(Y-1)*Z + X*Y*(Z-1) vertices. # These are all midpoints, and don't account for interpolation (which is # done later based on the used edge midpoints). edge_midpoints = torch.cat( [ ((corner_coords[:-1] + corner_coords[1:]) / 2).reshape(-1, 3), ((corner_coords[:, :-1] + corner_coords[:, 1:]) / 2).reshape(-1, 3), ((corner_coords[:, :, :-1] + corner_coords[:, :, 1:]) / 2).reshape(-1, 3), ], dim=0, ) # Create a flat array of [X, Y, Z] indices for each cube. cube_indices = torch.zeros( grid_size[0] - 1, grid_size[1] - 1, grid_size[2] - 1, 3, device=dev, dtype=torch.long ) cube_indices[range(grid_size[0] - 1), :, :, 0] = torch.arange(grid_size[0] - 1, device=dev)[ :, None, None ] cube_indices[:, range(grid_size[1] - 1), :, 1] = torch.arange(grid_size[1] - 1, device=dev)[ :, None ] cube_indices[:, :, range(grid_size[2] - 1), 2] = torch.arange(grid_size[2] - 1, device=dev) flat_cube_indices = cube_indices.reshape(-1, 3) # Create a flat array mapping each cube to 12 global edge indices. edge_indices = _create_flat_edge_indices(flat_cube_indices, grid_size) # Apply the LUT to figure out the triangles. flat_bitmasks = bitmasks.reshape( -1 ).long() # must cast to long for indexing to believe this not a mask local_tris = lut.cases[flat_bitmasks] local_masks = lut.masks[flat_bitmasks] # Compute the global edge indices for the triangles. global_tris = torch.gather( edge_indices, 1, local_tris.reshape(local_tris.shape[0], -1) ).reshape(local_tris.shape) # Select the used triangles for each cube. selected_tris = global_tris.reshape(-1, 3)[local_masks.reshape(-1)] # Now we have a bunch of indices into the full list of possible vertices, # but we want to reduce this list to only the used vertices. used_vertex_indices = torch.unique(selected_tris.view(-1)) used_edge_midpoints = edge_midpoints[used_vertex_indices] old_index_to_new_index = torch.zeros(len(edge_midpoints), device=dev, dtype=torch.long) old_index_to_new_index[used_vertex_indices] = torch.arange( len(used_vertex_indices), device=dev, dtype=torch.long ) # Rewrite the triangles to use the new indices selected_tris = torch.gather(old_index_to_new_index, 0, selected_tris.view(-1)).reshape( selected_tris.shape ) # Compute the actual interpolated coordinates corresponding to edge midpoints. v1 = torch.floor(used_edge_midpoints).to(torch.long) v2 = torch.ceil(used_edge_midpoints).to(torch.long) s1 = field[v1[:, 0], v1[:, 1], v1[:, 2]] s2 = field[v2[:, 0], v2[:, 1], v2[:, 2]] p1 = (v1.float() / (grid_size_tensor - 1)) * size + min_point p2 = (v2.float() / (grid_size_tensor - 1)) * size + min_point # The signs of s1 and s2 should be different. We want to find # t such that t*s2 + (1-t)*s1 = 0. t = (s1 / (s1 - s2))[:, None] verts = t * p2 + (1 - t) * p1 return TorchMesh(verts=verts, faces=selected_tris) def _create_flat_edge_indices( flat_cube_indices: torch.Tensor, grid_size: Tuple[int, int, int] ) -> torch.Tensor: num_xs = (grid_size[0] - 1) * grid_size[1] * grid_size[2] y_offset = num_xs num_ys = grid_size[0] * (grid_size[1] - 1) * grid_size[2] z_offset = num_xs + num_ys return torch.stack( [ # Edges spanning x-axis. flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] + flat_cube_indices[:, 1] * grid_size[2] + flat_cube_indices[:, 2], flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] + (flat_cube_indices[:, 1] + 1) * grid_size[2] + flat_cube_indices[:, 2], flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] + flat_cube_indices[:, 1] * grid_size[2] + flat_cube_indices[:, 2] + 1, flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] + (flat_cube_indices[:, 1] + 1) * grid_size[2] + flat_cube_indices[:, 2] + 1, # Edges spanning y-axis. ( y_offset + flat_cube_indices[:, 0] * (grid_size[1] - 1) * grid_size[2] + flat_cube_indices[:, 1] * grid_size[2] + flat_cube_indices[:, 2] ), ( y_offset + (flat_cube_indices[:, 0] + 1) * (grid_size[1] - 1) * grid_size[2] + flat_cube_indices[:, 1] * grid_size[2] + flat_cube_indices[:, 2] ), ( y_offset + flat_cube_indices[:, 0] * (grid_size[1] - 1) * grid_size[2] + flat_cube_indices[:, 1] * grid_size[2] + flat_cube_indices[:, 2] + 1 ), ( y_offset + (flat_cube_indices[:, 0] + 1) * (grid_size[1] - 1) * grid_size[2] + flat_cube_indices[:, 1] * grid_size[2] + flat_cube_indices[:, 2] + 1 ), # Edges spanning z-axis. ( z_offset + flat_cube_indices[:, 0] * grid_size[1] * (grid_size[2] - 1) + flat_cube_indices[:, 1] * (grid_size[2] - 1) + flat_cube_indices[:, 2] ), ( z_offset + (flat_cube_indices[:, 0] + 1) * grid_size[1] * (grid_size[2] - 1) + flat_cube_indices[:, 1] * (grid_size[2] - 1) + flat_cube_indices[:, 2] ), ( z_offset + flat_cube_indices[:, 0] * grid_size[1] * (grid_size[2] - 1) + (flat_cube_indices[:, 1] + 1) * (grid_size[2] - 1) + flat_cube_indices[:, 2] ), ( z_offset + (flat_cube_indices[:, 0] + 1) * grid_size[1] * (grid_size[2] - 1) + (flat_cube_indices[:, 1] + 1) * (grid_size[2] - 1) + flat_cube_indices[:, 2] ), ], dim=-1, ) @dataclass class McLookupTable: # Coordinates in triangles are represented as edge indices from 0-12 # Here is an MC cell with both corner and edge indices marked. # 6 + ---------- 3 ----------+ 7 # /| /| # 6 | 7 | # / | / | # 4 +--------- 2 ------------+ 5 | # | 10 | | # | | | 11 # | | | | # 8 | 2 9 | 3 # | +--------- 1 --------|---+ # | / | / # | 4 | 5 # |/ |/ # +---------- 0 -----------+ # 0 1 cases: torch.Tensor # [256 x 5 x 3] long tensor masks: torch.Tensor # [256 x 5] bool tensor @lru_cache(maxsize=9) # if there's more than 8 GPUs and a CPU, don't bother caching def _lookup_table(device: torch.device) -> McLookupTable: cases = torch.zeros(256, 5, 3, device=device, dtype=torch.long) masks = torch.zeros(256, 5, device=device, dtype=torch.bool) edge_to_index = { (0, 1): 0, (2, 3): 1, (4, 5): 2, (6, 7): 3, (0, 2): 4, (1, 3): 5, (4, 6): 6, (5, 7): 7, (0, 4): 8, (1, 5): 9, (2, 6): 10, (3, 7): 11, } for i, case in enumerate(MC_TABLE): for j, tri in enumerate(case): for k, (c1, c2) in enumerate(zip(tri[::2], tri[1::2])): cases[i, j, k] = edge_to_index[(c1, c2) if c1 < c2 else (c2, c1)] masks[i, j] = True return McLookupTable(cases=cases, masks=masks)