from abc import ABC, abstractmethod from functools import partial from typing import Any, Dict, Optional, Tuple import numpy as np import torch import torch.nn as nn from shap_e.models.nn.checkpoint import checkpoint from shap_e.models.nn.encoding import encode_position, spherical_harmonics_basis from shap_e.models.nn.meta import MetaModule, subdict from shap_e.models.nn.ops import MLP, MetaMLP, get_act, mlp_init, zero_init from shap_e.models.nn.utils import ArrayType from shap_e.models.query import Query from shap_e.util.collections import AttrDict class NeRFModel(ABC): """ Parametric scene representation whose outputs are integrated by NeRFRenderer """ @abstractmethod def forward( self, query: Query, params: Optional[Dict[str, torch.Tensor]] = None, options: Optional[Dict[str, Any]] = None, ) -> AttrDict: """ :param query: the points in the field to query. :param params: Meta parameters :param options: Optional hyperparameters :return: An AttrDict containing at least - density: [batch_size x ... x 1] - channels: [batch_size x ... x n_channels] - aux_losses: [batch_size x ... x 1] """ class VoidNeRFModel(MetaModule, NeRFModel): """ Implements the default empty space model where all queries are rendered as background. """ def __init__( self, background: ArrayType, trainable: bool = False, channel_scale: float = 255.0, device: torch.device = torch.device("cuda"), ): super().__init__() background = nn.Parameter( torch.from_numpy(np.array(background)).to(dtype=torch.float32, device=device) / channel_scale ) if trainable: self.register_parameter("background", background) else: self.register_buffer("background", background) def forward( self, query: Query, params: Optional[Dict[str, torch.Tensor]] = None, options: Optional[Dict[str, Any]] = None, ) -> AttrDict: _ = params default_bg = self.background[None] background = options.get("background", default_bg) if options is not None else default_bg shape = query.position.shape[:-1] ones = [1] * (len(shape) - 1) n_channels = background.shape[-1] background = torch.broadcast_to( background.view(background.shape[0], *ones, n_channels), [*shape, n_channels] ) return background class MLPNeRFModel(MetaModule, NeRFModel): def __init__( self, # Positional encoding parameters n_levels: int = 10, # MLP parameters d_hidden: int = 256, n_density_layers: int = 4, n_channel_layers: int = 1, n_channels: int = 3, sh_degree: int = 4, activation: str = "relu", density_activation: str = "exp", init: Optional[str] = None, init_scale: float = 1.0, output_activation: str = "sigmoid", meta_parameters: bool = False, trainable_meta: bool = False, zero_out: bool = True, register_freqs: bool = True, posenc_version: str = "v1", device: torch.device = torch.device("cuda"), ): super().__init__() # Positional encoding if register_freqs: # not used anymore self.register_buffer( "freqs", 2.0 ** torch.arange(n_levels, device=device, dtype=torch.float).view(1, n_levels), ) self.posenc_version = posenc_version dummy = torch.eye(1, 3) d_input = encode_position(posenc_version, position=dummy).shape[-1] self.n_levels = n_levels self.sh_degree = sh_degree d_sh_coeffs = sh_degree**2 self.meta_parameters = meta_parameters mlp_cls = ( partial( MetaMLP, meta_scale=False, meta_shift=False, meta_proj=True, meta_bias=True, trainable_meta=trainable_meta, ) if meta_parameters else MLP ) self.density_mlp = mlp_cls( d_input=d_input, d_hidden=[d_hidden] * (n_density_layers - 1), d_output=d_hidden, act_name=activation, init_scale=init_scale, ) self.channel_mlp = mlp_cls( d_input=d_hidden + d_sh_coeffs, d_hidden=[d_hidden] * n_channel_layers, d_output=n_channels, act_name=activation, init_scale=init_scale, ) self.act = get_act(output_activation) self.density_act = get_act(density_activation) mlp_init( list(self.density_mlp.affines) + list(self.channel_mlp.affines), init=init, init_scale=init_scale, ) if zero_out: zero_init(self.channel_mlp.affines[-1]) self.to(device) def encode_position(self, query: Query): h = encode_position(self.posenc_version, position=query.position) return h def forward( self, query: Query, params: Optional[Dict[str, torch.Tensor]] = None, options: Optional[Dict[str, Any]] = None, ) -> AttrDict: params = self.update(params) options = AttrDict() if options is None else AttrDict(options) query = query.copy() h_position = self.encode_position(query) if self.meta_parameters: density_params = subdict(params, "density_mlp") density_mlp = partial( self.density_mlp, params=density_params, options=options, log_prefix="density_" ) density_mlp_parameters = list(density_params.values()) else: density_mlp = partial(self.density_mlp, options=options, log_prefix="density_") density_mlp_parameters = self.density_mlp.parameters() h_density = checkpoint( density_mlp, (h_position,), density_mlp_parameters, options.checkpoint_nerf_mlp, ) h_direction = maybe_get_spherical_harmonics_basis( sh_degree=self.sh_degree, coords_shape=query.position.shape, coords=query.direction, device=query.position.device, ) if self.meta_parameters: channel_params = subdict(params, "channel_mlp") channel_mlp = partial( self.channel_mlp, params=channel_params, options=options, log_prefix="channel_" ) channel_mlp_parameters = list(channel_params.values()) else: channel_mlp = partial(self.channel_mlp, options=options, log_prefix="channel_") channel_mlp_parameters = self.channel_mlp.parameters() h_channel = checkpoint( channel_mlp, (torch.cat([h_density, h_direction], dim=-1),), channel_mlp_parameters, options.checkpoint_nerf_mlp, ) density_logit = h_density[..., :1] res = AttrDict( density_logit=density_logit, density=self.density_act(density_logit), channels=self.act(h_channel), aux_losses=AttrDict(), no_weight_grad_aux_losses=AttrDict(), ) if options.return_h_density: res.h_density = h_density return res def maybe_get_spherical_harmonics_basis( sh_degree: int, coords_shape: Tuple[int], coords: Optional[torch.Tensor] = None, device: torch.device = torch.device("cuda"), ) -> torch.Tensor: """ :param sh_degree: Spherical harmonics degree :param coords_shape: [*shape, 3] :param coords: optional coordinate tensor of coords_shape """ if coords is None: return torch.zeros(*coords_shape[:-1], sh_degree**2).to(device) return spherical_harmonics_basis(coords, sh_degree)