From 8625e7c15526d8510a2292f92165979268d0e945 Mon Sep 17 00:00:00 2001 From: Alex Nichol Date: Fri, 5 May 2023 13:14:59 -0700 Subject: [PATCH] example of creating meshes --- shap_e/examples/sample_text_to_3d.ipynb | 15 +++++++++++++++ shap_e/models/stf/renderer.py | 20 ++++++++++++-------- shap_e/util/notebooks.py | 16 ++++++++++++++++ 3 files changed, 43 insertions(+), 8 deletions(-) diff --git a/shap_e/examples/sample_text_to_3d.ipynb b/shap_e/examples/sample_text_to_3d.ipynb index 555579c..48615e8 100644 --- a/shap_e/examples/sample_text_to_3d.ipynb +++ b/shap_e/examples/sample_text_to_3d.ipynb @@ -80,6 +80,21 @@ " images = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)\n", " display(gif_widget(images))" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "85a4dce4", + "metadata": {}, + "outputs": [], + "source": [ + "# Example of saving the latents as meshes.\n", + "from shap_e.util.notebooks import decode_latent_mesh\n", + "\n", + "for i, latent in enumerate(latents):\n", + " with open(f'example_mesh_{i}.ply', 'wb') as f:\n", + " decode_latent_mesh(xm, latent).tri_mesh().write_ply(f)" + ] } ], "metadata": { diff --git a/shap_e/models/stf/renderer.py b/shap_e/models/stf/renderer.py index 8a27707..099de74 100644 --- a/shap_e/models/stf/renderer.py +++ b/shap_e/models/stf/renderer.py @@ -256,6 +256,16 @@ def render_views_from_stf( if output_srgb: tf_out.channels = _convert_srgb_to_linear(tf_out.channels) + # Make sure the raw meshes have colors. + with torch.autocast(device_type, enabled=False): + textures = tf_out.channels.float() + assert len(textures.shape) == 3 and textures.shape[-1] == len( + texture_channels + ), f"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}" + for m, texture in zip(raw_meshes, textures): + texture = texture[: len(m.verts)] + m.vertex_channels = {name: ch for name, ch in zip(texture_channels, texture.unbind(-1))} + args = dict( options=options, texture_channels=texture_channels, @@ -315,6 +325,8 @@ def _render_with_pytorch3d( raw_meshes: List[TorchMesh], tf_out: AttrDict, ): + _ = tf_out + # Lazy import because pytorch3d is installed lazily. from shap_e.rendering.pytorch3d_util import ( blender_uniform_lights, @@ -328,14 +340,6 @@ def _render_with_pytorch3d( device_type = device.type with torch.autocast(device_type, enabled=False): - textures = tf_out.channels.float() - assert len(textures.shape) == 3 and textures.shape[-1] == len( - texture_channels - ), f"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}" - for m, texture in zip(raw_meshes, textures): - texture = texture[: len(m.verts)] - m.vertex_channels = {name: ch for name, ch in zip(texture_channels, texture.unbind(-1))} - meshes = convert_meshes(raw_meshes) lights = blender_uniform_lights( diff --git a/shap_e/util/notebooks.py b/shap_e/util/notebooks.py index bfe479d..08ca9c2 100644 --- a/shap_e/util/notebooks.py +++ b/shap_e/util/notebooks.py @@ -9,6 +9,7 @@ from PIL import Image from shap_e.models.nn.camera import DifferentiableCameraBatch, DifferentiableProjectiveCamera from shap_e.models.transmitter.base import Transmitter, VectorDecoder +from shap_e.rendering.torch_mesh import TorchMesh from shap_e.util.collections import AttrDict @@ -60,6 +61,21 @@ def decode_latent_images( return [Image.fromarray(x) for x in arr] +@torch.no_grad() +def decode_latent_mesh( + xm: Union[Transmitter, VectorDecoder], + latent: torch.Tensor, +) -> TorchMesh: + decoded = xm.renderer.render_views( + AttrDict(cameras=create_pan_cameras(2, latent.device)), # lowest resolution possible + params=(xm.encoder if isinstance(xm, Transmitter) else xm).bottleneck_to_params( + latent[None] + ), + options=AttrDict(rendering_mode="stf", render_with_direction=False), + ) + return decoded.raw_meshes[0] + + def gif_widget(images): writer = io.BytesIO() images[0].save(