a fork of shap-e for gc
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

174 lines
6.1 KiB

2 years ago
from typing import Any, Dict, Optional, Tuple
import torch
from shap_e.models.nn.ops import get_act
from shap_e.models.query import Query
from shap_e.models.stf.mlp import MLPModel
from shap_e.util.collections import AttrDict
class MLPDensitySDFModel(MLPModel):
def __init__(
self,
initial_bias: float = -0.1,
sdf_activation="tanh",
density_activation="exp",
**kwargs,
):
super().__init__(
n_output=2,
output_activation="identity",
**kwargs,
)
self.mlp[-1].bias[0].data.fill_(initial_bias)
self.sdf_activation = get_act(sdf_activation)
self.density_activation = get_act(density_activation)
def forward(
self,
query: Query,
params: Optional[Dict[str, torch.Tensor]] = None,
options: Optional[Dict[str, Any]] = None,
) -> AttrDict[str, Any]:
# query.direction is None typically for SDF models and training
h, _h_directionless = self._mlp(
query.position, query.direction, params=params, options=options
)
h_sdf, h_density = h.split(1, dim=-1)
return AttrDict(
density=self.density_activation(h_density),
signed_distance=self.sdf_activation(h_sdf),
)
class MLPNeRSTFModel(MLPModel):
def __init__(
self,
sdf_activation="tanh",
density_activation="exp",
channel_activation="sigmoid",
direction_dependent_shape: bool = True, # To be able to load old models. Set this to be False in future models.
separate_nerf_channels: bool = False,
separate_coarse_channels: bool = False,
initial_density_bias: float = 0.0,
initial_sdf_bias: float = -0.1,
**kwargs,
):
h_map, h_directionless_map = indices_for_output_mode(
direction_dependent_shape=direction_dependent_shape,
separate_nerf_channels=separate_nerf_channels,
separate_coarse_channels=separate_coarse_channels,
)
n_output = index_mapping_max(h_map)
super().__init__(
n_output=n_output,
output_activation="identity",
**kwargs,
)
self.direction_dependent_shape = direction_dependent_shape
self.separate_nerf_channels = separate_nerf_channels
self.separate_coarse_channels = separate_coarse_channels
self.sdf_activation = get_act(sdf_activation)
self.density_activation = get_act(density_activation)
self.channel_activation = get_act(channel_activation)
self.h_map = h_map
self.h_directionless_map = h_directionless_map
self.mlp[-1].bias.data.zero_()
layer = -1 if self.direction_dependent_shape else self.insert_direction_at
self.mlp[layer].bias[0].data.fill_(initial_sdf_bias)
self.mlp[layer].bias[1].data.fill_(initial_density_bias)
def forward(
self,
query: Query,
params: Optional[Dict[str, torch.Tensor]] = None,
options: Optional[Dict[str, Any]] = None,
) -> AttrDict[str, Any]:
options = AttrDict() if options is None else AttrDict(options)
h, h_directionless = self._mlp(
query.position, query.direction, params=params, options=options
)
activations = map_indices_to_keys(self.h_map, h)
activations.update(map_indices_to_keys(self.h_directionless_map, h_directionless))
if options.nerf_level == "coarse":
h_density = activations.density_coarse
else:
h_density = activations.density_fine
if options.get("rendering_mode", "stf") == "nerf":
if options.nerf_level == "coarse":
h_channels = activations.nerf_coarse
else:
h_channels = activations.nerf_fine
else:
h_channels = activations.stf
return AttrDict(
density=self.density_activation(h_density),
signed_distance=self.sdf_activation(activations.sdf),
channels=self.channel_activation(h_channels),
)
IndexMapping = AttrDict[str, Tuple[int, int]]
def indices_for_output_mode(
direction_dependent_shape: bool,
separate_nerf_channels: bool,
separate_coarse_channels: bool,
) -> Tuple[IndexMapping, IndexMapping]:
"""
Get output mappings for (h, h_directionless).
"""
h_map = AttrDict()
h_directionless_map = AttrDict()
if direction_dependent_shape:
h_map.sdf = (0, 1)
if separate_coarse_channels:
assert separate_nerf_channels
h_map.density_coarse = (1, 2)
h_map.density_fine = (2, 3)
h_map.stf = (3, 6)
h_map.nerf_coarse = (6, 9)
h_map.nerf_fine = (9, 12)
else:
h_map.density_coarse = (1, 2)
h_map.density_fine = (1, 2)
if separate_nerf_channels:
h_map.stf = (2, 5)
h_map.nerf_coarse = (5, 8)
h_map.nerf_fine = (5, 8)
else:
h_map.stf = (2, 5)
h_map.nerf_coarse = (2, 5)
h_map.nerf_fine = (2, 5)
else:
h_directionless_map.sdf = (0, 1)
h_directionless_map.density_coarse = (1, 2)
if separate_coarse_channels:
h_directionless_map.density_fine = (2, 3)
else:
h_directionless_map.density_fine = h_directionless_map.density_coarse
h_map.stf = (0, 3)
if separate_coarse_channels:
assert separate_nerf_channels
h_map.nerf_coarse = (3, 6)
h_map.nerf_fine = (6, 9)
else:
if separate_nerf_channels:
h_map.nerf_coarse = (3, 6)
else:
h_map.nerf_coarse = (0, 3)
h_map.nerf_fine = h_map.nerf_coarse
return h_map, h_directionless_map
def map_indices_to_keys(mapping: IndexMapping, data: torch.Tensor) -> AttrDict[str, torch.Tensor]:
return AttrDict({k: data[..., start:end] for k, (start, end) in mapping.items()})
def index_mapping_max(mapping: IndexMapping) -> int:
return max(end for _, (_, end) in mapping.items())