a fork of shap-e for gc
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

960 lines
34 KiB

2 years ago
from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torch import torch
from shap_e.models.generation.perceiver import SimplePerceiver
from shap_e.models.generation.transformer import Transformer
from shap_e.models.nn.camera import DifferentiableProjectiveCamera
from shap_e.models.nn.encoding import (
MultiviewPointCloudEmbedding,
MultiviewPoseEmbedding,
PosEmbLinear,
)
from shap_e.models.nn.ops import PointSetEmbedding
from shap_e.rendering.point_cloud import PointCloud
from shap_e.rendering.view_data import ProjectiveCamera
from shap_e.util.collections import AttrDict
from .base import ChannelsEncoder
class TransformerChannelsEncoder(ChannelsEncoder, ABC):
"""
Encode point clouds using a transformer model with an extra output
token used to extract a latent vector.
"""
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
param_shapes: Dict[str, Tuple[int]],
params_proj: Dict[str, Any],
d_latent: int = 512,
latent_bottleneck: Optional[Dict[str, Any]] = None,
latent_warp: Optional[Dict[str, Any]] = None,
n_ctx: int = 1024,
width: int = 512,
layers: int = 12,
heads: int = 8,
init_scale: float = 0.25,
latent_scale: float = 1.0,
):
super().__init__(
device=device,
param_shapes=param_shapes,
params_proj=params_proj,
d_latent=d_latent,
latent_bottleneck=latent_bottleneck,
latent_warp=latent_warp,
)
self.width = width
self.device = device
self.dtype = dtype
self.n_ctx = n_ctx
self.backbone = Transformer(
device=device,
dtype=dtype,
n_ctx=n_ctx + self.latent_ctx,
width=width,
layers=layers,
heads=heads,
init_scale=init_scale,
)
self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype)
self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
self.register_parameter(
"output_tokens",
nn.Parameter(torch.randn(self.latent_ctx, width, device=device, dtype=dtype)),
)
self.output_proj = nn.Linear(width, d_latent, device=device, dtype=dtype)
self.latent_scale = latent_scale
@abstractmethod
def encode_input(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor:
pass
def encode_to_channels(
self, batch: AttrDict, options: Optional[AttrDict] = None
) -> torch.Tensor:
h = self.encode_input(batch, options=options)
h = torch.cat([h, self.output_tokens[None].repeat(len(h), 1, 1)], dim=1)
h = self.ln_pre(h)
h = self.backbone(h)
h = h[:, -self.latent_ctx :]
h = self.ln_post(h)
h = self.output_proj(h)
return h
class PerceiverChannelsEncoder(ChannelsEncoder, ABC):
"""
Encode point clouds using a perceiver model with an extra output
token used to extract a latent vector.
"""
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
param_shapes: Dict[str, Tuple[int]],
params_proj: Dict[str, Any],
min_unrolls: int,
max_unrolls: int,
d_latent: int = 512,
latent_bottleneck: Optional[Dict[str, Any]] = None,
latent_warp: Optional[Dict[str, Any]] = None,
width: int = 512,
layers: int = 12,
xattn_layers: int = 1,
heads: int = 8,
init_scale: float = 0.25,
# Training hparams
inner_batch_size: Union[int, List[int]] = 1,
data_ctx: int = 1,
):
super().__init__(
device=device,
param_shapes=param_shapes,
params_proj=params_proj,
d_latent=d_latent,
latent_bottleneck=latent_bottleneck,
latent_warp=latent_warp,
)
self.width = width
self.device = device
self.dtype = dtype
if isinstance(inner_batch_size, int):
inner_batch_size = [inner_batch_size]
self.inner_batch_size = inner_batch_size
self.data_ctx = data_ctx
self.min_unrolls = min_unrolls
self.max_unrolls = max_unrolls
encoder_fn = lambda inner_batch_size: SimplePerceiver(
device=device,
dtype=dtype,
n_ctx=self.data_ctx + self.latent_ctx,
n_data=inner_batch_size,
width=width,
layers=xattn_layers,
heads=heads,
init_scale=init_scale,
)
self.encoder = (
encoder_fn(self.inner_batch_size[0])
if len(self.inner_batch_size) == 1
else nn.ModuleList([encoder_fn(inner_bsz) for inner_bsz in self.inner_batch_size])
)
self.processor = Transformer(
device=device,
dtype=dtype,
n_ctx=self.data_ctx + self.latent_ctx,
layers=layers - xattn_layers,
width=width,
heads=heads,
init_scale=init_scale,
)
self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype)
self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
self.register_parameter(
"output_tokens",
nn.Parameter(torch.randn(self.latent_ctx, width, device=device, dtype=dtype)),
)
self.output_proj = nn.Linear(width, d_latent, device=device, dtype=dtype)
@abstractmethod
def get_h_and_iterator(
self, batch: AttrDict, options: Optional[AttrDict] = None
) -> Tuple[torch.Tensor, Iterable[Union[torch.Tensor, Tuple]]]:
"""
:return: a tuple of (
the initial output tokens of size [batch_size, data_ctx + latent_ctx, width],
an iterator over the given data
)
"""
def encode_to_channels(
self, batch: AttrDict, options: Optional[AttrDict] = None
) -> torch.Tensor:
h, it = self.get_h_and_iterator(batch, options=options)
n_unrolls = self.get_n_unrolls()
for _ in range(n_unrolls):
data = next(it)
if isinstance(data, tuple):
for data_i, encoder_i in zip(data, self.encoder):
h = encoder_i(h, data_i)
else:
h = self.encoder(h, data)
h = self.processor(h)
h = self.output_proj(self.ln_post(h[:, -self.latent_ctx :]))
return h
def get_n_unrolls(self):
if self.training:
n_unrolls = torch.randint(
self.min_unrolls, self.max_unrolls + 1, size=(), device=self.device
)
dist.broadcast(n_unrolls, 0)
n_unrolls = n_unrolls.item()
else:
n_unrolls = self.max_unrolls
return n_unrolls
@dataclass
class DatasetIterator:
embs: torch.Tensor # [batch_size, dataset_size, *shape]
batch_size: int
def __iter__(self):
self._reset()
return self
def __next__(self):
_outer_batch_size, dataset_size, *_shape = self.embs.shape
while True:
start = self.idx
self.idx += self.batch_size
end = self.idx
if end <= dataset_size:
break
self._reset()
return self.embs[:, start:end]
def _reset(self):
self._shuffle()
self.idx = 0 # pylint: disable=attribute-defined-outside-init
def _shuffle(self):
outer_batch_size, dataset_size, *shape = self.embs.shape
idx = torch.stack(
[
torch.randperm(dataset_size, device=self.embs.device)
for _ in range(outer_batch_size)
],
dim=0,
)
idx = idx.view(outer_batch_size, dataset_size, *([1] * len(shape)))
idx = torch.broadcast_to(idx, self.embs.shape)
self.embs = torch.gather(self.embs, 1, idx)
class PointCloudTransformerChannelsEncoder(TransformerChannelsEncoder):
"""
Encode point clouds using a transformer model with an extra output
token used to extract a latent vector.
"""
def __init__(
self,
*,
input_channels: int = 6,
**kwargs,
):
super().__init__(**kwargs)
self.input_channels = input_channels
self.input_proj = nn.Linear(
input_channels, self.width, device=self.device, dtype=self.dtype
)
def encode_input(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor:
_ = options
points = batch.points
h = self.input_proj(points.permute(0, 2, 1)) # NCL -> NLC
return h
class PointCloudPerceiverChannelsEncoder(PerceiverChannelsEncoder):
"""
Encode point clouds using a transformer model with an extra output
token used to extract a latent vector.
"""
def __init__(
self,
*,
cross_attention_dataset: str = "pcl",
fps_method: str = "fps",
# point cloud hyperparameters
input_channels: int = 6,
pos_emb: Optional[str] = None,
# multiview hyperparameters
image_size: int = 256,
patch_size: int = 32,
pose_dropout: float = 0.0,
use_depth: bool = False,
max_depth: float = 5.0,
# point conv hyperparameters
pointconv_radius: float = 0.5,
pointconv_samples: int = 32,
pointconv_hidden: Optional[List[int]] = None,
pointconv_patch_size: int = 1,
pointconv_stride: int = 1,
pointconv_padding_mode: str = "zeros",
use_pointconv: bool = False,
# other hyperparameters
**kwargs,
):
super().__init__(**kwargs)
assert cross_attention_dataset in (
"pcl",
"multiview",
"dense_pose_multiview",
"multiview_pcl",
"pcl_and_multiview_pcl",
"incorrect_multiview_pcl",
"pcl_and_incorrect_multiview_pcl",
)
assert fps_method in ("fps", "first")
self.cross_attention_dataset = cross_attention_dataset
self.fps_method = fps_method
self.input_channels = input_channels
self.input_proj = PosEmbLinear(
pos_emb,
input_channels,
self.width,
device=self.device,
dtype=self.dtype,
)
self.use_pointconv = use_pointconv
if use_pointconv:
if pointconv_hidden is None:
pointconv_hidden = [self.width]
self.point_conv = PointSetEmbedding(
n_point=self.data_ctx,
radius=pointconv_radius,
n_sample=pointconv_samples,
d_input=self.input_proj.weight.shape[0],
d_hidden=pointconv_hidden,
patch_size=pointconv_patch_size,
stride=pointconv_stride,
padding_mode=pointconv_padding_mode,
fps_method=fps_method,
device=self.device,
dtype=self.dtype,
)
if self.cross_attention_dataset == "multiview":
self.image_size = image_size
self.patch_size = patch_size
self.pose_dropout = pose_dropout
self.use_depth = use_depth
self.max_depth = max_depth
pos_ctx = (image_size // patch_size) ** 2
self.register_parameter(
"pos_emb",
nn.Parameter(
torch.randn(
pos_ctx * self.inner_batch_size,
self.width,
device=self.device,
dtype=self.dtype,
)
),
)
self.patch_emb = nn.Conv2d(
in_channels=3 if not use_depth else 4,
out_channels=self.width,
kernel_size=patch_size,
stride=patch_size,
device=self.device,
dtype=self.dtype,
)
self.camera_emb = nn.Sequential(
nn.Linear(
3 * 4 + 1, self.width, device=self.device, dtype=self.dtype
), # input size is for origin+x+y+z+fov
nn.GELU(),
nn.Linear(self.width, 2 * self.width, device=self.device, dtype=self.dtype),
)
elif self.cross_attention_dataset == "dense_pose_multiview":
# The number of output features is halved, because a patch_size of
# 32 ends up with a large patch_emb weight.
self.view_pose_width = self.width // 2
self.image_size = image_size
self.patch_size = patch_size
self.use_depth = use_depth
self.max_depth = max_depth
self.mv_pose_embed = MultiviewPoseEmbedding(
posemb_version="nerf",
n_channels=4 if self.use_depth else 3,
out_features=self.view_pose_width,
device=self.device,
dtype=self.dtype,
)
pos_ctx = (image_size // patch_size) ** 2
# Positional embedding is unnecessary because pose information is baked into each pixel
self.patch_emb = nn.Conv2d(
in_channels=self.view_pose_width,
out_channels=self.width,
kernel_size=patch_size,
stride=patch_size,
device=self.device,
dtype=self.dtype,
)
elif (
self.cross_attention_dataset == "multiview_pcl"
or self.cross_attention_dataset == "incorrect_multiview_pcl"
):
self.view_pose_width = self.width // 2
self.image_size = image_size
self.patch_size = patch_size
self.max_depth = max_depth
assert use_depth
self.mv_pcl_embed = MultiviewPointCloudEmbedding(
posemb_version="nerf",
n_channels=3,
out_features=self.view_pose_width,
device=self.device,
dtype=self.dtype,
)
self.patch_emb = nn.Conv2d(
in_channels=self.view_pose_width,
out_channels=self.width,
kernel_size=patch_size,
stride=patch_size,
device=self.device,
dtype=self.dtype,
)
elif (
self.cross_attention_dataset == "pcl_and_multiview_pcl"
or self.cross_attention_dataset == "pcl_and_incorrect_multiview_pcl"
):
self.view_pose_width = self.width // 2
self.image_size = image_size
self.patch_size = patch_size
self.max_depth = max_depth
assert use_depth
self.mv_pcl_embed = MultiviewPointCloudEmbedding(
posemb_version="nerf",
n_channels=3,
out_features=self.view_pose_width,
device=self.device,
dtype=self.dtype,
)
self.patch_emb = nn.Conv2d(
in_channels=self.view_pose_width,
out_channels=self.width,
kernel_size=patch_size,
stride=patch_size,
device=self.device,
dtype=self.dtype,
)
def get_h_and_iterator(
self, batch: AttrDict, options: Optional[AttrDict] = None
) -> Tuple[torch.Tensor, Iterable]:
"""
:return: a tuple of (
the initial output tokens of size [batch_size, data_ctx + latent_ctx, width],
an iterator over the given data
)
"""
options = AttrDict() if options is None else options
# Build the initial query embeddings
points = batch.points.permute(0, 2, 1) # NCL -> NLC
if self.use_pointconv:
points = self.input_proj(points).permute(0, 2, 1) # NLC -> NCL
xyz = batch.points[:, :3]
data_tokens = self.point_conv(xyz, points).permute(0, 2, 1) # NCL -> NLC
else:
fps_samples = self.sample_pcl_fps(points)
data_tokens = self.input_proj(fps_samples)
batch_size = points.shape[0]
latent_tokens = self.output_tokens.unsqueeze(0).repeat(batch_size, 1, 1)
h = self.ln_pre(torch.cat([data_tokens, latent_tokens], dim=1))
assert h.shape == (batch_size, self.data_ctx + self.latent_ctx, self.width)
# Build the dataset embedding iterator
dataset_fn = {
"pcl": self.get_pcl_dataset,
"multiview": self.get_multiview_dataset,
"dense_pose_multiview": self.get_dense_pose_multiview_dataset,
"pcl_and_multiview_pcl": self.get_pcl_and_multiview_pcl_dataset,
"multiview_pcl": self.get_multiview_pcl_dataset,
}[self.cross_attention_dataset]
it = dataset_fn(batch, options=options)
return h, it
def sample_pcl_fps(self, points: torch.Tensor) -> torch.Tensor:
return sample_pcl_fps(points, data_ctx=self.data_ctx, method=self.fps_method)
def get_pcl_dataset(
self,
batch: AttrDict,
options: Optional[AttrDict[str, Any]] = None,
inner_batch_size: Optional[int] = None,
) -> Iterable:
_ = options
if inner_batch_size is None:
inner_batch_size = self.inner_batch_size[0]
points = batch.points.permute(0, 2, 1) # NCL -> NLC
dataset_emb = self.input_proj(points)
assert dataset_emb.shape[1] >= inner_batch_size
return iter(DatasetIterator(dataset_emb, batch_size=inner_batch_size))
def get_multiview_dataset(
self,
batch: AttrDict,
options: Optional[AttrDict] = None,
inner_batch_size: Optional[int] = None,
) -> Iterable:
_ = options
if inner_batch_size is None:
inner_batch_size = self.inner_batch_size[0]
dataset_emb = self.encode_views(batch)
batch_size, num_views, n_patches, width = dataset_emb.shape
assert num_views >= inner_batch_size
it = iter(DatasetIterator(dataset_emb, batch_size=inner_batch_size))
def gen():
while True:
examples = next(it)
assert examples.shape == (batch_size, self.inner_batch_size, n_patches, self.width)
views = examples.reshape(batch_size, -1, width) + self.pos_emb
yield views
return gen()
def get_dense_pose_multiview_dataset(
self,
batch: AttrDict,
options: Optional[AttrDict] = None,
inner_batch_size: Optional[int] = None,
) -> Iterable:
_ = options
if inner_batch_size is None:
inner_batch_size = self.inner_batch_size[0]
dataset_emb = self.encode_dense_pose_views(batch)
batch_size, num_views, n_patches, width = dataset_emb.shape
assert num_views >= inner_batch_size
it = iter(DatasetIterator(dataset_emb, batch_size=inner_batch_size))
def gen():
while True:
examples = next(it)
assert examples.shape == (batch_size, inner_batch_size, n_patches, self.width)
views = examples.reshape(batch_size, -1, width)
yield views
return gen()
def get_pcl_and_multiview_pcl_dataset(
self,
batch: AttrDict,
options: Optional[AttrDict] = None,
use_distance: bool = True,
) -> Iterable:
_ = options
pcl_it = self.get_pcl_dataset(
batch, options=options, inner_batch_size=self.inner_batch_size[0]
)
multiview_pcl_emb = self.encode_multiview_pcl(batch, use_distance=use_distance)
batch_size, num_views, n_patches, width = multiview_pcl_emb.shape
assert num_views >= self.inner_batch_size[1]
multiview_pcl_it = iter(
DatasetIterator(multiview_pcl_emb, batch_size=self.inner_batch_size[1])
)
def gen():
while True:
pcl = next(pcl_it)
multiview_pcl = next(multiview_pcl_it)
assert multiview_pcl.shape == (
batch_size,
self.inner_batch_size[1],
n_patches,
self.width,
)
yield pcl, multiview_pcl.reshape(batch_size, -1, width)
return gen()
def get_multiview_pcl_dataset(
self,
batch: AttrDict,
options: Optional[AttrDict] = None,
inner_batch_size: Optional[int] = None,
use_distance: bool = True,
) -> Iterable:
_ = options
if inner_batch_size is None:
inner_batch_size = self.inner_batch_size[0]
multiview_pcl_emb = self.encode_multiview_pcl(batch, use_distance=use_distance)
batch_size, num_views, n_patches, width = multiview_pcl_emb.shape
assert num_views >= inner_batch_size
multiview_pcl_it = iter(DatasetIterator(multiview_pcl_emb, batch_size=inner_batch_size))
def gen():
while True:
multiview_pcl = next(multiview_pcl_it)
assert multiview_pcl.shape == (
batch_size,
inner_batch_size,
n_patches,
self.width,
)
yield multiview_pcl.reshape(batch_size, -1, width)
return gen()
def encode_views(self, batch: AttrDict) -> torch.Tensor:
"""
:return: [batch_size, num_views, n_patches, width]
"""
all_views = self.views_to_tensor(batch.views).to(self.device)
if self.use_depth:
all_views = torch.cat([all_views, self.depths_to_tensor(batch.depths)], dim=2)
all_cameras = self.cameras_to_tensor(batch.cameras).to(self.device)
batch_size, num_views, _, _, _ = all_views.shape
views_proj = self.patch_emb(
all_views.reshape([batch_size * num_views, *all_views.shape[2:]])
)
views_proj = (
views_proj.reshape([batch_size, num_views, self.width, -1])
.permute(0, 1, 3, 2)
.contiguous()
) # [batch_size x num_views x n_patches x width]
# [batch_size, num_views, 1, 2 * width]
camera_proj = self.camera_emb(all_cameras).reshape(
[batch_size, num_views, 1, self.width * 2]
)
pose_dropout = self.pose_dropout if self.training else 0.0
mask = torch.rand(batch_size, 1, 1, 1, device=views_proj.device) >= pose_dropout
camera_proj = torch.where(mask, camera_proj, torch.zeros_like(camera_proj))
scale, shift = camera_proj.chunk(2, dim=3)
views_proj = views_proj * (scale + 1.0) + shift
return views_proj
def encode_dense_pose_views(self, batch: AttrDict) -> torch.Tensor:
"""
:return: [batch_size, num_views, n_patches, width]
"""
all_views = self.views_to_tensor(batch.views).to(self.device)
if self.use_depth:
depths = self.depths_to_tensor(batch.depths)
all_views = torch.cat([all_views, depths], dim=2)
dense_poses, _ = self.dense_pose_cameras_to_tensor(batch.cameras)
dense_poses = dense_poses.permute(0, 1, 4, 5, 2, 3)
position, direction = dense_poses[:, :, 0], dense_poses[:, :, 1]
all_view_poses = self.mv_pose_embed(all_views, position, direction)
batch_size, num_views, _, _, _ = all_view_poses.shape
views_proj = self.patch_emb(
all_view_poses.reshape([batch_size * num_views, *all_view_poses.shape[2:]])
)
views_proj = (
views_proj.reshape([batch_size, num_views, self.width, -1])
.permute(0, 1, 3, 2)
.contiguous()
) # [batch_size x num_views x n_patches x width]
return views_proj
def encode_multiview_pcl(self, batch: AttrDict, use_distance: bool = True) -> torch.Tensor:
"""
:return: [batch_size, num_views, n_patches, width]
"""
all_views = self.views_to_tensor(batch.views).to(self.device)
depths = self.raw_depths_to_tensor(batch.depths)
all_view_alphas = self.view_alphas_to_tensor(batch.view_alphas).to(self.device)
mask = all_view_alphas >= 0.999
dense_poses, camera_z = self.dense_pose_cameras_to_tensor(batch.cameras)
dense_poses = dense_poses.permute(0, 1, 4, 5, 2, 3)
origin, direction = dense_poses[:, :, 0], dense_poses[:, :, 1]
if use_distance:
ray_depth_factor = torch.sum(direction * camera_z[..., None, None], dim=2, keepdim=True)
depths = depths / ray_depth_factor
position = origin + depths * direction
all_view_poses = self.mv_pcl_embed(all_views, origin, position, mask)
batch_size, num_views, _, _, _ = all_view_poses.shape
views_proj = self.patch_emb(
all_view_poses.reshape([batch_size * num_views, *all_view_poses.shape[2:]])
)
views_proj = (
views_proj.reshape([batch_size, num_views, self.width, -1])
.permute(0, 1, 3, 2)
.contiguous()
) # [batch_size x num_views x n_patches x width]
return views_proj
def views_to_tensor(self, views: Union[torch.Tensor, List[List[Image.Image]]]) -> torch.Tensor:
"""
Returns a [batch x num_views x 3 x size x size] tensor in the range [-1, 1].
"""
if isinstance(views, torch.Tensor):
return views
tensor_batch = []
num_views = len(views[0])
for inner_list in views:
assert len(inner_list) == num_views
inner_batch = []
for img in inner_list:
img = img.resize((self.image_size,) * 2).convert("RGB")
inner_batch.append(
torch.from_numpy(np.array(img)).to(device=self.device, dtype=torch.float32)
/ 127.5
- 1
)
tensor_batch.append(torch.stack(inner_batch, dim=0))
return torch.stack(tensor_batch, dim=0).permute(0, 1, 4, 2, 3)
def depths_to_tensor(
self, depths: Union[torch.Tensor, List[List[Image.Image]]]
) -> torch.Tensor:
"""
Returns a [batch x num_views x 1 x size x size] tensor in the range [-1, 1].
"""
if isinstance(depths, torch.Tensor):
return depths
tensor_batch = []
num_views = len(depths[0])
for inner_list in depths:
assert len(inner_list) == num_views
inner_batch = []
for arr in inner_list:
tensor = torch.from_numpy(arr).clamp(max=self.max_depth) / self.max_depth
tensor = tensor * 2 - 1
tensor = F.interpolate(
tensor[None, None],
(self.image_size,) * 2,
mode="nearest",
)
inner_batch.append(tensor.to(device=self.device, dtype=torch.float32))
tensor_batch.append(torch.cat(inner_batch, dim=0))
return torch.stack(tensor_batch, dim=0)
def view_alphas_to_tensor(
self, view_alphas: Union[torch.Tensor, List[List[Image.Image]]]
) -> torch.Tensor:
"""
Returns a [batch x num_views x 1 x size x size] tensor in the range [0, 1].
"""
if isinstance(view_alphas, torch.Tensor):
return view_alphas
tensor_batch = []
num_views = len(view_alphas[0])
for inner_list in view_alphas:
assert len(inner_list) == num_views
inner_batch = []
for img in inner_list:
tensor = (
torch.from_numpy(np.array(img)).to(device=self.device, dtype=torch.float32)
/ 255.0
)
tensor = F.interpolate(
tensor[None, None],
(self.image_size,) * 2,
mode="nearest",
)
inner_batch.append(tensor)
tensor_batch.append(torch.cat(inner_batch, dim=0))
return torch.stack(tensor_batch, dim=0)
def raw_depths_to_tensor(
self, depths: Union[torch.Tensor, List[List[Image.Image]]]
) -> torch.Tensor:
"""
Returns a [batch x num_views x 1 x size x size] tensor
"""
if isinstance(depths, torch.Tensor):
return depths
tensor_batch = []
num_views = len(depths[0])
for inner_list in depths:
assert len(inner_list) == num_views
inner_batch = []
for arr in inner_list:
tensor = torch.from_numpy(arr).clamp(max=self.max_depth)
tensor = F.interpolate(
tensor[None, None],
(self.image_size,) * 2,
mode="nearest",
)
inner_batch.append(tensor.to(device=self.device, dtype=torch.float32))
tensor_batch.append(torch.cat(inner_batch, dim=0))
return torch.stack(tensor_batch, dim=0)
def cameras_to_tensor(
self, cameras: Union[torch.Tensor, List[List[ProjectiveCamera]]]
) -> torch.Tensor:
"""
Returns a [batch x num_views x 3*4+1] tensor of camera information.
"""
if isinstance(cameras, torch.Tensor):
return cameras
outer_batch = []
for inner_list in cameras:
inner_batch = []
for camera in inner_list:
inner_batch.append(
np.array(
[
*camera.x,
*camera.y,
*camera.z,
*camera.origin,
camera.x_fov,
]
)
)
outer_batch.append(np.stack(inner_batch, axis=0))
return torch.from_numpy(np.stack(outer_batch, axis=0)).float()
def dense_pose_cameras_to_tensor(
self, cameras: Union[torch.Tensor, List[List[ProjectiveCamera]]]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Returns a tuple of (rays, z_directions) where
- rays: [batch, num_views, height, width, 2, 3] tensor of camera information.
- z_directions: [batch, num_views, 3] tensor of camera z directions.
"""
if isinstance(cameras, torch.Tensor):
raise NotImplementedError
for inner_list in cameras:
assert len(inner_list) == len(cameras[0])
camera = cameras[0][0]
flat_camera = DifferentiableProjectiveCamera(
origin=torch.from_numpy(
np.stack(
[cam.origin for inner_list in cameras for cam in inner_list],
axis=0,
)
).to(self.device),
x=torch.from_numpy(
np.stack(
[cam.x for inner_list in cameras for cam in inner_list],
axis=0,
)
).to(self.device),
y=torch.from_numpy(
np.stack(
[cam.y for inner_list in cameras for cam in inner_list],
axis=0,
)
).to(self.device),
z=torch.from_numpy(
np.stack(
[cam.z for inner_list in cameras for cam in inner_list],
axis=0,
)
).to(self.device),
width=camera.width,
height=camera.height,
x_fov=camera.x_fov,
y_fov=camera.y_fov,
)
batch_size = len(cameras) * len(cameras[0])
coords = (
flat_camera.image_coords()
.to(flat_camera.origin.device)
.unsqueeze(0)
.repeat(batch_size, 1, 1)
)
rays = flat_camera.camera_rays(coords)
return (
rays.view(len(cameras), len(cameras[0]), camera.height, camera.width, 2, 3).to(
self.device
),
flat_camera.z.view(len(cameras), len(cameras[0]), 3).to(self.device),
)
def sample_pcl_fps(points: torch.Tensor, data_ctx: int, method: str = "fps") -> torch.Tensor:
"""
Run farthest-point sampling on a batch of point clouds.
:param points: batch of shape [N x num_points].
:param data_ctx: subsample count.
:param method: either 'fps' or 'first'. Using 'first' assumes that the
points are already sorted according to FPS sampling.
:return: batch of shape [N x min(num_points, data_ctx)].
"""
n_points = points.shape[1]
if n_points == data_ctx:
return points
if method == "first":
return points[:, :data_ctx]
elif method == "fps":
batch = points.cpu().split(1, dim=0)
fps = [sample_fps(x, n_samples=data_ctx) for x in batch]
return torch.cat(fps, dim=0).to(points.device)
else:
raise ValueError(f"unsupported farthest-point sampling method: {method}")
def sample_fps(example: torch.Tensor, n_samples: int) -> torch.Tensor:
"""
:param example: [1, n_points, 3 + n_channels]
:return: [1, n_samples, 3 + n_channels]
"""
points = example.cpu().squeeze(0).numpy()
coords, raw_channels = points[:, :3], points[:, 3:]
n_points, n_channels = raw_channels.shape
assert n_samples <= n_points
channels = {str(idx): raw_channels[:, idx] for idx in range(n_channels)}
max_points = min(32768, n_points)
fps_pcl = (
PointCloud(coords=coords, channels=channels)
.random_sample(max_points)
.farthest_point_sample(n_samples)
)
fps_channels = np.stack([fps_pcl.channels[str(idx)] for idx in range(n_channels)], axis=1)
fps = np.concatenate([fps_pcl.coords, fps_channels], axis=1)
fps = torch.from_numpy(fps).unsqueeze(0)
assert fps.shape == (1, n_samples, 3 + n_channels)
return fps