from functools import partial from typing import Any, Dict, Optional, Tuple import torch import torch.nn as nn from shap_e.models.nn.checkpoint import checkpoint from shap_e.models.nn.encoding import encode_position, maybe_encode_direction from shap_e.models.nn.meta import MetaModule, subdict from shap_e.models.nn.ops import MetaLinear, get_act, mlp_init from shap_e.models.query import Query from shap_e.util.collections import AttrDict from .base import Model class MLPModel(MetaModule, Model): def __init__( self, n_output: int, output_activation: str, # Positional encoding parameters posenc_version: str = "v1", # Direction related channel prediction insert_direction_at: Optional[int] = None, # MLP parameters d_hidden: int = 256, n_hidden_layers: int = 4, activation: str = "relu", init: Optional[str] = None, init_scale: float = 1.0, meta_parameters: bool = False, trainable_meta: bool = False, meta_proj: bool = True, meta_bias: bool = True, meta_start: int = 0, meta_stop: Optional[int] = None, n_meta_layers: Optional[int] = None, register_freqs: bool = False, device: torch.device = torch.device("cuda"), ): super().__init__() if register_freqs: self.register_buffer("freqs", 2.0 ** torch.arange(10, device=device).view(1, 10)) # Positional encoding self.posenc_version = posenc_version dummy = torch.eye(1, 3) d_posenc_pos = encode_position(posenc_version, position=dummy).shape[-1] d_posenc_dir = maybe_encode_direction(posenc_version, position=dummy).shape[-1] # Instantiate the MLP mlp_widths = [d_hidden] * n_hidden_layers input_widths = [d_posenc_pos, *mlp_widths] output_widths = mlp_widths + [n_output] self.meta_parameters = meta_parameters # When this model is used jointly to express NeRF, it may have to # process directions as well in which case we simply concatenate # the direction representation at the specified layer. self.insert_direction_at = insert_direction_at if insert_direction_at is not None: input_widths[self.insert_direction_at] += d_posenc_dir linear_cls = lambda meta: ( partial( MetaLinear, meta_scale=False, meta_shift=False, meta_proj=meta_proj, meta_bias=meta_bias, trainable_meta=trainable_meta, ) if meta else nn.Linear ) if meta_stop is None: if n_meta_layers is not None: assert n_meta_layers > 0 meta_stop = meta_start + n_meta_layers - 1 else: meta_stop = n_hidden_layers if meta_parameters: metas = [meta_start <= layer <= meta_stop for layer in range(n_hidden_layers + 1)] else: metas = [False] * (n_hidden_layers + 1) self.mlp = nn.ModuleList( [ linear_cls(meta)(d_in, d_out, device=device) for meta, d_in, d_out in zip(metas, input_widths, output_widths) ] ) mlp_init(self.mlp, init=init, init_scale=init_scale) self.activation = get_act(activation) self.output_activation = get_act(output_activation) self.device = device self.to(device) def forward( self, query: Query, params: Optional[Dict[str, torch.Tensor]] = None, options: Optional[Dict[str, Any]] = None, ) -> AttrDict: """ :param position: [batch_size x ... x 3] :param params: Meta parameters :param options: Optional hyperparameters """ # query.direction is None typically for SDF models and training h_final, _h_directionless = self._mlp( query.position, query.direction, params=params, options=options ) return self.output_activation(h_final) def _run_mlp( self, position: torch.Tensor, direction: torch.Tensor, params: AttrDict[str, torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]: """ :return: the final and directionless activations at the given query """ h_preact = h = encode_position(self.posenc_version, position=position) h_directionless = None for i, layer in enumerate(self.mlp): if i == self.insert_direction_at: h_directionless = h_preact h_direction = maybe_encode_direction( self.posenc_version, position=position, direction=direction ) h = torch.cat([h, h_direction], dim=-1) if isinstance(layer, MetaLinear): h = layer(h, params=subdict(params, f"mlp.{i}")) else: h = layer(h) h_preact = h if i < len(self.mlp) - 1: h = self.activation(h) h_final = h if h_directionless is None: h_directionless = h_preact return h_final, h_directionless def _mlp( self, position: torch.Tensor, direction: Optional[torch.Tensor] = None, params: Optional[Dict[str, torch.Tensor]] = None, options: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ :param position: [batch_size x ... x 3] :param params: Meta parameters :param options: Optional hyperparameters :return: the final and directionless activations at the given query """ params = self.update(params) options = AttrDict() if options is None else AttrDict(options) mlp = partial(self._run_mlp, direction=direction, params=params) parameters = [] for i, layer in enumerate(self.mlp): if isinstance(layer, MetaLinear): parameters.extend(list(subdict(params, f"mlp.{i}").values())) else: parameters.extend(layer.parameters()) h_final, h_directionless = checkpoint( mlp, (position,), parameters, options.checkpoint_stf_model ) return h_final, h_directionless class MLPSDFModel(MLPModel): def __init__(self, initial_bias: float = -0.1, **kwargs): super().__init__(n_output=1, output_activation="identity", **kwargs) self.mlp[-1].bias.data.fill_(initial_bias) def forward( self, query: Query, params: Optional[Dict[str, torch.Tensor]] = None, options: Optional[Dict[str, Any]] = None, ) -> AttrDict[str, Any]: signed_distance = super().forward(query=query, params=params, options=options) return AttrDict(signed_distance=signed_distance) class MLPTextureFieldModel(MLPModel): def __init__( self, n_channels: int = 3, **kwargs, ): super().__init__(n_output=n_channels, output_activation="sigmoid", **kwargs) def forward( self, query: Query, params: Optional[Dict[str, torch.Tensor]] = None, options: Optional[Dict[str, Any]] = None, ) -> AttrDict[str, Any]: channels = super().forward(query=query, params=params, options=options) return AttrDict(channels=channels)