You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
960 lines
34 KiB
960 lines
34 KiB
2 years ago
|
from abc import ABC, abstractmethod
|
||
|
from dataclasses import dataclass
|
||
|
from functools import partial
|
||
|
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||
|
|
||
|
import numpy as np
|
||
|
import torch.distributed as dist
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
from PIL import Image
|
||
|
from torch import torch
|
||
|
|
||
|
from shap_e.models.generation.perceiver import SimplePerceiver
|
||
|
from shap_e.models.generation.transformer import Transformer
|
||
|
from shap_e.models.nn.camera import DifferentiableProjectiveCamera
|
||
|
from shap_e.models.nn.encoding import (
|
||
|
MultiviewPointCloudEmbedding,
|
||
|
MultiviewPoseEmbedding,
|
||
|
PosEmbLinear,
|
||
|
)
|
||
|
from shap_e.models.nn.ops import PointSetEmbedding
|
||
|
from shap_e.rendering.point_cloud import PointCloud
|
||
|
from shap_e.rendering.view_data import ProjectiveCamera
|
||
|
from shap_e.util.collections import AttrDict
|
||
|
|
||
|
from .base import ChannelsEncoder
|
||
|
|
||
|
|
||
|
class TransformerChannelsEncoder(ChannelsEncoder, ABC):
|
||
|
"""
|
||
|
Encode point clouds using a transformer model with an extra output
|
||
|
token used to extract a latent vector.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
*,
|
||
|
device: torch.device,
|
||
|
dtype: torch.dtype,
|
||
|
param_shapes: Dict[str, Tuple[int]],
|
||
|
params_proj: Dict[str, Any],
|
||
|
d_latent: int = 512,
|
||
|
latent_bottleneck: Optional[Dict[str, Any]] = None,
|
||
|
latent_warp: Optional[Dict[str, Any]] = None,
|
||
|
n_ctx: int = 1024,
|
||
|
width: int = 512,
|
||
|
layers: int = 12,
|
||
|
heads: int = 8,
|
||
|
init_scale: float = 0.25,
|
||
|
latent_scale: float = 1.0,
|
||
|
):
|
||
|
super().__init__(
|
||
|
device=device,
|
||
|
param_shapes=param_shapes,
|
||
|
params_proj=params_proj,
|
||
|
d_latent=d_latent,
|
||
|
latent_bottleneck=latent_bottleneck,
|
||
|
latent_warp=latent_warp,
|
||
|
)
|
||
|
self.width = width
|
||
|
self.device = device
|
||
|
self.dtype = dtype
|
||
|
|
||
|
self.n_ctx = n_ctx
|
||
|
|
||
|
self.backbone = Transformer(
|
||
|
device=device,
|
||
|
dtype=dtype,
|
||
|
n_ctx=n_ctx + self.latent_ctx,
|
||
|
width=width,
|
||
|
layers=layers,
|
||
|
heads=heads,
|
||
|
init_scale=init_scale,
|
||
|
)
|
||
|
self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype)
|
||
|
self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
|
||
|
self.register_parameter(
|
||
|
"output_tokens",
|
||
|
nn.Parameter(torch.randn(self.latent_ctx, width, device=device, dtype=dtype)),
|
||
|
)
|
||
|
self.output_proj = nn.Linear(width, d_latent, device=device, dtype=dtype)
|
||
|
self.latent_scale = latent_scale
|
||
|
|
||
|
@abstractmethod
|
||
|
def encode_input(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor:
|
||
|
pass
|
||
|
|
||
|
def encode_to_channels(
|
||
|
self, batch: AttrDict, options: Optional[AttrDict] = None
|
||
|
) -> torch.Tensor:
|
||
|
h = self.encode_input(batch, options=options)
|
||
|
h = torch.cat([h, self.output_tokens[None].repeat(len(h), 1, 1)], dim=1)
|
||
|
h = self.ln_pre(h)
|
||
|
h = self.backbone(h)
|
||
|
h = h[:, -self.latent_ctx :]
|
||
|
h = self.ln_post(h)
|
||
|
h = self.output_proj(h)
|
||
|
return h
|
||
|
|
||
|
|
||
|
class PerceiverChannelsEncoder(ChannelsEncoder, ABC):
|
||
|
"""
|
||
|
Encode point clouds using a perceiver model with an extra output
|
||
|
token used to extract a latent vector.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
*,
|
||
|
device: torch.device,
|
||
|
dtype: torch.dtype,
|
||
|
param_shapes: Dict[str, Tuple[int]],
|
||
|
params_proj: Dict[str, Any],
|
||
|
min_unrolls: int,
|
||
|
max_unrolls: int,
|
||
|
d_latent: int = 512,
|
||
|
latent_bottleneck: Optional[Dict[str, Any]] = None,
|
||
|
latent_warp: Optional[Dict[str, Any]] = None,
|
||
|
width: int = 512,
|
||
|
layers: int = 12,
|
||
|
xattn_layers: int = 1,
|
||
|
heads: int = 8,
|
||
|
init_scale: float = 0.25,
|
||
|
# Training hparams
|
||
|
inner_batch_size: Union[int, List[int]] = 1,
|
||
|
data_ctx: int = 1,
|
||
|
):
|
||
|
super().__init__(
|
||
|
device=device,
|
||
|
param_shapes=param_shapes,
|
||
|
params_proj=params_proj,
|
||
|
d_latent=d_latent,
|
||
|
latent_bottleneck=latent_bottleneck,
|
||
|
latent_warp=latent_warp,
|
||
|
)
|
||
|
self.width = width
|
||
|
self.device = device
|
||
|
self.dtype = dtype
|
||
|
|
||
|
if isinstance(inner_batch_size, int):
|
||
|
inner_batch_size = [inner_batch_size]
|
||
|
self.inner_batch_size = inner_batch_size
|
||
|
self.data_ctx = data_ctx
|
||
|
self.min_unrolls = min_unrolls
|
||
|
self.max_unrolls = max_unrolls
|
||
|
|
||
|
encoder_fn = lambda inner_batch_size: SimplePerceiver(
|
||
|
device=device,
|
||
|
dtype=dtype,
|
||
|
n_ctx=self.data_ctx + self.latent_ctx,
|
||
|
n_data=inner_batch_size,
|
||
|
width=width,
|
||
|
layers=xattn_layers,
|
||
|
heads=heads,
|
||
|
init_scale=init_scale,
|
||
|
)
|
||
|
self.encoder = (
|
||
|
encoder_fn(self.inner_batch_size[0])
|
||
|
if len(self.inner_batch_size) == 1
|
||
|
else nn.ModuleList([encoder_fn(inner_bsz) for inner_bsz in self.inner_batch_size])
|
||
|
)
|
||
|
self.processor = Transformer(
|
||
|
device=device,
|
||
|
dtype=dtype,
|
||
|
n_ctx=self.data_ctx + self.latent_ctx,
|
||
|
layers=layers - xattn_layers,
|
||
|
width=width,
|
||
|
heads=heads,
|
||
|
init_scale=init_scale,
|
||
|
)
|
||
|
self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype)
|
||
|
self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
|
||
|
self.register_parameter(
|
||
|
"output_tokens",
|
||
|
nn.Parameter(torch.randn(self.latent_ctx, width, device=device, dtype=dtype)),
|
||
|
)
|
||
|
self.output_proj = nn.Linear(width, d_latent, device=device, dtype=dtype)
|
||
|
|
||
|
@abstractmethod
|
||
|
def get_h_and_iterator(
|
||
|
self, batch: AttrDict, options: Optional[AttrDict] = None
|
||
|
) -> Tuple[torch.Tensor, Iterable[Union[torch.Tensor, Tuple]]]:
|
||
|
"""
|
||
|
:return: a tuple of (
|
||
|
the initial output tokens of size [batch_size, data_ctx + latent_ctx, width],
|
||
|
an iterator over the given data
|
||
|
)
|
||
|
"""
|
||
|
|
||
|
def encode_to_channels(
|
||
|
self, batch: AttrDict, options: Optional[AttrDict] = None
|
||
|
) -> torch.Tensor:
|
||
|
h, it = self.get_h_and_iterator(batch, options=options)
|
||
|
n_unrolls = self.get_n_unrolls()
|
||
|
|
||
|
for _ in range(n_unrolls):
|
||
|
data = next(it)
|
||
|
if isinstance(data, tuple):
|
||
|
for data_i, encoder_i in zip(data, self.encoder):
|
||
|
h = encoder_i(h, data_i)
|
||
|
else:
|
||
|
h = self.encoder(h, data)
|
||
|
h = self.processor(h)
|
||
|
|
||
|
h = self.output_proj(self.ln_post(h[:, -self.latent_ctx :]))
|
||
|
return h
|
||
|
|
||
|
def get_n_unrolls(self):
|
||
|
if self.training:
|
||
|
n_unrolls = torch.randint(
|
||
|
self.min_unrolls, self.max_unrolls + 1, size=(), device=self.device
|
||
|
)
|
||
|
dist.broadcast(n_unrolls, 0)
|
||
|
n_unrolls = n_unrolls.item()
|
||
|
else:
|
||
|
n_unrolls = self.max_unrolls
|
||
|
return n_unrolls
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class DatasetIterator:
|
||
|
|
||
|
embs: torch.Tensor # [batch_size, dataset_size, *shape]
|
||
|
batch_size: int
|
||
|
|
||
|
def __iter__(self):
|
||
|
self._reset()
|
||
|
return self
|
||
|
|
||
|
def __next__(self):
|
||
|
_outer_batch_size, dataset_size, *_shape = self.embs.shape
|
||
|
|
||
|
while True:
|
||
|
start = self.idx
|
||
|
self.idx += self.batch_size
|
||
|
end = self.idx
|
||
|
if end <= dataset_size:
|
||
|
break
|
||
|
self._reset()
|
||
|
|
||
|
return self.embs[:, start:end]
|
||
|
|
||
|
def _reset(self):
|
||
|
self._shuffle()
|
||
|
self.idx = 0 # pylint: disable=attribute-defined-outside-init
|
||
|
|
||
|
def _shuffle(self):
|
||
|
outer_batch_size, dataset_size, *shape = self.embs.shape
|
||
|
idx = torch.stack(
|
||
|
[
|
||
|
torch.randperm(dataset_size, device=self.embs.device)
|
||
|
for _ in range(outer_batch_size)
|
||
|
],
|
||
|
dim=0,
|
||
|
)
|
||
|
idx = idx.view(outer_batch_size, dataset_size, *([1] * len(shape)))
|
||
|
idx = torch.broadcast_to(idx, self.embs.shape)
|
||
|
self.embs = torch.gather(self.embs, 1, idx)
|
||
|
|
||
|
|
||
|
class PointCloudTransformerChannelsEncoder(TransformerChannelsEncoder):
|
||
|
"""
|
||
|
Encode point clouds using a transformer model with an extra output
|
||
|
token used to extract a latent vector.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
*,
|
||
|
input_channels: int = 6,
|
||
|
**kwargs,
|
||
|
):
|
||
|
super().__init__(**kwargs)
|
||
|
self.input_channels = input_channels
|
||
|
self.input_proj = nn.Linear(
|
||
|
input_channels, self.width, device=self.device, dtype=self.dtype
|
||
|
)
|
||
|
|
||
|
def encode_input(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor:
|
||
|
_ = options
|
||
|
points = batch.points
|
||
|
h = self.input_proj(points.permute(0, 2, 1)) # NCL -> NLC
|
||
|
return h
|
||
|
|
||
|
|
||
|
class PointCloudPerceiverChannelsEncoder(PerceiverChannelsEncoder):
|
||
|
"""
|
||
|
Encode point clouds using a transformer model with an extra output
|
||
|
token used to extract a latent vector.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
*,
|
||
|
cross_attention_dataset: str = "pcl",
|
||
|
fps_method: str = "fps",
|
||
|
# point cloud hyperparameters
|
||
|
input_channels: int = 6,
|
||
|
pos_emb: Optional[str] = None,
|
||
|
# multiview hyperparameters
|
||
|
image_size: int = 256,
|
||
|
patch_size: int = 32,
|
||
|
pose_dropout: float = 0.0,
|
||
|
use_depth: bool = False,
|
||
|
max_depth: float = 5.0,
|
||
|
# point conv hyperparameters
|
||
|
pointconv_radius: float = 0.5,
|
||
|
pointconv_samples: int = 32,
|
||
|
pointconv_hidden: Optional[List[int]] = None,
|
||
|
pointconv_patch_size: int = 1,
|
||
|
pointconv_stride: int = 1,
|
||
|
pointconv_padding_mode: str = "zeros",
|
||
|
use_pointconv: bool = False,
|
||
|
# other hyperparameters
|
||
|
**kwargs,
|
||
|
):
|
||
|
super().__init__(**kwargs)
|
||
|
assert cross_attention_dataset in (
|
||
|
"pcl",
|
||
|
"multiview",
|
||
|
"dense_pose_multiview",
|
||
|
"multiview_pcl",
|
||
|
"pcl_and_multiview_pcl",
|
||
|
"incorrect_multiview_pcl",
|
||
|
"pcl_and_incorrect_multiview_pcl",
|
||
|
)
|
||
|
assert fps_method in ("fps", "first")
|
||
|
self.cross_attention_dataset = cross_attention_dataset
|
||
|
self.fps_method = fps_method
|
||
|
self.input_channels = input_channels
|
||
|
self.input_proj = PosEmbLinear(
|
||
|
pos_emb,
|
||
|
input_channels,
|
||
|
self.width,
|
||
|
device=self.device,
|
||
|
dtype=self.dtype,
|
||
|
)
|
||
|
self.use_pointconv = use_pointconv
|
||
|
if use_pointconv:
|
||
|
if pointconv_hidden is None:
|
||
|
pointconv_hidden = [self.width]
|
||
|
self.point_conv = PointSetEmbedding(
|
||
|
n_point=self.data_ctx,
|
||
|
radius=pointconv_radius,
|
||
|
n_sample=pointconv_samples,
|
||
|
d_input=self.input_proj.weight.shape[0],
|
||
|
d_hidden=pointconv_hidden,
|
||
|
patch_size=pointconv_patch_size,
|
||
|
stride=pointconv_stride,
|
||
|
padding_mode=pointconv_padding_mode,
|
||
|
fps_method=fps_method,
|
||
|
device=self.device,
|
||
|
dtype=self.dtype,
|
||
|
)
|
||
|
if self.cross_attention_dataset == "multiview":
|
||
|
self.image_size = image_size
|
||
|
self.patch_size = patch_size
|
||
|
self.pose_dropout = pose_dropout
|
||
|
self.use_depth = use_depth
|
||
|
self.max_depth = max_depth
|
||
|
pos_ctx = (image_size // patch_size) ** 2
|
||
|
self.register_parameter(
|
||
|
"pos_emb",
|
||
|
nn.Parameter(
|
||
|
torch.randn(
|
||
|
pos_ctx * self.inner_batch_size,
|
||
|
self.width,
|
||
|
device=self.device,
|
||
|
dtype=self.dtype,
|
||
|
)
|
||
|
),
|
||
|
)
|
||
|
self.patch_emb = nn.Conv2d(
|
||
|
in_channels=3 if not use_depth else 4,
|
||
|
out_channels=self.width,
|
||
|
kernel_size=patch_size,
|
||
|
stride=patch_size,
|
||
|
device=self.device,
|
||
|
dtype=self.dtype,
|
||
|
)
|
||
|
self.camera_emb = nn.Sequential(
|
||
|
nn.Linear(
|
||
|
3 * 4 + 1, self.width, device=self.device, dtype=self.dtype
|
||
|
), # input size is for origin+x+y+z+fov
|
||
|
nn.GELU(),
|
||
|
nn.Linear(self.width, 2 * self.width, device=self.device, dtype=self.dtype),
|
||
|
)
|
||
|
elif self.cross_attention_dataset == "dense_pose_multiview":
|
||
|
# The number of output features is halved, because a patch_size of
|
||
|
# 32 ends up with a large patch_emb weight.
|
||
|
self.view_pose_width = self.width // 2
|
||
|
self.image_size = image_size
|
||
|
self.patch_size = patch_size
|
||
|
self.use_depth = use_depth
|
||
|
self.max_depth = max_depth
|
||
|
self.mv_pose_embed = MultiviewPoseEmbedding(
|
||
|
posemb_version="nerf",
|
||
|
n_channels=4 if self.use_depth else 3,
|
||
|
out_features=self.view_pose_width,
|
||
|
device=self.device,
|
||
|
dtype=self.dtype,
|
||
|
)
|
||
|
pos_ctx = (image_size // patch_size) ** 2
|
||
|
# Positional embedding is unnecessary because pose information is baked into each pixel
|
||
|
self.patch_emb = nn.Conv2d(
|
||
|
in_channels=self.view_pose_width,
|
||
|
out_channels=self.width,
|
||
|
kernel_size=patch_size,
|
||
|
stride=patch_size,
|
||
|
device=self.device,
|
||
|
dtype=self.dtype,
|
||
|
)
|
||
|
|
||
|
elif (
|
||
|
self.cross_attention_dataset == "multiview_pcl"
|
||
|
or self.cross_attention_dataset == "incorrect_multiview_pcl"
|
||
|
):
|
||
|
self.view_pose_width = self.width // 2
|
||
|
self.image_size = image_size
|
||
|
self.patch_size = patch_size
|
||
|
self.max_depth = max_depth
|
||
|
assert use_depth
|
||
|
self.mv_pcl_embed = MultiviewPointCloudEmbedding(
|
||
|
posemb_version="nerf",
|
||
|
n_channels=3,
|
||
|
out_features=self.view_pose_width,
|
||
|
device=self.device,
|
||
|
dtype=self.dtype,
|
||
|
)
|
||
|
self.patch_emb = nn.Conv2d(
|
||
|
in_channels=self.view_pose_width,
|
||
|
out_channels=self.width,
|
||
|
kernel_size=patch_size,
|
||
|
stride=patch_size,
|
||
|
device=self.device,
|
||
|
dtype=self.dtype,
|
||
|
)
|
||
|
|
||
|
elif (
|
||
|
self.cross_attention_dataset == "pcl_and_multiview_pcl"
|
||
|
or self.cross_attention_dataset == "pcl_and_incorrect_multiview_pcl"
|
||
|
):
|
||
|
self.view_pose_width = self.width // 2
|
||
|
self.image_size = image_size
|
||
|
self.patch_size = patch_size
|
||
|
self.max_depth = max_depth
|
||
|
assert use_depth
|
||
|
self.mv_pcl_embed = MultiviewPointCloudEmbedding(
|
||
|
posemb_version="nerf",
|
||
|
n_channels=3,
|
||
|
out_features=self.view_pose_width,
|
||
|
device=self.device,
|
||
|
dtype=self.dtype,
|
||
|
)
|
||
|
self.patch_emb = nn.Conv2d(
|
||
|
in_channels=self.view_pose_width,
|
||
|
out_channels=self.width,
|
||
|
kernel_size=patch_size,
|
||
|
stride=patch_size,
|
||
|
device=self.device,
|
||
|
dtype=self.dtype,
|
||
|
)
|
||
|
|
||
|
def get_h_and_iterator(
|
||
|
self, batch: AttrDict, options: Optional[AttrDict] = None
|
||
|
) -> Tuple[torch.Tensor, Iterable]:
|
||
|
"""
|
||
|
:return: a tuple of (
|
||
|
the initial output tokens of size [batch_size, data_ctx + latent_ctx, width],
|
||
|
an iterator over the given data
|
||
|
)
|
||
|
"""
|
||
|
options = AttrDict() if options is None else options
|
||
|
|
||
|
# Build the initial query embeddings
|
||
|
points = batch.points.permute(0, 2, 1) # NCL -> NLC
|
||
|
if self.use_pointconv:
|
||
|
points = self.input_proj(points).permute(0, 2, 1) # NLC -> NCL
|
||
|
xyz = batch.points[:, :3]
|
||
|
data_tokens = self.point_conv(xyz, points).permute(0, 2, 1) # NCL -> NLC
|
||
|
else:
|
||
|
fps_samples = self.sample_pcl_fps(points)
|
||
|
data_tokens = self.input_proj(fps_samples)
|
||
|
batch_size = points.shape[0]
|
||
|
latent_tokens = self.output_tokens.unsqueeze(0).repeat(batch_size, 1, 1)
|
||
|
h = self.ln_pre(torch.cat([data_tokens, latent_tokens], dim=1))
|
||
|
assert h.shape == (batch_size, self.data_ctx + self.latent_ctx, self.width)
|
||
|
|
||
|
# Build the dataset embedding iterator
|
||
|
dataset_fn = {
|
||
|
"pcl": self.get_pcl_dataset,
|
||
|
"multiview": self.get_multiview_dataset,
|
||
|
"dense_pose_multiview": self.get_dense_pose_multiview_dataset,
|
||
|
"pcl_and_multiview_pcl": self.get_pcl_and_multiview_pcl_dataset,
|
||
|
"multiview_pcl": self.get_multiview_pcl_dataset,
|
||
|
}[self.cross_attention_dataset]
|
||
|
it = dataset_fn(batch, options=options)
|
||
|
|
||
|
return h, it
|
||
|
|
||
|
def sample_pcl_fps(self, points: torch.Tensor) -> torch.Tensor:
|
||
|
return sample_pcl_fps(points, data_ctx=self.data_ctx, method=self.fps_method)
|
||
|
|
||
|
def get_pcl_dataset(
|
||
|
self,
|
||
|
batch: AttrDict,
|
||
|
options: Optional[AttrDict[str, Any]] = None,
|
||
|
inner_batch_size: Optional[int] = None,
|
||
|
) -> Iterable:
|
||
|
_ = options
|
||
|
if inner_batch_size is None:
|
||
|
inner_batch_size = self.inner_batch_size[0]
|
||
|
points = batch.points.permute(0, 2, 1) # NCL -> NLC
|
||
|
dataset_emb = self.input_proj(points)
|
||
|
assert dataset_emb.shape[1] >= inner_batch_size
|
||
|
return iter(DatasetIterator(dataset_emb, batch_size=inner_batch_size))
|
||
|
|
||
|
def get_multiview_dataset(
|
||
|
self,
|
||
|
batch: AttrDict,
|
||
|
options: Optional[AttrDict] = None,
|
||
|
inner_batch_size: Optional[int] = None,
|
||
|
) -> Iterable:
|
||
|
_ = options
|
||
|
|
||
|
if inner_batch_size is None:
|
||
|
inner_batch_size = self.inner_batch_size[0]
|
||
|
|
||
|
dataset_emb = self.encode_views(batch)
|
||
|
batch_size, num_views, n_patches, width = dataset_emb.shape
|
||
|
|
||
|
assert num_views >= inner_batch_size
|
||
|
|
||
|
it = iter(DatasetIterator(dataset_emb, batch_size=inner_batch_size))
|
||
|
|
||
|
def gen():
|
||
|
while True:
|
||
|
examples = next(it)
|
||
|
assert examples.shape == (batch_size, self.inner_batch_size, n_patches, self.width)
|
||
|
views = examples.reshape(batch_size, -1, width) + self.pos_emb
|
||
|
yield views
|
||
|
|
||
|
return gen()
|
||
|
|
||
|
def get_dense_pose_multiview_dataset(
|
||
|
self,
|
||
|
batch: AttrDict,
|
||
|
options: Optional[AttrDict] = None,
|
||
|
inner_batch_size: Optional[int] = None,
|
||
|
) -> Iterable:
|
||
|
_ = options
|
||
|
|
||
|
if inner_batch_size is None:
|
||
|
inner_batch_size = self.inner_batch_size[0]
|
||
|
|
||
|
dataset_emb = self.encode_dense_pose_views(batch)
|
||
|
batch_size, num_views, n_patches, width = dataset_emb.shape
|
||
|
|
||
|
assert num_views >= inner_batch_size
|
||
|
|
||
|
it = iter(DatasetIterator(dataset_emb, batch_size=inner_batch_size))
|
||
|
|
||
|
def gen():
|
||
|
while True:
|
||
|
examples = next(it)
|
||
|
assert examples.shape == (batch_size, inner_batch_size, n_patches, self.width)
|
||
|
views = examples.reshape(batch_size, -1, width)
|
||
|
yield views
|
||
|
|
||
|
return gen()
|
||
|
|
||
|
def get_pcl_and_multiview_pcl_dataset(
|
||
|
self,
|
||
|
batch: AttrDict,
|
||
|
options: Optional[AttrDict] = None,
|
||
|
use_distance: bool = True,
|
||
|
) -> Iterable:
|
||
|
_ = options
|
||
|
|
||
|
pcl_it = self.get_pcl_dataset(
|
||
|
batch, options=options, inner_batch_size=self.inner_batch_size[0]
|
||
|
)
|
||
|
multiview_pcl_emb = self.encode_multiview_pcl(batch, use_distance=use_distance)
|
||
|
batch_size, num_views, n_patches, width = multiview_pcl_emb.shape
|
||
|
|
||
|
assert num_views >= self.inner_batch_size[1]
|
||
|
|
||
|
multiview_pcl_it = iter(
|
||
|
DatasetIterator(multiview_pcl_emb, batch_size=self.inner_batch_size[1])
|
||
|
)
|
||
|
|
||
|
def gen():
|
||
|
while True:
|
||
|
pcl = next(pcl_it)
|
||
|
multiview_pcl = next(multiview_pcl_it)
|
||
|
assert multiview_pcl.shape == (
|
||
|
batch_size,
|
||
|
self.inner_batch_size[1],
|
||
|
n_patches,
|
||
|
self.width,
|
||
|
)
|
||
|
yield pcl, multiview_pcl.reshape(batch_size, -1, width)
|
||
|
|
||
|
return gen()
|
||
|
|
||
|
def get_multiview_pcl_dataset(
|
||
|
self,
|
||
|
batch: AttrDict,
|
||
|
options: Optional[AttrDict] = None,
|
||
|
inner_batch_size: Optional[int] = None,
|
||
|
use_distance: bool = True,
|
||
|
) -> Iterable:
|
||
|
_ = options
|
||
|
|
||
|
if inner_batch_size is None:
|
||
|
inner_batch_size = self.inner_batch_size[0]
|
||
|
|
||
|
multiview_pcl_emb = self.encode_multiview_pcl(batch, use_distance=use_distance)
|
||
|
batch_size, num_views, n_patches, width = multiview_pcl_emb.shape
|
||
|
|
||
|
assert num_views >= inner_batch_size
|
||
|
|
||
|
multiview_pcl_it = iter(DatasetIterator(multiview_pcl_emb, batch_size=inner_batch_size))
|
||
|
|
||
|
def gen():
|
||
|
while True:
|
||
|
multiview_pcl = next(multiview_pcl_it)
|
||
|
assert multiview_pcl.shape == (
|
||
|
batch_size,
|
||
|
inner_batch_size,
|
||
|
n_patches,
|
||
|
self.width,
|
||
|
)
|
||
|
yield multiview_pcl.reshape(batch_size, -1, width)
|
||
|
|
||
|
return gen()
|
||
|
|
||
|
def encode_views(self, batch: AttrDict) -> torch.Tensor:
|
||
|
"""
|
||
|
:return: [batch_size, num_views, n_patches, width]
|
||
|
"""
|
||
|
all_views = self.views_to_tensor(batch.views).to(self.device)
|
||
|
if self.use_depth:
|
||
|
all_views = torch.cat([all_views, self.depths_to_tensor(batch.depths)], dim=2)
|
||
|
all_cameras = self.cameras_to_tensor(batch.cameras).to(self.device)
|
||
|
|
||
|
batch_size, num_views, _, _, _ = all_views.shape
|
||
|
|
||
|
views_proj = self.patch_emb(
|
||
|
all_views.reshape([batch_size * num_views, *all_views.shape[2:]])
|
||
|
)
|
||
|
views_proj = (
|
||
|
views_proj.reshape([batch_size, num_views, self.width, -1])
|
||
|
.permute(0, 1, 3, 2)
|
||
|
.contiguous()
|
||
|
) # [batch_size x num_views x n_patches x width]
|
||
|
|
||
|
# [batch_size, num_views, 1, 2 * width]
|
||
|
camera_proj = self.camera_emb(all_cameras).reshape(
|
||
|
[batch_size, num_views, 1, self.width * 2]
|
||
|
)
|
||
|
pose_dropout = self.pose_dropout if self.training else 0.0
|
||
|
mask = torch.rand(batch_size, 1, 1, 1, device=views_proj.device) >= pose_dropout
|
||
|
camera_proj = torch.where(mask, camera_proj, torch.zeros_like(camera_proj))
|
||
|
scale, shift = camera_proj.chunk(2, dim=3)
|
||
|
views_proj = views_proj * (scale + 1.0) + shift
|
||
|
return views_proj
|
||
|
|
||
|
def encode_dense_pose_views(self, batch: AttrDict) -> torch.Tensor:
|
||
|
"""
|
||
|
:return: [batch_size, num_views, n_patches, width]
|
||
|
"""
|
||
|
all_views = self.views_to_tensor(batch.views).to(self.device)
|
||
|
if self.use_depth:
|
||
|
depths = self.depths_to_tensor(batch.depths)
|
||
|
all_views = torch.cat([all_views, depths], dim=2)
|
||
|
|
||
|
dense_poses, _ = self.dense_pose_cameras_to_tensor(batch.cameras)
|
||
|
dense_poses = dense_poses.permute(0, 1, 4, 5, 2, 3)
|
||
|
position, direction = dense_poses[:, :, 0], dense_poses[:, :, 1]
|
||
|
all_view_poses = self.mv_pose_embed(all_views, position, direction)
|
||
|
|
||
|
batch_size, num_views, _, _, _ = all_view_poses.shape
|
||
|
|
||
|
views_proj = self.patch_emb(
|
||
|
all_view_poses.reshape([batch_size * num_views, *all_view_poses.shape[2:]])
|
||
|
)
|
||
|
views_proj = (
|
||
|
views_proj.reshape([batch_size, num_views, self.width, -1])
|
||
|
.permute(0, 1, 3, 2)
|
||
|
.contiguous()
|
||
|
) # [batch_size x num_views x n_patches x width]
|
||
|
|
||
|
return views_proj
|
||
|
|
||
|
def encode_multiview_pcl(self, batch: AttrDict, use_distance: bool = True) -> torch.Tensor:
|
||
|
"""
|
||
|
:return: [batch_size, num_views, n_patches, width]
|
||
|
"""
|
||
|
all_views = self.views_to_tensor(batch.views).to(self.device)
|
||
|
depths = self.raw_depths_to_tensor(batch.depths)
|
||
|
all_view_alphas = self.view_alphas_to_tensor(batch.view_alphas).to(self.device)
|
||
|
mask = all_view_alphas >= 0.999
|
||
|
|
||
|
dense_poses, camera_z = self.dense_pose_cameras_to_tensor(batch.cameras)
|
||
|
dense_poses = dense_poses.permute(0, 1, 4, 5, 2, 3)
|
||
|
|
||
|
origin, direction = dense_poses[:, :, 0], dense_poses[:, :, 1]
|
||
|
if use_distance:
|
||
|
ray_depth_factor = torch.sum(direction * camera_z[..., None, None], dim=2, keepdim=True)
|
||
|
depths = depths / ray_depth_factor
|
||
|
position = origin + depths * direction
|
||
|
all_view_poses = self.mv_pcl_embed(all_views, origin, position, mask)
|
||
|
|
||
|
batch_size, num_views, _, _, _ = all_view_poses.shape
|
||
|
|
||
|
views_proj = self.patch_emb(
|
||
|
all_view_poses.reshape([batch_size * num_views, *all_view_poses.shape[2:]])
|
||
|
)
|
||
|
views_proj = (
|
||
|
views_proj.reshape([batch_size, num_views, self.width, -1])
|
||
|
.permute(0, 1, 3, 2)
|
||
|
.contiguous()
|
||
|
) # [batch_size x num_views x n_patches x width]
|
||
|
|
||
|
return views_proj
|
||
|
|
||
|
def views_to_tensor(self, views: Union[torch.Tensor, List[List[Image.Image]]]) -> torch.Tensor:
|
||
|
"""
|
||
|
Returns a [batch x num_views x 3 x size x size] tensor in the range [-1, 1].
|
||
|
"""
|
||
|
if isinstance(views, torch.Tensor):
|
||
|
return views
|
||
|
|
||
|
tensor_batch = []
|
||
|
num_views = len(views[0])
|
||
|
for inner_list in views:
|
||
|
assert len(inner_list) == num_views
|
||
|
inner_batch = []
|
||
|
for img in inner_list:
|
||
|
img = img.resize((self.image_size,) * 2).convert("RGB")
|
||
|
inner_batch.append(
|
||
|
torch.from_numpy(np.array(img)).to(device=self.device, dtype=torch.float32)
|
||
|
/ 127.5
|
||
|
- 1
|
||
|
)
|
||
|
tensor_batch.append(torch.stack(inner_batch, dim=0))
|
||
|
return torch.stack(tensor_batch, dim=0).permute(0, 1, 4, 2, 3)
|
||
|
|
||
|
def depths_to_tensor(
|
||
|
self, depths: Union[torch.Tensor, List[List[Image.Image]]]
|
||
|
) -> torch.Tensor:
|
||
|
"""
|
||
|
Returns a [batch x num_views x 1 x size x size] tensor in the range [-1, 1].
|
||
|
"""
|
||
|
if isinstance(depths, torch.Tensor):
|
||
|
return depths
|
||
|
|
||
|
tensor_batch = []
|
||
|
num_views = len(depths[0])
|
||
|
for inner_list in depths:
|
||
|
assert len(inner_list) == num_views
|
||
|
inner_batch = []
|
||
|
for arr in inner_list:
|
||
|
tensor = torch.from_numpy(arr).clamp(max=self.max_depth) / self.max_depth
|
||
|
tensor = tensor * 2 - 1
|
||
|
tensor = F.interpolate(
|
||
|
tensor[None, None],
|
||
|
(self.image_size,) * 2,
|
||
|
mode="nearest",
|
||
|
)
|
||
|
inner_batch.append(tensor.to(device=self.device, dtype=torch.float32))
|
||
|
tensor_batch.append(torch.cat(inner_batch, dim=0))
|
||
|
return torch.stack(tensor_batch, dim=0)
|
||
|
|
||
|
def view_alphas_to_tensor(
|
||
|
self, view_alphas: Union[torch.Tensor, List[List[Image.Image]]]
|
||
|
) -> torch.Tensor:
|
||
|
"""
|
||
|
Returns a [batch x num_views x 1 x size x size] tensor in the range [0, 1].
|
||
|
"""
|
||
|
if isinstance(view_alphas, torch.Tensor):
|
||
|
return view_alphas
|
||
|
|
||
|
tensor_batch = []
|
||
|
num_views = len(view_alphas[0])
|
||
|
for inner_list in view_alphas:
|
||
|
assert len(inner_list) == num_views
|
||
|
inner_batch = []
|
||
|
for img in inner_list:
|
||
|
tensor = (
|
||
|
torch.from_numpy(np.array(img)).to(device=self.device, dtype=torch.float32)
|
||
|
/ 255.0
|
||
|
)
|
||
|
tensor = F.interpolate(
|
||
|
tensor[None, None],
|
||
|
(self.image_size,) * 2,
|
||
|
mode="nearest",
|
||
|
)
|
||
|
inner_batch.append(tensor)
|
||
|
tensor_batch.append(torch.cat(inner_batch, dim=0))
|
||
|
return torch.stack(tensor_batch, dim=0)
|
||
|
|
||
|
def raw_depths_to_tensor(
|
||
|
self, depths: Union[torch.Tensor, List[List[Image.Image]]]
|
||
|
) -> torch.Tensor:
|
||
|
"""
|
||
|
Returns a [batch x num_views x 1 x size x size] tensor
|
||
|
"""
|
||
|
if isinstance(depths, torch.Tensor):
|
||
|
return depths
|
||
|
|
||
|
tensor_batch = []
|
||
|
num_views = len(depths[0])
|
||
|
for inner_list in depths:
|
||
|
assert len(inner_list) == num_views
|
||
|
inner_batch = []
|
||
|
for arr in inner_list:
|
||
|
tensor = torch.from_numpy(arr).clamp(max=self.max_depth)
|
||
|
tensor = F.interpolate(
|
||
|
tensor[None, None],
|
||
|
(self.image_size,) * 2,
|
||
|
mode="nearest",
|
||
|
)
|
||
|
inner_batch.append(tensor.to(device=self.device, dtype=torch.float32))
|
||
|
tensor_batch.append(torch.cat(inner_batch, dim=0))
|
||
|
return torch.stack(tensor_batch, dim=0)
|
||
|
|
||
|
def cameras_to_tensor(
|
||
|
self, cameras: Union[torch.Tensor, List[List[ProjectiveCamera]]]
|
||
|
) -> torch.Tensor:
|
||
|
"""
|
||
|
Returns a [batch x num_views x 3*4+1] tensor of camera information.
|
||
|
"""
|
||
|
if isinstance(cameras, torch.Tensor):
|
||
|
return cameras
|
||
|
outer_batch = []
|
||
|
for inner_list in cameras:
|
||
|
inner_batch = []
|
||
|
for camera in inner_list:
|
||
|
inner_batch.append(
|
||
|
np.array(
|
||
|
[
|
||
|
*camera.x,
|
||
|
*camera.y,
|
||
|
*camera.z,
|
||
|
*camera.origin,
|
||
|
camera.x_fov,
|
||
|
]
|
||
|
)
|
||
|
)
|
||
|
outer_batch.append(np.stack(inner_batch, axis=0))
|
||
|
return torch.from_numpy(np.stack(outer_batch, axis=0)).float()
|
||
|
|
||
|
def dense_pose_cameras_to_tensor(
|
||
|
self, cameras: Union[torch.Tensor, List[List[ProjectiveCamera]]]
|
||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
|
"""
|
||
|
Returns a tuple of (rays, z_directions) where
|
||
|
- rays: [batch, num_views, height, width, 2, 3] tensor of camera information.
|
||
|
- z_directions: [batch, num_views, 3] tensor of camera z directions.
|
||
|
"""
|
||
|
if isinstance(cameras, torch.Tensor):
|
||
|
raise NotImplementedError
|
||
|
|
||
|
for inner_list in cameras:
|
||
|
assert len(inner_list) == len(cameras[0])
|
||
|
|
||
|
camera = cameras[0][0]
|
||
|
flat_camera = DifferentiableProjectiveCamera(
|
||
|
origin=torch.from_numpy(
|
||
|
np.stack(
|
||
|
[cam.origin for inner_list in cameras for cam in inner_list],
|
||
|
axis=0,
|
||
|
)
|
||
|
).to(self.device),
|
||
|
x=torch.from_numpy(
|
||
|
np.stack(
|
||
|
[cam.x for inner_list in cameras for cam in inner_list],
|
||
|
axis=0,
|
||
|
)
|
||
|
).to(self.device),
|
||
|
y=torch.from_numpy(
|
||
|
np.stack(
|
||
|
[cam.y for inner_list in cameras for cam in inner_list],
|
||
|
axis=0,
|
||
|
)
|
||
|
).to(self.device),
|
||
|
z=torch.from_numpy(
|
||
|
np.stack(
|
||
|
[cam.z for inner_list in cameras for cam in inner_list],
|
||
|
axis=0,
|
||
|
)
|
||
|
).to(self.device),
|
||
|
width=camera.width,
|
||
|
height=camera.height,
|
||
|
x_fov=camera.x_fov,
|
||
|
y_fov=camera.y_fov,
|
||
|
)
|
||
|
batch_size = len(cameras) * len(cameras[0])
|
||
|
coords = (
|
||
|
flat_camera.image_coords()
|
||
|
.to(flat_camera.origin.device)
|
||
|
.unsqueeze(0)
|
||
|
.repeat(batch_size, 1, 1)
|
||
|
)
|
||
|
rays = flat_camera.camera_rays(coords)
|
||
|
return (
|
||
|
rays.view(len(cameras), len(cameras[0]), camera.height, camera.width, 2, 3).to(
|
||
|
self.device
|
||
|
),
|
||
|
flat_camera.z.view(len(cameras), len(cameras[0]), 3).to(self.device),
|
||
|
)
|
||
|
|
||
|
|
||
|
def sample_pcl_fps(points: torch.Tensor, data_ctx: int, method: str = "fps") -> torch.Tensor:
|
||
|
"""
|
||
|
Run farthest-point sampling on a batch of point clouds.
|
||
|
|
||
|
:param points: batch of shape [N x num_points].
|
||
|
:param data_ctx: subsample count.
|
||
|
:param method: either 'fps' or 'first'. Using 'first' assumes that the
|
||
|
points are already sorted according to FPS sampling.
|
||
|
:return: batch of shape [N x min(num_points, data_ctx)].
|
||
|
"""
|
||
|
n_points = points.shape[1]
|
||
|
if n_points == data_ctx:
|
||
|
return points
|
||
|
if method == "first":
|
||
|
return points[:, :data_ctx]
|
||
|
elif method == "fps":
|
||
|
batch = points.cpu().split(1, dim=0)
|
||
|
fps = [sample_fps(x, n_samples=data_ctx) for x in batch]
|
||
|
return torch.cat(fps, dim=0).to(points.device)
|
||
|
else:
|
||
|
raise ValueError(f"unsupported farthest-point sampling method: {method}")
|
||
|
|
||
|
|
||
|
def sample_fps(example: torch.Tensor, n_samples: int) -> torch.Tensor:
|
||
|
"""
|
||
|
:param example: [1, n_points, 3 + n_channels]
|
||
|
:return: [1, n_samples, 3 + n_channels]
|
||
|
"""
|
||
|
points = example.cpu().squeeze(0).numpy()
|
||
|
coords, raw_channels = points[:, :3], points[:, 3:]
|
||
|
n_points, n_channels = raw_channels.shape
|
||
|
assert n_samples <= n_points
|
||
|
channels = {str(idx): raw_channels[:, idx] for idx in range(n_channels)}
|
||
|
max_points = min(32768, n_points)
|
||
|
fps_pcl = (
|
||
|
PointCloud(coords=coords, channels=channels)
|
||
|
.random_sample(max_points)
|
||
|
.farthest_point_sample(n_samples)
|
||
|
)
|
||
|
fps_channels = np.stack([fps_pcl.channels[str(idx)] for idx in range(n_channels)], axis=1)
|
||
|
fps = np.concatenate([fps_pcl.coords, fps_channels], axis=1)
|
||
|
fps = torch.from_numpy(fps).unsqueeze(0)
|
||
|
assert fps.shape == (1, n_samples, 3 + n_channels)
|
||
|
return fps
|