a fork of shap-e for gc
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

201 lines
7.0 KiB

from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from shap_e.models.generation.transformer import Transformer
from shap_e.rendering.view_data import ProjectiveCamera
from shap_e.util.collections import AttrDict
from .base import VectorEncoder
class MultiviewTransformerEncoder(VectorEncoder):
"""
Encode cameras and views using a transformer model with extra output
token(s) used to extract a latent vector.
"""
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
param_shapes: Dict[str, Tuple[int]],
params_proj: Dict[str, Any],
latent_bottleneck: Optional[Dict[str, Any]] = None,
d_latent: int = 512,
latent_ctx: int = 1,
num_views: int = 20,
image_size: int = 256,
patch_size: int = 32,
use_depth: bool = False,
max_depth: float = 5.0,
width: int = 512,
layers: int = 12,
heads: int = 8,
init_scale: float = 0.25,
pos_emb_init_scale: float = 1.0,
):
super().__init__(
device=device,
param_shapes=param_shapes,
params_proj=params_proj,
latent_bottleneck=latent_bottleneck,
d_latent=d_latent,
)
self.num_views = num_views
self.image_size = image_size
self.patch_size = patch_size
self.use_depth = use_depth
self.max_depth = max_depth
self.n_ctx = num_views * (1 + (image_size // patch_size) ** 2)
self.latent_ctx = latent_ctx
self.width = width
assert d_latent % latent_ctx == 0
self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype)
self.backbone = Transformer(
device=device,
dtype=dtype,
n_ctx=self.n_ctx + latent_ctx,
width=width,
layers=layers,
heads=heads,
init_scale=init_scale,
)
self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
self.register_parameter(
"output_tokens",
nn.Parameter(torch.randn(latent_ctx, width, device=device, dtype=dtype)),
)
self.register_parameter(
"pos_emb",
nn.Parameter(
pos_emb_init_scale * torch.randn(self.n_ctx, width, device=device, dtype=dtype)
),
)
self.patch_emb = nn.Conv2d(
in_channels=3 if not use_depth else 4,
out_channels=width,
kernel_size=patch_size,
stride=patch_size,
device=device,
dtype=dtype,
)
self.camera_emb = nn.Sequential(
nn.Linear(
3 * 4 + 1, width, device=device, dtype=dtype
), # input size is for origin+x+y+z+fov
nn.GELU(),
nn.Linear(width, width, device=device, dtype=dtype),
)
self.output_proj = nn.Linear(width, d_latent // latent_ctx, device=device, dtype=dtype)
def encode_to_vector(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor:
_ = options
all_views = self.views_to_tensor(batch.views).to(self.device)
if self.use_depth:
all_views = torch.cat([all_views, self.depths_to_tensor(batch.depths)], dim=2)
all_cameras = self.cameras_to_tensor(batch.cameras).to(self.device)
batch_size, num_views, _, _, _ = all_views.shape
views_proj = self.patch_emb(
all_views.reshape([batch_size * num_views, *all_views.shape[2:]])
)
views_proj = (
views_proj.reshape([batch_size, num_views, self.width, -1])
.permute(0, 1, 3, 2)
.contiguous()
) # [batch_size x num_views x n_patches x width]
cameras_proj = self.camera_emb(all_cameras).reshape([batch_size, num_views, 1, self.width])
h = torch.cat([views_proj, cameras_proj], dim=2).reshape([batch_size, -1, self.width])
h = h + self.pos_emb
h = torch.cat([h, self.output_tokens[None].repeat(len(h), 1, 1)], dim=1)
h = self.ln_pre(h)
h = self.backbone(h)
h = self.ln_post(h)
h = h[:, self.n_ctx :]
h = self.output_proj(h).flatten(1)
return h
def views_to_tensor(self, views: Union[torch.Tensor, List[List[Image.Image]]]) -> torch.Tensor:
"""
Returns a [batch x num_views x 3 x size x size] tensor in the range [-1, 1].
"""
if isinstance(views, torch.Tensor):
return views
tensor_batch = []
for inner_list in views:
assert len(inner_list) == self.num_views
inner_batch = []
for img in inner_list:
img = img.resize((self.image_size,) * 2).convert("RGB")
inner_batch.append(
torch.from_numpy(np.array(img)).to(device=self.device, dtype=torch.float32)
/ 127.5
- 1
)
tensor_batch.append(torch.stack(inner_batch, dim=0))
return torch.stack(tensor_batch, dim=0).permute(0, 1, 4, 2, 3)
def depths_to_tensor(
self, depths: Union[torch.Tensor, List[List[Image.Image]]]
) -> torch.Tensor:
"""
Returns a [batch x num_views x 1 x size x size] tensor in the range [-1, 1].
"""
if isinstance(depths, torch.Tensor):
return depths
tensor_batch = []
for inner_list in depths:
assert len(inner_list) == self.num_views
inner_batch = []
for arr in inner_list:
tensor = torch.from_numpy(arr).clamp(max=self.max_depth) / self.max_depth
tensor = tensor * 2 - 1
tensor = F.interpolate(
tensor[None, None],
(self.image_size,) * 2,
mode="nearest",
)
inner_batch.append(tensor.to(device=self.device, dtype=torch.float32))
tensor_batch.append(torch.cat(inner_batch, dim=0))
return torch.stack(tensor_batch, dim=0)
def cameras_to_tensor(
self, cameras: Union[torch.Tensor, List[List[ProjectiveCamera]]]
) -> torch.Tensor:
"""
Returns a [batch x num_views x 3*4+1] tensor of camera information.
"""
if isinstance(cameras, torch.Tensor):
return cameras
outer_batch = []
for inner_list in cameras:
inner_batch = []
for camera in inner_list:
inner_batch.append(
np.array(
[
*camera.x,
*camera.y,
*camera.z,
*camera.origin,
camera.x_fov,
]
)
)
outer_batch.append(np.stack(inner_batch, axis=0))
return torch.from_numpy(np.stack(outer_batch, axis=0)).float()