You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
254 lines
9.8 KiB
254 lines
9.8 KiB
2 years ago
|
from dataclasses import dataclass
|
||
|
from functools import lru_cache
|
||
|
from typing import Tuple
|
||
|
|
||
|
import torch
|
||
|
|
||
|
from ._mc_table import MC_TABLE
|
||
|
from .torch_mesh import TorchMesh
|
||
|
|
||
|
|
||
|
def marching_cubes(
|
||
|
field: torch.Tensor,
|
||
|
min_point: torch.Tensor,
|
||
|
size: torch.Tensor,
|
||
|
) -> TorchMesh:
|
||
|
"""
|
||
|
For a signed distance field, produce a mesh using marching cubes.
|
||
|
|
||
|
:param field: a 3D tensor of field values, where negative values correspond
|
||
|
to the outside of the shape. The dimensions correspond to the
|
||
|
x, y, and z directions, respectively.
|
||
|
:param min_point: a tensor of shape [3] containing the point corresponding
|
||
|
to (0, 0, 0) in the field.
|
||
|
:param size: a tensor of shape [3] containing the per-axis distance from the
|
||
|
(0, 0, 0) field corner and the (-1, -1, -1) field corner.
|
||
|
"""
|
||
|
assert len(field.shape) == 3, "input must be a 3D scalar field"
|
||
|
dev = field.device
|
||
|
|
||
|
grid_size = field.shape
|
||
|
grid_size_tensor = torch.tensor(grid_size).to(size)
|
||
|
lut = _lookup_table(dev)
|
||
|
|
||
|
# Create bitmasks between 0 and 255 (inclusive) indicating the state
|
||
|
# of the eight corners of each cube.
|
||
|
bitmasks = (field > 0).to(torch.uint8)
|
||
|
bitmasks = bitmasks[:-1, :, :] | (bitmasks[1:, :, :] << 1)
|
||
|
bitmasks = bitmasks[:, :-1, :] | (bitmasks[:, 1:, :] << 2)
|
||
|
bitmasks = bitmasks[:, :, :-1] | (bitmasks[:, :, 1:] << 4)
|
||
|
|
||
|
# Compute corner coordinates across the entire grid.
|
||
|
corner_coords = torch.empty(*grid_size, 3, device=dev, dtype=field.dtype)
|
||
|
corner_coords[range(grid_size[0]), :, :, 0] = torch.arange(
|
||
|
grid_size[0], device=dev, dtype=field.dtype
|
||
|
)[:, None, None]
|
||
|
corner_coords[:, range(grid_size[1]), :, 1] = torch.arange(
|
||
|
grid_size[1], device=dev, dtype=field.dtype
|
||
|
)[:, None]
|
||
|
corner_coords[:, :, range(grid_size[2]), 2] = torch.arange(
|
||
|
grid_size[2], device=dev, dtype=field.dtype
|
||
|
)
|
||
|
|
||
|
# Compute all vertices across all edges in the grid, even though we will
|
||
|
# throw some out later. We have (X-1)*Y*Z + X*(Y-1)*Z + X*Y*(Z-1) vertices.
|
||
|
# These are all midpoints, and don't account for interpolation (which is
|
||
|
# done later based on the used edge midpoints).
|
||
|
edge_midpoints = torch.cat(
|
||
|
[
|
||
|
((corner_coords[:-1] + corner_coords[1:]) / 2).reshape(-1, 3),
|
||
|
((corner_coords[:, :-1] + corner_coords[:, 1:]) / 2).reshape(-1, 3),
|
||
|
((corner_coords[:, :, :-1] + corner_coords[:, :, 1:]) / 2).reshape(-1, 3),
|
||
|
],
|
||
|
dim=0,
|
||
|
)
|
||
|
|
||
|
# Create a flat array of [X, Y, Z] indices for each cube.
|
||
|
cube_indices = torch.zeros(
|
||
|
grid_size[0] - 1, grid_size[1] - 1, grid_size[2] - 1, 3, device=dev, dtype=torch.long
|
||
|
)
|
||
|
cube_indices[range(grid_size[0] - 1), :, :, 0] = torch.arange(grid_size[0] - 1, device=dev)[
|
||
|
:, None, None
|
||
|
]
|
||
|
cube_indices[:, range(grid_size[1] - 1), :, 1] = torch.arange(grid_size[1] - 1, device=dev)[
|
||
|
:, None
|
||
|
]
|
||
|
cube_indices[:, :, range(grid_size[2] - 1), 2] = torch.arange(grid_size[2] - 1, device=dev)
|
||
|
flat_cube_indices = cube_indices.reshape(-1, 3)
|
||
|
|
||
|
# Create a flat array mapping each cube to 12 global edge indices.
|
||
|
edge_indices = _create_flat_edge_indices(flat_cube_indices, grid_size)
|
||
|
|
||
|
# Apply the LUT to figure out the triangles.
|
||
|
flat_bitmasks = bitmasks.reshape(
|
||
|
-1
|
||
|
).long() # must cast to long for indexing to believe this not a mask
|
||
|
local_tris = lut.cases[flat_bitmasks]
|
||
|
local_masks = lut.masks[flat_bitmasks]
|
||
|
# Compute the global edge indices for the triangles.
|
||
|
global_tris = torch.gather(
|
||
|
edge_indices, 1, local_tris.reshape(local_tris.shape[0], -1)
|
||
|
).reshape(local_tris.shape)
|
||
|
# Select the used triangles for each cube.
|
||
|
selected_tris = global_tris.reshape(-1, 3)[local_masks.reshape(-1)]
|
||
|
|
||
|
# Now we have a bunch of indices into the full list of possible vertices,
|
||
|
# but we want to reduce this list to only the used vertices.
|
||
|
used_vertex_indices = torch.unique(selected_tris.view(-1))
|
||
|
used_edge_midpoints = edge_midpoints[used_vertex_indices]
|
||
|
old_index_to_new_index = torch.zeros(len(edge_midpoints), device=dev, dtype=torch.long)
|
||
|
old_index_to_new_index[used_vertex_indices] = torch.arange(
|
||
|
len(used_vertex_indices), device=dev, dtype=torch.long
|
||
|
)
|
||
|
|
||
|
# Rewrite the triangles to use the new indices
|
||
|
selected_tris = torch.gather(old_index_to_new_index, 0, selected_tris.view(-1)).reshape(
|
||
|
selected_tris.shape
|
||
|
)
|
||
|
|
||
|
# Compute the actual interpolated coordinates corresponding to edge midpoints.
|
||
|
v1 = torch.floor(used_edge_midpoints).to(torch.long)
|
||
|
v2 = torch.ceil(used_edge_midpoints).to(torch.long)
|
||
|
s1 = field[v1[:, 0], v1[:, 1], v1[:, 2]]
|
||
|
s2 = field[v2[:, 0], v2[:, 1], v2[:, 2]]
|
||
|
p1 = (v1.float() / (grid_size_tensor - 1)) * size + min_point
|
||
|
p2 = (v2.float() / (grid_size_tensor - 1)) * size + min_point
|
||
|
# The signs of s1 and s2 should be different. We want to find
|
||
|
# t such that t*s2 + (1-t)*s1 = 0.
|
||
|
t = (s1 / (s1 - s2))[:, None]
|
||
|
verts = t * p2 + (1 - t) * p1
|
||
|
|
||
|
return TorchMesh(verts=verts, faces=selected_tris)
|
||
|
|
||
|
|
||
|
def _create_flat_edge_indices(
|
||
|
flat_cube_indices: torch.Tensor, grid_size: Tuple[int, int, int]
|
||
|
) -> torch.Tensor:
|
||
|
num_xs = (grid_size[0] - 1) * grid_size[1] * grid_size[2]
|
||
|
y_offset = num_xs
|
||
|
num_ys = grid_size[0] * (grid_size[1] - 1) * grid_size[2]
|
||
|
z_offset = num_xs + num_ys
|
||
|
return torch.stack(
|
||
|
[
|
||
|
# Edges spanning x-axis.
|
||
|
flat_cube_indices[:, 0] * grid_size[1] * grid_size[2]
|
||
|
+ flat_cube_indices[:, 1] * grid_size[2]
|
||
|
+ flat_cube_indices[:, 2],
|
||
|
flat_cube_indices[:, 0] * grid_size[1] * grid_size[2]
|
||
|
+ (flat_cube_indices[:, 1] + 1) * grid_size[2]
|
||
|
+ flat_cube_indices[:, 2],
|
||
|
flat_cube_indices[:, 0] * grid_size[1] * grid_size[2]
|
||
|
+ flat_cube_indices[:, 1] * grid_size[2]
|
||
|
+ flat_cube_indices[:, 2]
|
||
|
+ 1,
|
||
|
flat_cube_indices[:, 0] * grid_size[1] * grid_size[2]
|
||
|
+ (flat_cube_indices[:, 1] + 1) * grid_size[2]
|
||
|
+ flat_cube_indices[:, 2]
|
||
|
+ 1,
|
||
|
# Edges spanning y-axis.
|
||
|
(
|
||
|
y_offset
|
||
|
+ flat_cube_indices[:, 0] * (grid_size[1] - 1) * grid_size[2]
|
||
|
+ flat_cube_indices[:, 1] * grid_size[2]
|
||
|
+ flat_cube_indices[:, 2]
|
||
|
),
|
||
|
(
|
||
|
y_offset
|
||
|
+ (flat_cube_indices[:, 0] + 1) * (grid_size[1] - 1) * grid_size[2]
|
||
|
+ flat_cube_indices[:, 1] * grid_size[2]
|
||
|
+ flat_cube_indices[:, 2]
|
||
|
),
|
||
|
(
|
||
|
y_offset
|
||
|
+ flat_cube_indices[:, 0] * (grid_size[1] - 1) * grid_size[2]
|
||
|
+ flat_cube_indices[:, 1] * grid_size[2]
|
||
|
+ flat_cube_indices[:, 2]
|
||
|
+ 1
|
||
|
),
|
||
|
(
|
||
|
y_offset
|
||
|
+ (flat_cube_indices[:, 0] + 1) * (grid_size[1] - 1) * grid_size[2]
|
||
|
+ flat_cube_indices[:, 1] * grid_size[2]
|
||
|
+ flat_cube_indices[:, 2]
|
||
|
+ 1
|
||
|
),
|
||
|
# Edges spanning z-axis.
|
||
|
(
|
||
|
z_offset
|
||
|
+ flat_cube_indices[:, 0] * grid_size[1] * (grid_size[2] - 1)
|
||
|
+ flat_cube_indices[:, 1] * (grid_size[2] - 1)
|
||
|
+ flat_cube_indices[:, 2]
|
||
|
),
|
||
|
(
|
||
|
z_offset
|
||
|
+ (flat_cube_indices[:, 0] + 1) * grid_size[1] * (grid_size[2] - 1)
|
||
|
+ flat_cube_indices[:, 1] * (grid_size[2] - 1)
|
||
|
+ flat_cube_indices[:, 2]
|
||
|
),
|
||
|
(
|
||
|
z_offset
|
||
|
+ flat_cube_indices[:, 0] * grid_size[1] * (grid_size[2] - 1)
|
||
|
+ (flat_cube_indices[:, 1] + 1) * (grid_size[2] - 1)
|
||
|
+ flat_cube_indices[:, 2]
|
||
|
),
|
||
|
(
|
||
|
z_offset
|
||
|
+ (flat_cube_indices[:, 0] + 1) * grid_size[1] * (grid_size[2] - 1)
|
||
|
+ (flat_cube_indices[:, 1] + 1) * (grid_size[2] - 1)
|
||
|
+ flat_cube_indices[:, 2]
|
||
|
),
|
||
|
],
|
||
|
dim=-1,
|
||
|
)
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class McLookupTable:
|
||
|
# Coordinates in triangles are represented as edge indices from 0-12
|
||
|
# Here is an MC cell with both corner and edge indices marked.
|
||
|
# 6 + ---------- 3 ----------+ 7
|
||
|
# /| /|
|
||
|
# 6 | 7 |
|
||
|
# / | / |
|
||
|
# 4 +--------- 2 ------------+ 5 |
|
||
|
# | 10 | |
|
||
|
# | | | 11
|
||
|
# | | | |
|
||
|
# 8 | 2 9 | 3
|
||
|
# | +--------- 1 --------|---+
|
||
|
# | / | /
|
||
|
# | 4 | 5
|
||
|
# |/ |/
|
||
|
# +---------- 0 -----------+
|
||
|
# 0 1
|
||
|
cases: torch.Tensor # [256 x 5 x 3] long tensor
|
||
|
masks: torch.Tensor # [256 x 5] bool tensor
|
||
|
|
||
|
|
||
|
@lru_cache(maxsize=9) # if there's more than 8 GPUs and a CPU, don't bother caching
|
||
|
def _lookup_table(device: torch.device) -> McLookupTable:
|
||
|
cases = torch.zeros(256, 5, 3, device=device, dtype=torch.long)
|
||
|
masks = torch.zeros(256, 5, device=device, dtype=torch.bool)
|
||
|
|
||
|
edge_to_index = {
|
||
|
(0, 1): 0,
|
||
|
(2, 3): 1,
|
||
|
(4, 5): 2,
|
||
|
(6, 7): 3,
|
||
|
(0, 2): 4,
|
||
|
(1, 3): 5,
|
||
|
(4, 6): 6,
|
||
|
(5, 7): 7,
|
||
|
(0, 4): 8,
|
||
|
(1, 5): 9,
|
||
|
(2, 6): 10,
|
||
|
(3, 7): 11,
|
||
|
}
|
||
|
|
||
|
for i, case in enumerate(MC_TABLE):
|
||
|
for j, tri in enumerate(case):
|
||
|
for k, (c1, c2) in enumerate(zip(tri[::2], tri[1::2])):
|
||
|
cases[i, j, k] = edge_to_index[(c1, c2) if c1 < c2 else (c2, c1)]
|
||
|
masks[i, j] = True
|
||
|
return McLookupTable(cases=cases, masks=masks)
|