a fork of shap-e for gc
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

240 lines
8.8 KiB

from abc import abstractmethod
from typing import Callable, Dict, List, Optional, Tuple
import numpy as np
import torch
from shap_e.models.nn.camera import (
DifferentiableCamera,
DifferentiableProjectiveCamera,
get_image_coords,
projective_camera_frame,
)
from shap_e.models.nn.meta import MetaModule
from shap_e.util.collections import AttrDict
class Renderer(MetaModule):
"""
A rendering abstraction that can render rays and views by calling the
appropriate models. The models are instantiated outside but registered in
this module.
"""
@abstractmethod
def render_views(
self,
batch: AttrDict,
params: Optional[Dict] = None,
options: Optional[Dict] = None,
) -> AttrDict:
"""
Returns a backproppable rendering of a view
:param batch: contains
- height: Optional[int]
- width: Optional[int]
- inner_batch_size or ray_batch_size: Optional[int] defaults to 4096 rays
And additionally, to specify poses with a default up direction:
- poses: [batch_size x *shape x 2 x 3] where poses[:, ..., 0, :] are the camera
positions, and poses[:, ..., 1, :] are the z-axis (toward the object) of
the camera frame.
- camera: DifferentiableCamera. Assumes the same camera position
across batch for simplicity. Could eventually support
batched cameras.
or to specify a batch of arbitrary poses:
- cameras: DifferentiableCameraBatch of shape [batch_size x *shape].
:param params: Meta parameters
:param options: Optional[Dict]
"""
class RayRenderer(Renderer):
"""
A rendering abstraction that can render rays and views by calling the
appropriate models. The models are instantiated outside but registered in
this module.
"""
@abstractmethod
def render_rays(
self,
batch: AttrDict,
params: Optional[Dict] = None,
options: Optional[Dict] = None,
) -> AttrDict:
"""
:param batch: has
- rays: [batch_size x ... x 2 x 3] specify the origin and direction of each ray.
- radii (optional): [batch_size x ... x 1] the "thickness" of each ray.
:param options: Optional[Dict]
"""
def render_views(
self,
batch: AttrDict,
params: Optional[Dict] = None,
options: Optional[Dict] = None,
device: torch.device = torch.device("cuda"),
) -> AttrDict:
output = render_views_from_rays(
self.render_rays,
batch,
params=params,
options=options,
device=self.device,
)
return output
def forward(
self,
batch: AttrDict,
params: Optional[Dict] = None,
options: Optional[Dict] = None,
) -> AttrDict:
"""
:param batch: must contain either
- rays: [batch_size x ... x 2 x 3] specify the origin and direction of each ray.
or
- poses: [batch_size x 2 x 3] where poses[:, 0] are the camera
positions, and poses[:, 1] are the z-axis (toward the object) of
the camera frame.
- camera: an instance of Camera that implements camera_rays
or
- cameras: DifferentiableCameraBatch of shape [batch_size x *shape].
For both of the above two options, these may be specified.
- height: Optional[int]
- width: Optional[int]
- ray_batch_size or inner_batch_size: Optional[int] defaults to 4096 rays
:param params: a dictionary of optional meta parameters.
:param options: A Dict of other hyperparameters that could be
related to rendering or debugging
:return: a dictionary containing
- channels: [batch_size, *shape, n_channels]
- distances: [batch_size, *shape, 1]
- transmittance: [batch_size, *shape, 1]
- aux_losses: Dict[str, torch.Tensor]
"""
if "rays" in batch:
for key in ["poses", "camera", "height", "width"]:
assert key not in batch
return self.render_rays(batch, params=params, options=options)
elif "poses" in batch or "cameras" in batch:
assert "rays" not in batch
if "poses" in batch:
assert "camera" in batch
else:
assert "camera" not in batch
return self.render_views(batch, params=params, options=options)
raise NotImplementedError
def get_camera_from_batch(batch: AttrDict) -> Tuple[DifferentiableCamera, int, Tuple[int]]:
if "poses" in batch:
assert not "cameras" in batch
batch_size, *inner_shape, n_vecs, spatial_dim = batch.poses.shape
assert n_vecs == 2 and spatial_dim == 3
inner_batch_size = int(np.prod(inner_shape))
poses = batch.poses.view(batch_size * inner_batch_size, 2, 3)
position, direction = poses[:, 0], poses[:, 1]
camera = projective_camera_frame(position, direction, batch.camera)
elif "cameras" in batch:
assert not "camera" in batch
batch_size, *inner_shape = batch.cameras.shape
camera = batch.cameras.flat_camera
else:
raise ValueError(f'neither "poses" nor "cameras" found in keys: {batch.keys()}')
if "height" in batch and "width" in batch:
camera = camera.resize_image(batch.width, batch.height)
return camera, batch_size, inner_shape
def append_tensor(val_list: Optional[List[torch.Tensor]], output: Optional[torch.Tensor]):
if val_list is None:
return [output]
return val_list + [output]
def render_views_from_rays(
render_rays: Callable[[AttrDict, AttrDict, AttrDict], AttrDict],
batch: AttrDict,
params: Optional[Dict] = None,
options: Optional[Dict] = None,
device: torch.device = torch.device("cuda"),
) -> AttrDict:
camera, batch_size, inner_shape = get_camera_from_batch(batch)
inner_batch_size = int(np.prod(inner_shape))
coords = get_image_coords(camera.width, camera.height).to(device)
coords = torch.broadcast_to(coords.unsqueeze(0), [batch_size * inner_batch_size, *coords.shape])
rays = camera.camera_rays(coords)
# mip-NeRF radii calculation from: https://github.com/google/mipnerf/blob/84c969e0a623edd183b75693aed72a7e7c22902d/internal/datasets.py#L193-L200
directions = rays.view(batch_size, inner_batch_size, camera.height, camera.width, 2, 3)[
..., 1, :
]
neighbor_dists = torch.linalg.norm(directions[:, :, :, 1:] - directions[:, :, :, :-1], dim=-1)
neighbor_dists = torch.cat([neighbor_dists, neighbor_dists[:, :, :, -2:-1]], dim=3)
radii = (neighbor_dists * 2 / np.sqrt(12)).view(batch_size, -1, 1)
rays = rays.view(batch_size, inner_batch_size * camera.height * camera.width, 2, 3)
if isinstance(camera, DifferentiableProjectiveCamera):
# Compute the camera z direction corresponding to every ray's pixel.
# Used for depth computations below.
z_directions = (
(camera.z / torch.linalg.norm(camera.z, dim=-1, keepdim=True))
.reshape([batch_size, inner_batch_size, 1, 3])
.repeat(1, 1, camera.width * camera.height, 1)
.reshape(1, inner_batch_size * camera.height * camera.width, 3)
)
ray_batch_size = batch.get("ray_batch_size", batch.get("inner_batch_size", 4096))
assert rays.shape[1] % ray_batch_size == 0
n_batches = rays.shape[1] // ray_batch_size
output_list = AttrDict(aux_losses=dict())
for idx in range(n_batches):
rays_batch = AttrDict(
rays=rays[:, idx * ray_batch_size : (idx + 1) * ray_batch_size],
radii=radii[:, idx * ray_batch_size : (idx + 1) * ray_batch_size],
)
output = render_rays(rays_batch, params=params, options=options)
if isinstance(camera, DifferentiableProjectiveCamera):
z_batch = z_directions[:, idx * ray_batch_size : (idx + 1) * ray_batch_size]
ray_directions = rays_batch.rays[:, :, 1]
z_dots = (ray_directions * z_batch).sum(-1, keepdim=True)
output.depth = output.distances * z_dots
output_list = output_list.combine(output, append_tensor)
def _resize(val_list: List[torch.Tensor]):
val = torch.cat(val_list, dim=1)
assert val.shape[1] == inner_batch_size * camera.height * camera.width
return val.view(batch_size, *inner_shape, camera.height, camera.width, -1)
def _avg(_key: str, loss_list: List[torch.Tensor]):
return sum(loss_list) / n_batches
output = AttrDict(
{name: _resize(val_list) for name, val_list in output_list.items() if name != "aux_losses"}
)
output.aux_losses = output_list.aux_losses.map(_avg)
return output