a fork of shap-e for gc
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

202 lines
7.0 KiB

2 years ago
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from shap_e.models.generation.transformer import Transformer
from shap_e.rendering.view_data import ProjectiveCamera
from shap_e.util.collections import AttrDict
from .base import VectorEncoder
class MultiviewTransformerEncoder(VectorEncoder):
"""
Encode cameras and views using a transformer model with extra output
token(s) used to extract a latent vector.
"""
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
param_shapes: Dict[str, Tuple[int]],
params_proj: Dict[str, Any],
latent_bottleneck: Optional[Dict[str, Any]] = None,
d_latent: int = 512,
latent_ctx: int = 1,
num_views: int = 20,
image_size: int = 256,
patch_size: int = 32,
use_depth: bool = False,
max_depth: float = 5.0,
width: int = 512,
layers: int = 12,
heads: int = 8,
init_scale: float = 0.25,
pos_emb_init_scale: float = 1.0,
):
super().__init__(
device=device,
param_shapes=param_shapes,
params_proj=params_proj,
latent_bottleneck=latent_bottleneck,
d_latent=d_latent,
)
self.num_views = num_views
self.image_size = image_size
self.patch_size = patch_size
self.use_depth = use_depth
self.max_depth = max_depth
self.n_ctx = num_views * (1 + (image_size // patch_size) ** 2)
self.latent_ctx = latent_ctx
self.width = width
assert d_latent % latent_ctx == 0
self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype)
self.backbone = Transformer(
device=device,
dtype=dtype,
n_ctx=self.n_ctx + latent_ctx,
width=width,
layers=layers,
heads=heads,
init_scale=init_scale,
)
self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
self.register_parameter(
"output_tokens",
nn.Parameter(torch.randn(latent_ctx, width, device=device, dtype=dtype)),
)
self.register_parameter(
"pos_emb",
nn.Parameter(
pos_emb_init_scale * torch.randn(self.n_ctx, width, device=device, dtype=dtype)
),
)
self.patch_emb = nn.Conv2d(
in_channels=3 if not use_depth else 4,
out_channels=width,
kernel_size=patch_size,
stride=patch_size,
device=device,
dtype=dtype,
)
self.camera_emb = nn.Sequential(
nn.Linear(
3 * 4 + 1, width, device=device, dtype=dtype
), # input size is for origin+x+y+z+fov
nn.GELU(),
nn.Linear(width, width, device=device, dtype=dtype),
)
self.output_proj = nn.Linear(width, d_latent // latent_ctx, device=device, dtype=dtype)
def encode_to_vector(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor:
_ = options
all_views = self.views_to_tensor(batch.views).to(self.device)
if self.use_depth:
all_views = torch.cat([all_views, self.depths_to_tensor(batch.depths)], dim=2)
all_cameras = self.cameras_to_tensor(batch.cameras).to(self.device)
batch_size, num_views, _, _, _ = all_views.shape
views_proj = self.patch_emb(
all_views.reshape([batch_size * num_views, *all_views.shape[2:]])
)
views_proj = (
views_proj.reshape([batch_size, num_views, self.width, -1])
.permute(0, 1, 3, 2)
.contiguous()
) # [batch_size x num_views x n_patches x width]
cameras_proj = self.camera_emb(all_cameras).reshape([batch_size, num_views, 1, self.width])
h = torch.cat([views_proj, cameras_proj], dim=2).reshape([batch_size, -1, self.width])
h = h + self.pos_emb
h = torch.cat([h, self.output_tokens[None].repeat(len(h), 1, 1)], dim=1)
h = self.ln_pre(h)
h = self.backbone(h)
h = self.ln_post(h)
h = h[:, self.n_ctx :]
h = self.output_proj(h).flatten(1)
return h
def views_to_tensor(self, views: Union[torch.Tensor, List[List[Image.Image]]]) -> torch.Tensor:
"""
Returns a [batch x num_views x 3 x size x size] tensor in the range [-1, 1].
"""
if isinstance(views, torch.Tensor):
return views
tensor_batch = []
for inner_list in views:
assert len(inner_list) == self.num_views
inner_batch = []
for img in inner_list:
img = img.resize((self.image_size,) * 2).convert("RGB")
inner_batch.append(
torch.from_numpy(np.array(img)).to(device=self.device, dtype=torch.float32)
/ 127.5
- 1
)
tensor_batch.append(torch.stack(inner_batch, dim=0))
return torch.stack(tensor_batch, dim=0).permute(0, 1, 4, 2, 3)
def depths_to_tensor(
self, depths: Union[torch.Tensor, List[List[Image.Image]]]
) -> torch.Tensor:
"""
Returns a [batch x num_views x 1 x size x size] tensor in the range [-1, 1].
"""
if isinstance(depths, torch.Tensor):
return depths
tensor_batch = []
for inner_list in depths:
assert len(inner_list) == self.num_views
inner_batch = []
for arr in inner_list:
tensor = torch.from_numpy(arr).clamp(max=self.max_depth) / self.max_depth
tensor = tensor * 2 - 1
tensor = F.interpolate(
tensor[None, None],
(self.image_size,) * 2,
mode="nearest",
)
inner_batch.append(tensor.to(device=self.device, dtype=torch.float32))
tensor_batch.append(torch.cat(inner_batch, dim=0))
return torch.stack(tensor_batch, dim=0)
def cameras_to_tensor(
self, cameras: Union[torch.Tensor, List[List[ProjectiveCamera]]]
) -> torch.Tensor:
"""
Returns a [batch x num_views x 3*4+1] tensor of camera information.
"""
if isinstance(cameras, torch.Tensor):
return cameras
outer_batch = []
for inner_list in cameras:
inner_batch = []
for camera in inner_list:
inner_batch.append(
np.array(
[
*camera.x,
*camera.y,
*camera.z,
*camera.origin,
camera.x_fov,
]
)
)
outer_batch.append(np.stack(inner_batch, axis=0))
return torch.from_numpy(np.stack(outer_batch, axis=0)).float()