a fork of shap-e for gc
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

427 lines
15 KiB

2 years ago
from abc import abstractmethod
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torch import torch
from shap_e.models.generation.perceiver import SimplePerceiver
from shap_e.models.generation.transformer import Transformer
from shap_e.models.nn.encoding import PosEmbLinear
from shap_e.rendering.view_data import ProjectiveCamera
from shap_e.util.collections import AttrDict
from .base import VectorEncoder
from .channels_encoder import DatasetIterator, sample_pcl_fps
class PointCloudTransformerEncoder(VectorEncoder):
"""
Encode point clouds using a transformer model with an extra output
token used to extract a latent vector.
"""
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
param_shapes: Dict[str, Tuple[int]],
params_proj: Dict[str, Any],
latent_bottleneck: Optional[Dict[str, Any]] = None,
d_latent: int = 512,
latent_ctx: int = 1,
input_channels: int = 6,
n_ctx: int = 1024,
width: int = 512,
layers: int = 12,
heads: int = 8,
init_scale: float = 0.25,
pos_emb: Optional[str] = None,
):
super().__init__(
device=device,
param_shapes=param_shapes,
params_proj=params_proj,
latent_bottleneck=latent_bottleneck,
d_latent=d_latent,
)
self.input_channels = input_channels
self.n_ctx = n_ctx
self.latent_ctx = latent_ctx
assert d_latent % latent_ctx == 0
self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype)
self.backbone = Transformer(
device=device,
dtype=dtype,
n_ctx=n_ctx + latent_ctx,
width=width,
layers=layers,
heads=heads,
init_scale=init_scale,
)
self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
self.register_parameter(
"output_tokens",
nn.Parameter(torch.randn(latent_ctx, width, device=device, dtype=dtype)),
)
self.input_proj = PosEmbLinear(pos_emb, input_channels, width, device=device, dtype=dtype)
self.output_proj = nn.Linear(width, d_latent // latent_ctx, device=device, dtype=dtype)
def encode_to_vector(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor:
_ = options
points = batch.points.permute(0, 2, 1) # NCL -> NLC
h = self.input_proj(points)
h = torch.cat([h, self.output_tokens[None].repeat(len(h), 1, 1)], dim=1)
h = self.ln_pre(h)
h = self.backbone(h)
h = self.ln_post(h)
h = h[:, self.n_ctx :]
h = self.output_proj(h).flatten(1)
return h
class PerceiverEncoder(VectorEncoder):
"""
Encode point clouds using a perceiver model with an extra output
token used to extract a latent vector.
"""
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
param_shapes: Dict[str, Tuple[int]],
params_proj: Dict[str, Any],
latent_bottleneck: Optional[Dict[str, Any]] = None,
d_latent: int = 512,
latent_ctx: int = 1,
width: int = 512,
layers: int = 12,
xattn_layers: int = 1,
heads: int = 8,
init_scale: float = 0.25,
# Training hparams
inner_batch_size: int = 1,
data_ctx: int = 1,
min_unrolls: int,
max_unrolls: int,
):
super().__init__(
device=device,
param_shapes=param_shapes,
params_proj=params_proj,
latent_bottleneck=latent_bottleneck,
d_latent=d_latent,
)
self.width = width
self.device = device
self.dtype = dtype
self.latent_ctx = latent_ctx
self.inner_batch_size = inner_batch_size
self.data_ctx = data_ctx
self.min_unrolls = min_unrolls
self.max_unrolls = max_unrolls
self.encoder = SimplePerceiver(
device=device,
dtype=dtype,
n_ctx=self.data_ctx + self.latent_ctx,
n_data=self.inner_batch_size,
width=width,
layers=xattn_layers,
heads=heads,
init_scale=init_scale,
)
self.processor = Transformer(
device=device,
dtype=dtype,
n_ctx=self.data_ctx + self.latent_ctx,
layers=layers - xattn_layers,
width=width,
heads=heads,
init_scale=init_scale,
)
self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype)
self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
self.register_parameter(
"output_tokens",
nn.Parameter(torch.randn(self.latent_ctx, width, device=device, dtype=dtype)),
)
self.output_proj = nn.Linear(width, d_latent // self.latent_ctx, device=device, dtype=dtype)
@abstractmethod
def get_h_and_iterator(
self, batch: AttrDict, options: Optional[AttrDict] = None
) -> Tuple[torch.Tensor, Iterable]:
"""
:return: a tuple of (
the initial output tokens of size [batch_size, data_ctx + latent_ctx, width],
an iterator over the given data
)
"""
def encode_to_vector(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor:
h, it = self.get_h_and_iterator(batch, options=options)
n_unrolls = self.get_n_unrolls()
for _ in range(n_unrolls):
data = next(it)
h = self.encoder(h, data)
h = self.processor(h)
h = self.output_proj(self.ln_post(h[:, -self.latent_ctx :]))
return h.flatten(1)
def get_n_unrolls(self):
if self.training:
n_unrolls = torch.randint(
self.min_unrolls, self.max_unrolls + 1, size=(), device=self.device
)
dist.broadcast(n_unrolls, 0)
n_unrolls = n_unrolls.item()
else:
n_unrolls = self.max_unrolls
return n_unrolls
class PointCloudPerceiverEncoder(PerceiverEncoder):
"""
Encode point clouds using a transformer model with an extra output
token used to extract a latent vector.
"""
def __init__(
self,
*,
cross_attention_dataset: str = "pcl",
fps_method: str = "fps",
# point cloud hyperparameters
input_channels: int = 6,
pos_emb: Optional[str] = None,
# multiview hyperparameters
image_size: int = 256,
patch_size: int = 32,
pose_dropout: float = 0.0,
use_depth: bool = False,
max_depth: float = 5.0,
# other hyperparameters
**kwargs,
):
super().__init__(**kwargs)
assert cross_attention_dataset in ("pcl", "multiview")
assert fps_method in ("fps", "first")
self.cross_attention_dataset = cross_attention_dataset
self.fps_method = fps_method
self.input_channels = input_channels
self.input_proj = PosEmbLinear(
pos_emb, input_channels, self.width, device=self.device, dtype=self.dtype
)
if self.cross_attention_dataset == "multiview":
self.image_size = image_size
self.patch_size = patch_size
self.pose_dropout = pose_dropout
self.use_depth = use_depth
self.max_depth = max_depth
pos_ctx = (image_size // patch_size) ** 2
self.register_parameter(
"pos_emb",
nn.Parameter(
torch.randn(
pos_ctx * self.inner_batch_size,
self.width,
device=self.device,
dtype=self.dtype,
)
),
)
self.patch_emb = nn.Conv2d(
in_channels=3 if not use_depth else 4,
out_channels=self.width,
kernel_size=patch_size,
stride=patch_size,
device=self.device,
dtype=self.dtype,
)
self.camera_emb = nn.Sequential(
nn.Linear(
3 * 4 + 1, self.width, device=self.device, dtype=self.dtype
), # input size is for origin+x+y+z+fov
nn.GELU(),
nn.Linear(self.width, 2 * self.width, device=self.device, dtype=self.dtype),
)
def get_h_and_iterator(
self, batch: AttrDict, options: Optional[AttrDict] = None
) -> Tuple[torch.Tensor, Iterable]:
"""
:return: a tuple of (
the initial output tokens of size [batch_size, data_ctx + latent_ctx, width],
an iterator over the given data
)
"""
options = AttrDict() if options is None else options
# Build the initial query embeddings
points = batch.points.permute(0, 2, 1) # NCL -> NLC
fps_samples = self.sample_pcl_fps(points)
batch_size = points.shape[0]
data_tokens = self.input_proj(fps_samples)
latent_tokens = self.output_tokens.unsqueeze(0).repeat(batch_size, 1, 1)
h = self.ln_pre(torch.cat([data_tokens, latent_tokens], dim=1))
assert h.shape == (batch_size, self.data_ctx + self.latent_ctx, self.width)
# Build the dataset embedding iterator
dataset_fn = {
"pcl": self.get_pcl_dataset,
"multiview": self.get_multiview_dataset,
}[self.cross_attention_dataset]
it = dataset_fn(batch, options=options)
return h, it
def sample_pcl_fps(self, points: torch.Tensor) -> torch.Tensor:
return sample_pcl_fps(points, data_ctx=self.data_ctx, method=self.fps_method)
def get_pcl_dataset(
self, batch: AttrDict, options: Optional[AttrDict[str, Any]] = None
) -> Iterable:
_ = options
dataset_emb = self.input_proj(batch.points.permute(0, 2, 1)) # NCL -> NLC
assert dataset_emb.shape[1] >= self.inner_batch_size
return iter(DatasetIterator(dataset_emb, batch_size=self.inner_batch_size))
def get_multiview_dataset(
self, batch: AttrDict, options: Optional[AttrDict] = None
) -> Iterable:
_ = options
dataset_emb = self.encode_views(batch)
batch_size, num_views, n_patches, width = dataset_emb.shape
assert num_views >= self.inner_batch_size
it = iter(DatasetIterator(dataset_emb, batch_size=self.inner_batch_size))
def gen():
while True:
examples = next(it)
assert examples.shape == (batch_size, self.inner_batch_size, n_patches, self.width)
views = examples.reshape(batch_size, -1, width) + self.pos_emb
yield views
return gen()
def encode_views(self, batch: AttrDict) -> torch.Tensor:
"""
:return: [batch_size, num_views, n_patches, width]
"""
all_views = self.views_to_tensor(batch.views).to(self.device)
if self.use_depth:
all_views = torch.cat([all_views, self.depths_to_tensor(batch.depths)], dim=2)
all_cameras = self.cameras_to_tensor(batch.cameras).to(self.device)
batch_size, num_views, _, _, _ = all_views.shape
views_proj = self.patch_emb(
all_views.reshape([batch_size * num_views, *all_views.shape[2:]])
)
views_proj = (
views_proj.reshape([batch_size, num_views, self.width, -1])
.permute(0, 1, 3, 2)
.contiguous()
) # [batch_size x num_views x n_patches x width]
# [batch_size, num_views, 1, 2 * width]
camera_proj = self.camera_emb(all_cameras).reshape(
[batch_size, num_views, 1, self.width * 2]
)
pose_dropout = self.pose_dropout if self.training else 0.0
mask = torch.rand(batch_size, 1, 1, 1, device=views_proj.device) >= pose_dropout
camera_proj = torch.where(mask, camera_proj, torch.zeros_like(camera_proj))
scale, shift = camera_proj.chunk(2, dim=3)
views_proj = views_proj * (scale + 1.0) + shift
return views_proj
def views_to_tensor(self, views: Union[torch.Tensor, List[List[Image.Image]]]) -> torch.Tensor:
"""
Returns a [batch x num_views x 3 x size x size] tensor in the range [-1, 1].
"""
if isinstance(views, torch.Tensor):
return views
tensor_batch = []
num_views = len(views[0])
for inner_list in views:
assert len(inner_list) == num_views
inner_batch = []
for img in inner_list:
img = img.resize((self.image_size,) * 2).convert("RGB")
inner_batch.append(
torch.from_numpy(np.array(img)).to(device=self.device, dtype=torch.float32)
/ 127.5
- 1
)
tensor_batch.append(torch.stack(inner_batch, dim=0))
return torch.stack(tensor_batch, dim=0).permute(0, 1, 4, 2, 3)
def depths_to_tensor(
self, depths: Union[torch.Tensor, List[List[Image.Image]]]
) -> torch.Tensor:
"""
Returns a [batch x num_views x 1 x size x size] tensor in the range [-1, 1].
"""
if isinstance(depths, torch.Tensor):
return depths
tensor_batch = []
num_views = len(depths[0])
for inner_list in depths:
assert len(inner_list) == num_views
inner_batch = []
for arr in inner_list:
tensor = torch.from_numpy(arr).clamp(max=self.max_depth) / self.max_depth
tensor = tensor * 2 - 1
tensor = F.interpolate(
tensor[None, None],
(self.image_size,) * 2,
mode="nearest",
)
inner_batch.append(tensor.to(device=self.device, dtype=torch.float32))
tensor_batch.append(torch.cat(inner_batch, dim=0))
return torch.stack(tensor_batch, dim=0)
def cameras_to_tensor(
self, cameras: Union[torch.Tensor, List[List[ProjectiveCamera]]]
) -> torch.Tensor:
"""
Returns a [batch x num_views x 3*4+1] tensor of camera information.
"""
if isinstance(cameras, torch.Tensor):
return cameras
outer_batch = []
for inner_list in cameras:
inner_batch = []
for camera in inner_list:
inner_batch.append(
np.array(
[
*camera.x,
*camera.y,
*camera.z,
*camera.origin,
camera.x_fov,
]
)
)
outer_batch.append(np.stack(inner_batch, axis=0))
return torch.from_numpy(np.stack(outer_batch, axis=0)).float()