Browse Source

example of creating meshes

main
Alex Nichol 2 years ago
parent
commit
8625e7c155
  1. 15
      shap_e/examples/sample_text_to_3d.ipynb
  2. 20
      shap_e/models/stf/renderer.py
  3. 16
      shap_e/util/notebooks.py

15
shap_e/examples/sample_text_to_3d.ipynb

@ -80,6 +80,21 @@
" images = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)\n", " images = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)\n",
" display(gif_widget(images))" " display(gif_widget(images))"
] ]
},
{
"cell_type": "code",
"execution_count": null,
"id": "85a4dce4",
"metadata": {},
"outputs": [],
"source": [
"# Example of saving the latents as meshes.\n",
"from shap_e.util.notebooks import decode_latent_mesh\n",
"\n",
"for i, latent in enumerate(latents):\n",
" with open(f'example_mesh_{i}.ply', 'wb') as f:\n",
" decode_latent_mesh(xm, latent).tri_mesh().write_ply(f)"
]
} }
], ],
"metadata": { "metadata": {

20
shap_e/models/stf/renderer.py

@ -256,6 +256,16 @@ def render_views_from_stf(
if output_srgb: if output_srgb:
tf_out.channels = _convert_srgb_to_linear(tf_out.channels) tf_out.channels = _convert_srgb_to_linear(tf_out.channels)
# Make sure the raw meshes have colors.
with torch.autocast(device_type, enabled=False):
textures = tf_out.channels.float()
assert len(textures.shape) == 3 and textures.shape[-1] == len(
texture_channels
), f"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}"
for m, texture in zip(raw_meshes, textures):
texture = texture[: len(m.verts)]
m.vertex_channels = {name: ch for name, ch in zip(texture_channels, texture.unbind(-1))}
args = dict( args = dict(
options=options, options=options,
texture_channels=texture_channels, texture_channels=texture_channels,
@ -315,6 +325,8 @@ def _render_with_pytorch3d(
raw_meshes: List[TorchMesh], raw_meshes: List[TorchMesh],
tf_out: AttrDict, tf_out: AttrDict,
): ):
_ = tf_out
# Lazy import because pytorch3d is installed lazily. # Lazy import because pytorch3d is installed lazily.
from shap_e.rendering.pytorch3d_util import ( from shap_e.rendering.pytorch3d_util import (
blender_uniform_lights, blender_uniform_lights,
@ -328,14 +340,6 @@ def _render_with_pytorch3d(
device_type = device.type device_type = device.type
with torch.autocast(device_type, enabled=False): with torch.autocast(device_type, enabled=False):
textures = tf_out.channels.float()
assert len(textures.shape) == 3 and textures.shape[-1] == len(
texture_channels
), f"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}"
for m, texture in zip(raw_meshes, textures):
texture = texture[: len(m.verts)]
m.vertex_channels = {name: ch for name, ch in zip(texture_channels, texture.unbind(-1))}
meshes = convert_meshes(raw_meshes) meshes = convert_meshes(raw_meshes)
lights = blender_uniform_lights( lights = blender_uniform_lights(

16
shap_e/util/notebooks.py

@ -9,6 +9,7 @@ from PIL import Image
from shap_e.models.nn.camera import DifferentiableCameraBatch, DifferentiableProjectiveCamera from shap_e.models.nn.camera import DifferentiableCameraBatch, DifferentiableProjectiveCamera
from shap_e.models.transmitter.base import Transmitter, VectorDecoder from shap_e.models.transmitter.base import Transmitter, VectorDecoder
from shap_e.rendering.torch_mesh import TorchMesh
from shap_e.util.collections import AttrDict from shap_e.util.collections import AttrDict
@ -60,6 +61,21 @@ def decode_latent_images(
return [Image.fromarray(x) for x in arr] return [Image.fromarray(x) for x in arr]
@torch.no_grad()
def decode_latent_mesh(
xm: Union[Transmitter, VectorDecoder],
latent: torch.Tensor,
) -> TorchMesh:
decoded = xm.renderer.render_views(
AttrDict(cameras=create_pan_cameras(2, latent.device)), # lowest resolution possible
params=(xm.encoder if isinstance(xm, Transmitter) else xm).bottleneck_to_params(
latent[None]
),
options=AttrDict(rendering_mode="stf", render_with_direction=False),
)
return decoded.raw_meshes[0]
def gif_widget(images): def gif_widget(images):
writer = io.BytesIO() writer = io.BytesIO()
images[0].save( images[0].save(

Loading…
Cancel
Save