You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
241 lines
8.8 KiB
241 lines
8.8 KiB
2 years ago
|
from abc import abstractmethod
|
||
|
from typing import Callable, Dict, List, Optional, Tuple
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
|
||
|
from shap_e.models.nn.camera import (
|
||
|
DifferentiableCamera,
|
||
|
DifferentiableProjectiveCamera,
|
||
|
get_image_coords,
|
||
|
projective_camera_frame,
|
||
|
)
|
||
|
from shap_e.models.nn.meta import MetaModule
|
||
|
from shap_e.util.collections import AttrDict
|
||
|
|
||
|
|
||
|
class Renderer(MetaModule):
|
||
|
"""
|
||
|
A rendering abstraction that can render rays and views by calling the
|
||
|
appropriate models. The models are instantiated outside but registered in
|
||
|
this module.
|
||
|
"""
|
||
|
|
||
|
@abstractmethod
|
||
|
def render_views(
|
||
|
self,
|
||
|
batch: AttrDict,
|
||
|
params: Optional[Dict] = None,
|
||
|
options: Optional[Dict] = None,
|
||
|
) -> AttrDict:
|
||
|
"""
|
||
|
Returns a backproppable rendering of a view
|
||
|
|
||
|
:param batch: contains
|
||
|
- height: Optional[int]
|
||
|
- width: Optional[int]
|
||
|
- inner_batch_size or ray_batch_size: Optional[int] defaults to 4096 rays
|
||
|
|
||
|
And additionally, to specify poses with a default up direction:
|
||
|
- poses: [batch_size x *shape x 2 x 3] where poses[:, ..., 0, :] are the camera
|
||
|
positions, and poses[:, ..., 1, :] are the z-axis (toward the object) of
|
||
|
the camera frame.
|
||
|
- camera: DifferentiableCamera. Assumes the same camera position
|
||
|
across batch for simplicity. Could eventually support
|
||
|
batched cameras.
|
||
|
|
||
|
or to specify a batch of arbitrary poses:
|
||
|
- cameras: DifferentiableCameraBatch of shape [batch_size x *shape].
|
||
|
|
||
|
:param params: Meta parameters
|
||
|
:param options: Optional[Dict]
|
||
|
"""
|
||
|
|
||
|
|
||
|
class RayRenderer(Renderer):
|
||
|
"""
|
||
|
A rendering abstraction that can render rays and views by calling the
|
||
|
appropriate models. The models are instantiated outside but registered in
|
||
|
this module.
|
||
|
"""
|
||
|
|
||
|
@abstractmethod
|
||
|
def render_rays(
|
||
|
self,
|
||
|
batch: AttrDict,
|
||
|
params: Optional[Dict] = None,
|
||
|
options: Optional[Dict] = None,
|
||
|
) -> AttrDict:
|
||
|
"""
|
||
|
:param batch: has
|
||
|
- rays: [batch_size x ... x 2 x 3] specify the origin and direction of each ray.
|
||
|
- radii (optional): [batch_size x ... x 1] the "thickness" of each ray.
|
||
|
:param options: Optional[Dict]
|
||
|
"""
|
||
|
|
||
|
def render_views(
|
||
|
self,
|
||
|
batch: AttrDict,
|
||
|
params: Optional[Dict] = None,
|
||
|
options: Optional[Dict] = None,
|
||
|
device: torch.device = torch.device("cuda"),
|
||
|
) -> AttrDict:
|
||
|
output = render_views_from_rays(
|
||
|
self.render_rays,
|
||
|
batch,
|
||
|
params=params,
|
||
|
options=options,
|
||
|
device=self.device,
|
||
|
)
|
||
|
return output
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
batch: AttrDict,
|
||
|
params: Optional[Dict] = None,
|
||
|
options: Optional[Dict] = None,
|
||
|
) -> AttrDict:
|
||
|
"""
|
||
|
:param batch: must contain either
|
||
|
|
||
|
- rays: [batch_size x ... x 2 x 3] specify the origin and direction of each ray.
|
||
|
|
||
|
or
|
||
|
|
||
|
- poses: [batch_size x 2 x 3] where poses[:, 0] are the camera
|
||
|
positions, and poses[:, 1] are the z-axis (toward the object) of
|
||
|
the camera frame.
|
||
|
- camera: an instance of Camera that implements camera_rays
|
||
|
|
||
|
or
|
||
|
|
||
|
- cameras: DifferentiableCameraBatch of shape [batch_size x *shape].
|
||
|
|
||
|
For both of the above two options, these may be specified.
|
||
|
- height: Optional[int]
|
||
|
- width: Optional[int]
|
||
|
- ray_batch_size or inner_batch_size: Optional[int] defaults to 4096 rays
|
||
|
|
||
|
:param params: a dictionary of optional meta parameters.
|
||
|
:param options: A Dict of other hyperparameters that could be
|
||
|
related to rendering or debugging
|
||
|
|
||
|
:return: a dictionary containing
|
||
|
|
||
|
- channels: [batch_size, *shape, n_channels]
|
||
|
- distances: [batch_size, *shape, 1]
|
||
|
- transmittance: [batch_size, *shape, 1]
|
||
|
- aux_losses: Dict[str, torch.Tensor]
|
||
|
"""
|
||
|
|
||
|
if "rays" in batch:
|
||
|
for key in ["poses", "camera", "height", "width"]:
|
||
|
assert key not in batch
|
||
|
return self.render_rays(batch, params=params, options=options)
|
||
|
elif "poses" in batch or "cameras" in batch:
|
||
|
assert "rays" not in batch
|
||
|
if "poses" in batch:
|
||
|
assert "camera" in batch
|
||
|
else:
|
||
|
assert "camera" not in batch
|
||
|
return self.render_views(batch, params=params, options=options)
|
||
|
|
||
|
raise NotImplementedError
|
||
|
|
||
|
|
||
|
def get_camera_from_batch(batch: AttrDict) -> Tuple[DifferentiableCamera, int, Tuple[int]]:
|
||
|
if "poses" in batch:
|
||
|
assert not "cameras" in batch
|
||
|
batch_size, *inner_shape, n_vecs, spatial_dim = batch.poses.shape
|
||
|
assert n_vecs == 2 and spatial_dim == 3
|
||
|
inner_batch_size = int(np.prod(inner_shape))
|
||
|
poses = batch.poses.view(batch_size * inner_batch_size, 2, 3)
|
||
|
position, direction = poses[:, 0], poses[:, 1]
|
||
|
camera = projective_camera_frame(position, direction, batch.camera)
|
||
|
elif "cameras" in batch:
|
||
|
assert not "camera" in batch
|
||
|
batch_size, *inner_shape = batch.cameras.shape
|
||
|
camera = batch.cameras.flat_camera
|
||
|
else:
|
||
|
raise ValueError(f'neither "poses" nor "cameras" found in keys: {batch.keys()}')
|
||
|
if "height" in batch and "width" in batch:
|
||
|
camera = camera.resize_image(batch.width, batch.height)
|
||
|
return camera, batch_size, inner_shape
|
||
|
|
||
|
|
||
|
def append_tensor(val_list: Optional[List[torch.Tensor]], output: Optional[torch.Tensor]):
|
||
|
if val_list is None:
|
||
|
return [output]
|
||
|
return val_list + [output]
|
||
|
|
||
|
|
||
|
def render_views_from_rays(
|
||
|
render_rays: Callable[[AttrDict, AttrDict, AttrDict], AttrDict],
|
||
|
batch: AttrDict,
|
||
|
params: Optional[Dict] = None,
|
||
|
options: Optional[Dict] = None,
|
||
|
device: torch.device = torch.device("cuda"),
|
||
|
) -> AttrDict:
|
||
|
camera, batch_size, inner_shape = get_camera_from_batch(batch)
|
||
|
inner_batch_size = int(np.prod(inner_shape))
|
||
|
|
||
|
coords = get_image_coords(camera.width, camera.height).to(device)
|
||
|
coords = torch.broadcast_to(coords.unsqueeze(0), [batch_size * inner_batch_size, *coords.shape])
|
||
|
rays = camera.camera_rays(coords)
|
||
|
|
||
|
# mip-NeRF radii calculation from: https://github.com/google/mipnerf/blob/84c969e0a623edd183b75693aed72a7e7c22902d/internal/datasets.py#L193-L200
|
||
|
directions = rays.view(batch_size, inner_batch_size, camera.height, camera.width, 2, 3)[
|
||
|
..., 1, :
|
||
|
]
|
||
|
neighbor_dists = torch.linalg.norm(directions[:, :, :, 1:] - directions[:, :, :, :-1], dim=-1)
|
||
|
neighbor_dists = torch.cat([neighbor_dists, neighbor_dists[:, :, :, -2:-1]], dim=3)
|
||
|
radii = (neighbor_dists * 2 / np.sqrt(12)).view(batch_size, -1, 1)
|
||
|
|
||
|
rays = rays.view(batch_size, inner_batch_size * camera.height * camera.width, 2, 3)
|
||
|
|
||
|
if isinstance(camera, DifferentiableProjectiveCamera):
|
||
|
# Compute the camera z direction corresponding to every ray's pixel.
|
||
|
# Used for depth computations below.
|
||
|
z_directions = (
|
||
|
(camera.z / torch.linalg.norm(camera.z, dim=-1, keepdim=True))
|
||
|
.reshape([batch_size, inner_batch_size, 1, 3])
|
||
|
.repeat(1, 1, camera.width * camera.height, 1)
|
||
|
.reshape(1, inner_batch_size * camera.height * camera.width, 3)
|
||
|
)
|
||
|
|
||
|
ray_batch_size = batch.get("ray_batch_size", batch.get("inner_batch_size", 4096))
|
||
|
assert rays.shape[1] % ray_batch_size == 0
|
||
|
n_batches = rays.shape[1] // ray_batch_size
|
||
|
|
||
|
output_list = AttrDict(aux_losses=dict())
|
||
|
|
||
|
for idx in range(n_batches):
|
||
|
rays_batch = AttrDict(
|
||
|
rays=rays[:, idx * ray_batch_size : (idx + 1) * ray_batch_size],
|
||
|
radii=radii[:, idx * ray_batch_size : (idx + 1) * ray_batch_size],
|
||
|
)
|
||
|
output = render_rays(rays_batch, params=params, options=options)
|
||
|
|
||
|
if isinstance(camera, DifferentiableProjectiveCamera):
|
||
|
z_batch = z_directions[:, idx * ray_batch_size : (idx + 1) * ray_batch_size]
|
||
|
ray_directions = rays_batch.rays[:, :, 1]
|
||
|
z_dots = (ray_directions * z_batch).sum(-1, keepdim=True)
|
||
|
output.depth = output.distances * z_dots
|
||
|
|
||
|
output_list = output_list.combine(output, append_tensor)
|
||
|
|
||
|
def _resize(val_list: List[torch.Tensor]):
|
||
|
val = torch.cat(val_list, dim=1)
|
||
|
assert val.shape[1] == inner_batch_size * camera.height * camera.width
|
||
|
return val.view(batch_size, *inner_shape, camera.height, camera.width, -1)
|
||
|
|
||
|
def _avg(_key: str, loss_list: List[torch.Tensor]):
|
||
|
return sum(loss_list) / n_batches
|
||
|
|
||
|
output = AttrDict(
|
||
|
{name: _resize(val_list) for name, val_list in output_list.items() if name != "aux_losses"}
|
||
|
)
|
||
|
output.aux_losses = output_list.aux_losses.map(_avg)
|
||
|
|
||
|
return output
|