import base64 import io from typing import Union import ipywidgets as widgets import numpy as np import torch from PIL import Image from shap_e.models.nn.camera import DifferentiableCameraBatch, DifferentiableProjectiveCamera from shap_e.models.transmitter.base import Transmitter, VectorDecoder from shap_e.rendering.torch_mesh import TorchMesh from shap_e.util.collections import AttrDict def create_pan_cameras(size: int, device: torch.device) -> DifferentiableCameraBatch: origins = [] xs = [] ys = [] zs = [] for theta in np.linspace(0, 2 * np.pi, num=20): z = np.array([np.sin(theta), np.cos(theta), -0.5]) z /= np.sqrt(np.sum(z**2)) origin = -z * 4 x = np.array([np.cos(theta), -np.sin(theta), 0.0]) y = np.cross(z, x) origins.append(origin) xs.append(x) ys.append(y) zs.append(z) return DifferentiableCameraBatch( shape=(1, len(xs)), flat_camera=DifferentiableProjectiveCamera( origin=torch.from_numpy(np.stack(origins, axis=0)).float().to(device), x=torch.from_numpy(np.stack(xs, axis=0)).float().to(device), y=torch.from_numpy(np.stack(ys, axis=0)).float().to(device), z=torch.from_numpy(np.stack(zs, axis=0)).float().to(device), width=size, height=size, x_fov=0.7, y_fov=0.7, ), ) @torch.no_grad() def decode_latent_images( xm: Union[Transmitter, VectorDecoder], latent: torch.Tensor, cameras: DifferentiableCameraBatch, rendering_mode: str = "stf", ): decoded = xm.renderer.render_views( AttrDict(cameras=cameras), params=(xm.encoder if isinstance(xm, Transmitter) else xm).bottleneck_to_params( latent[None] ), options=AttrDict(rendering_mode=rendering_mode, render_with_direction=False), ) arr = decoded.channels.clamp(0, 255).to(torch.uint8)[0].cpu().numpy() return [Image.fromarray(x) for x in arr] @torch.no_grad() def decode_latent_mesh( xm: Union[Transmitter, VectorDecoder], latent: torch.Tensor, ) -> TorchMesh: decoded = xm.renderer.render_views( AttrDict(cameras=create_pan_cameras(2, latent.device)), # lowest resolution possible params=(xm.encoder if isinstance(xm, Transmitter) else xm).bottleneck_to_params( latent[None] ), options=AttrDict(rendering_mode="stf", render_with_direction=False), ) return decoded.raw_meshes[0] def gif_widget(images): writer = io.BytesIO() images[0].save( writer, format="GIF", save_all=True, append_images=images[1:], duration=100, loop=0 ) writer.seek(0) data = base64.b64encode(writer.read()).decode("ascii") return widgets.HTML(f'')