from dataclasses import dataclass, field from typing import Dict, Optional import torch from .mesh import TriMesh @dataclass class TorchMesh: """ A 3D triangle mesh with optional data at the vertices and faces. """ # [N x 3] array of vertex coordinates. verts: torch.Tensor # [M x 3] array of triangles, pointing to indices in verts. faces: torch.Tensor # Extra data per vertex and face. vertex_channels: Optional[Dict[str, torch.Tensor]] = field(default_factory=dict) face_channels: Optional[Dict[str, torch.Tensor]] = field(default_factory=dict) def tri_mesh(self) -> TriMesh: """ Create a CPU version of the mesh. """ return TriMesh( verts=self.verts.detach().cpu().numpy(), faces=self.faces.cpu().numpy(), vertex_channels=( {k: v.detach().cpu().numpy() for k, v in self.vertex_channels.items()} if self.vertex_channels is not None else None ), face_channels=( {k: v.detach().cpu().numpy() for k, v in self.face_channels.items()} if self.face_channels is not None else None ), )