diff --git a/shap_e/rendering/mesh.py b/shap_e/rendering/mesh.py index 7a0c2a1..f120ddf 100644 --- a/shap_e/rendering/mesh.py +++ b/shap_e/rendering/mesh.py @@ -89,17 +89,13 @@ class TriMesh: def write_obj(self, raw_f: BinaryIO): if self.has_vertex_colors(): - vertex_colors = np.stack([self.vertex_channels[x] - for x in "RGB"], axis=1) + vertex_colors = np.stack([self.vertex_channels[x] for x in "RGB"], axis=1) vertices = [ "{} {} {} {} {} {}".format(*coord, *color) for coord, color in zip(self.verts.tolist(), vertex_colors.tolist()) ] else: - vertices = [ - "{} {} {}".format(*coord) - for coord in self.verts.tolist() - ] + vertices = ["{} {} {}".format(*coord) for coord in self.verts.tolist()] faces = [ "f {} {} {}".format(str(tri[0] + 1), str(tri[1] + 1), str(tri[2] + 1))