from abc import ABC, abstractmethod from dataclasses import dataclass from functools import partial from typing import Any, Dict, List, Optional, Tuple import torch from shap_e.models.nn.utils import sample_pmf from shap_e.models.volume import Volume, VolumeRange from shap_e.util.collections import AttrDict from .model import NeRFModel, Query def render_rays( rays: torch.Tensor, parts: List["RayVolumeIntegral"], void_model: NeRFModel, shared: bool = False, prev_raw_outputs: Optional[List[AttrDict]] = None, render_with_direction: bool = True, importance_sampling_options: Optional[Dict[str, Any]] = None, ) -> Tuple["RayVolumeIntegralResults", List["RaySampler"], List[AttrDict]]: """ Perform volumetric rendering over a partition of possible t's in the union of rendering volumes (written below with some abuse of notations) C(r) := sum( transmittance(t[i]) * integrate( lambda t: density(t) * channels(t) * transmittance(t), [t[i], t[i + 1]], ) for i in range(len(parts)) ) + transmittance(t[-1]) * void_model(t[-1]).channels where 1) transmittance(s) := exp(-integrate(density, [t[0], s])) calculates the probability of light passing through the volume specified by [t[0], s]. (transmittance of 1 means light can pass freely) 2) density and channels are obtained by evaluating the appropriate part.model at time t. 3) [t[i], t[i + 1]] is defined as the range of t where the ray intersects (parts[i].volume \\ union(part.volume for part in parts[:i])) at the surface of the shell (if bounded). If the ray does not intersect, the integral over this segment is evaluated as 0 and transmittance(t[i + 1]) := transmittance(t[i]). 4) The last term is integration to infinity (e.g. [t[-1], math.inf]) that is evaluated by the void_model (i.e. we consider this space to be empty). :param rays: [batch_size x ... x 2 x 3] origin and direction. :param parts: disjoint volume integrals. :param void_model: use this model to integrate over the empty space :param shared: All RayVolumeIntegrals are calculated with the same model. :param prev_raw_outputs: Raw outputs from the previous rendering step :return: A tuple of - AttrDict containing the rendered `channels`, `distances`, and the `aux_losses` - A list of importance samplers for additional fine-grained rendering - A list of raw output for each interval """ if importance_sampling_options is None: importance_sampling_options = {} origin, direc = rays[..., 0, :], rays[..., 1, :] if prev_raw_outputs is None: prev_raw_outputs = [None] * len(parts) samplers = [] raw_outputs = [] t0 = None results = None for part_i, prev_raw_i in zip(parts, prev_raw_outputs): # Integrate over [t[i], t[i + 1]] results_i = part_i.render_rays( origin, direc, t0=t0, prev_raw=prev_raw_i, shared=shared, render_with_direction=render_with_direction, ) # Create an importance sampler for (optional) fine rendering samplers.append( ImportanceRaySampler( results_i.volume_range, results_i.raw, **importance_sampling_options ) ) raw_outputs.append(results_i.raw) # Pass t[i + 1] as the start of integration for the next interval. t0 = results_i.volume_range.next_t0() # Combine the results from [t[0], t[i]] and [t[i], t[i+1]] results = results_i if results is None else results.combine(results_i) # While integrating out [t[-1], math.inf] is the correct thing to do, this # erases a lot of useful information. Also, void_model is meant to predict # the channels at t=math.inf. # # Add the void background over [t[-1], math.inf] to complete integration. # results = results.combine( # RayVolumeIntegralResults( # output=AttrDict( # channels=void_model(origin, direc), # distances=torch.zeros_like(t0), # aux_losses=AttrDict(), # ), # volume_range=VolumeRange( # t0=t0, # t1=torch.full_like(t0, math.inf), # intersected=torch.full_like(results.volume_range.intersected, True), # ), # # Void space extends to infinity. It is assumed that no light # # passes beyond the void. # transmittance=torch.zeros_like(results_i.transmittance), # ) # ) results.output.channels = results.output.channels + results.transmittance * void_model( Query(origin, direc) ) return results, samplers, raw_outputs @dataclass class RayVolumeIntegralResults: """ Stores the relevant state and results of integrate( lambda t: density(t) * channels(t) * transmittance(t), [t0, t1], ) """ # Rendered output and auxiliary losses # output.channels has shape [batch_size, *inner_shape, n_channels] output: AttrDict """ Optional values """ # Raw values contain the sampled `ts`, `density`, `channels`, etc. raw: Optional[AttrDict] = None # Integration volume_range: Optional[VolumeRange] = None # If a ray intersects, the transmittance from t0 to t1 (e.g. the # probability that the ray passes through this volume). # has shape [batch_size, *inner_shape, 1] transmittance: Optional[torch.Tensor] = None def combine(self, cur: "RayVolumeIntegralResults") -> "RayVolumeIntegralResults": """ Combines the integration results of `self` over [t0, t1] and `cur` over [t1, t2] to produce a new set of results over [t0, t2] by using a similar equation to (4) in NeRF++: integrate( lambda t: density(t) * channels(t) * transmittance(t), [t0, t2] ) = integrate( lambda t: density(t) * channels(t) * transmittance(t), [t0, t1] ) + transmittance(t1) * integrate( lambda t: density(t) * channels(t) * transmittance(t), [t1, t2] ) """ assert torch.allclose(self.volume_range.next_t0(), cur.volume_range.t0) def _combine_fn( prev_val: Optional[torch.Tensor], cur_val: Optional[torch.Tensor], *, prev_transmittance: torch.Tensor, ): assert prev_val is not None if cur_val is None: # cur_output.aux_losses are empty for the void_model. return prev_val return prev_val + prev_transmittance * cur_val output = self.output.combine( cur.output, combine_fn=partial(_combine_fn, prev_transmittance=self.transmittance) ) combined = RayVolumeIntegralResults( output=output, volume_range=self.volume_range.extend(cur.volume_range), transmittance=self.transmittance * cur.transmittance, ) return combined @dataclass class RayVolumeIntegral: model: NeRFModel volume: Volume sampler: "RaySampler" n_samples: int def render_rays( self, origin: torch.Tensor, direction: torch.Tensor, t0: Optional[torch.Tensor] = None, prev_raw: Optional[AttrDict] = None, shared: bool = False, render_with_direction: bool = True, ) -> "RayVolumeIntegralResults": """ Perform volumetric rendering over the given volume. :param position: [batch_size, *shape, 3] :param direction: [batch_size, *shape, 3] :param t0: Optional [batch_size, *shape, 1] :param prev_raw: the raw outputs when using multiple levels with this model. :param shared: means the same model is used for all RayVolumeIntegral's :param render_with_direction: use the incoming ray direction when querying the model. :return: RayVolumeIntegralResults """ # 1. Intersect the rays with the current volume and sample ts to # integrate along. vrange = self.volume.intersect(origin, direction, t0_lower=t0) ts = self.sampler.sample(vrange.t0, vrange.t1, self.n_samples) if prev_raw is not None and not shared: # Append the previous ts now before fprop because previous # rendering used a different model and we can't reuse the output. ts = torch.sort(torch.cat([ts, prev_raw.ts], dim=-2), dim=-2).values # Shape sanity checks batch_size, *_shape, _t0_dim = vrange.t0.shape _, *ts_shape, _ts_dim = ts.shape # 2. Get the points along the ray and query the model directions = torch.broadcast_to(direction.unsqueeze(-2), [batch_size, *ts_shape, 3]) positions = origin.unsqueeze(-2) + ts * directions optional_directions = directions if render_with_direction else None mids = (ts[..., 1:, :] + ts[..., :-1, :]) / 2 raw = self.model( Query( position=positions, direction=optional_directions, t_min=torch.cat([vrange.t0[..., None, :], mids], dim=-2), t_max=torch.cat([mids, vrange.t1[..., None, :]], dim=-2), ) ) raw.ts = ts if prev_raw is not None and shared: # We can append the additional queries to previous raw outputs # before integration copy = prev_raw.copy() result = torch.sort(torch.cat([raw.pop("ts"), copy.pop("ts")], dim=-2), dim=-2) merge_results = partial(self._merge_results, dim=-2, indices=result.indices) raw = raw.combine(copy, merge_results) raw.ts = result.values # 3. Integrate the raw results output, transmittance = self.integrate_samples(vrange, raw) # 4. Clean up results that do not intersect with the volume. transmittance = torch.where( vrange.intersected, transmittance, torch.ones_like(transmittance) ) def _mask_fn(_key: str, tensor: torch.Tensor): return torch.where(vrange.intersected, tensor, torch.zeros_like(tensor)) def _is_tensor(_key: str, value: Any): return isinstance(value, torch.Tensor) output = output.map(map_fn=_mask_fn, should_map=_is_tensor) return RayVolumeIntegralResults( output=output, raw=raw, volume_range=vrange, transmittance=transmittance, ) def integrate_samples( self, volume_range: VolumeRange, raw: AttrDict, ) -> Tuple[AttrDict, torch.Tensor]: """ Integrate the raw.channels along with other aux_losses and values to produce the final output dictionary containing rendered `channels`, estimated `distances` and `aux_losses`. :param volume_range: Specifies the integral range [t0, t1] :param raw: Contains a dict of function evaluations at ts. Should have density: torch.Tensor [batch_size, *shape, n_samples, 1] channels: torch.Tensor [batch_size, *shape, n_samples, n_channels] aux_losses: {key: torch.Tensor [batch_size, *shape, n_samples, 1] for each key} no_weight_grad_aux_losses: an optional set of losses for which the weights should be detached before integration. after the call, integrate_samples populates some intermediate calculations for later use like weights: torch.Tensor [batch_size, *shape, n_samples, 1] (density * transmittance)[i] weight for each rgb output at [..., i, :]. :returns: a tuple of ( a dictionary of rendered outputs and aux_losses, transmittance of this volume, ) """ # 1. Calculate the weights _, _, dt = volume_range.partition(raw.ts) ddensity = raw.density * dt mass = torch.cumsum(ddensity, dim=-2) transmittance = torch.exp(-mass[..., -1, :]) alphas = 1.0 - torch.exp(-ddensity) Ts = torch.exp(torch.cat([torch.zeros_like(mass[..., :1, :]), -mass[..., :-1, :]], dim=-2)) # This is the probability of light hitting and reflecting off of # something at depth [..., i, :]. weights = alphas * Ts # 2. Integrate all results def _integrate(key: str, samples: torch.Tensor, weights: torch.Tensor): if key == "density": # Omit integrating the density, because we don't need it return None return torch.sum(samples * weights, dim=-2) def _is_tensor(_key: str, value: Any): return isinstance(value, torch.Tensor) if raw.no_weight_grad_aux_losses: extra_aux_losses = raw.no_weight_grad_aux_losses.map( partial(_integrate, weights=weights.detach()), should_map=_is_tensor ) else: extra_aux_losses = {} output = raw.map(partial(_integrate, weights=weights), should_map=_is_tensor) if "no_weight_grad_aux_losses" in output: del output["no_weight_grad_aux_losses"] output.aux_losses.update(extra_aux_losses) # Integrating the ts yields the distance away from the origin; rename the variable. output.distances = output.ts del output["ts"] del output["density"] assert output.distances.shape == (*output.channels.shape[:-1], 1) assert output.channels.shape[:-1] == raw.channels.shape[:-2] assert output.channels.shape[-1] == raw.channels.shape[-1] # 3. Reduce loss def _reduce_loss(_key: str, loss: torch.Tensor): return loss.view(loss.shape[0], -1).sum(dim=-1) # 4. Store other useful calculations raw.weights = weights output.aux_losses = output.aux_losses.map(_reduce_loss) return output, transmittance def _merge_results( self, a: Optional[torch.Tensor], b: torch.Tensor, dim: int, indices: torch.Tensor ): """ :param a: [..., n_a, ...]. The other dictionary containing the b's may contain extra tensors from earlier calculations, so a can be None. :param b: [..., n_b, ...] :param dim: dimension to merge :param indices: how the merged results should be sorted at the end :return: a concatted and sorted tensor of size [..., n_a + n_b, ...] """ if a is None: return None merged = torch.cat([a, b], dim=dim) return torch.gather(merged, dim=dim, index=torch.broadcast_to(indices, merged.shape)) class RaySampler(ABC): @abstractmethod def sample(self, t0: torch.Tensor, t1: torch.Tensor, n_samples: int) -> torch.Tensor: """ :param t0: start time has shape [batch_size, *shape, 1] :param t1: finish time has shape [batch_size, *shape, 1] :param n_samples: number of ts to sample :return: sampled ts of shape [batch_size, *shape, n_samples, 1] """ class StratifiedRaySampler(RaySampler): """ Instead of fixed intervals, a sample is drawn uniformly at random from each interval. """ def __init__(self, depth_mode: str = "linear"): """ :param depth_mode: linear samples ts linearly in depth. harmonic ensures closer points are sampled more densely. """ self.depth_mode = depth_mode assert self.depth_mode in ("linear", "geometric", "harmonic") def sample( self, t0: torch.Tensor, t1: torch.Tensor, n_samples: int, epsilon: float = 1e-3, ) -> torch.Tensor: """ :param t0: start time has shape [batch_size, *shape, 1] :param t1: finish time has shape [batch_size, *shape, 1] :param n_samples: number of ts to sample :return: sampled ts of shape [batch_size, *shape, n_samples, 1] """ ones = [1] * (len(t0.shape) - 1) ts = torch.linspace(0, 1, n_samples).view(*ones, n_samples).to(t0.dtype).to(t0.device) if self.depth_mode == "linear": ts = t0 * (1.0 - ts) + t1 * ts elif self.depth_mode == "geometric": ts = (t0.clamp(epsilon).log() * (1.0 - ts) + t1.clamp(epsilon).log() * ts).exp() elif self.depth_mode == "harmonic": # The original NeRF recommends this interpolation scheme for # spherical scenes, but there could be some weird edge cases when # the observer crosses from the inner to outer volume. ts = 1.0 / (1.0 / t0.clamp(epsilon) * (1.0 - ts) + 1.0 / t1.clamp(epsilon) * ts) mids = 0.5 * (ts[..., 1:] + ts[..., :-1]) upper = torch.cat([mids, t1], dim=-1) lower = torch.cat([t0, mids], dim=-1) t_rand = torch.rand_like(ts) ts = lower + (upper - lower) * t_rand return ts.unsqueeze(-1) class ImportanceRaySampler(RaySampler): """ Given the initial estimate of densities, this samples more from regions/bins expected to have objects. """ def __init__( self, volume_range: VolumeRange, raw: AttrDict, blur_pool: bool = False, alpha: float = 1e-5 ): """ :param volume_range: the range in which a ray intersects the given volume. :param raw: dictionary of raw outputs from the NeRF models of shape [batch_size, *shape, n_coarse_samples, 1]. Should at least contain :param ts: earlier samples from the coarse rendering step :param weights: discretized version of density * transmittance :param blur_pool: if true, use 2-tap max + 2-tap blur filter from mip-NeRF. :param alpha: small value to add to weights. """ self.volume_range = volume_range self.ts = raw.ts.clone().detach() self.weights = raw.weights.clone().detach() self.blur_pool = blur_pool self.alpha = alpha @torch.no_grad() def sample(self, t0: torch.Tensor, t1: torch.Tensor, n_samples: int) -> torch.Tensor: """ :param t0: start time has shape [batch_size, *shape, 1] :param t1: finish time has shape [batch_size, *shape, 1] :param n_samples: number of ts to sample :return: sampled ts of shape [batch_size, *shape, n_samples, 1] """ lower, upper, _ = self.volume_range.partition(self.ts) batch_size, *shape, n_coarse_samples, _ = self.ts.shape weights = self.weights if self.blur_pool: padded = torch.cat([weights[..., :1, :], weights, weights[..., -1:, :]], dim=-2) maxes = torch.maximum(padded[..., :-1, :], padded[..., 1:, :]) weights = 0.5 * (maxes[..., :-1, :] + maxes[..., 1:, :]) weights = weights + self.alpha pmf = weights / weights.sum(dim=-2, keepdim=True) inds = sample_pmf(pmf, n_samples) assert inds.shape == (batch_size, *shape, n_samples, 1) assert (inds >= 0).all() and (inds < n_coarse_samples).all() t_rand = torch.rand(inds.shape, device=inds.device) lower_ = torch.gather(lower, -2, inds) upper_ = torch.gather(upper, -2, inds) ts = lower_ + (upper_ - lower_) * t_rand ts = torch.sort(ts, dim=-2).values return ts