You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
174 lines
6.1 KiB
174 lines
6.1 KiB
2 years ago
|
from typing import Any, Dict, Optional, Tuple
|
||
|
|
||
|
import torch
|
||
|
|
||
|
from shap_e.models.nn.ops import get_act
|
||
|
from shap_e.models.query import Query
|
||
|
from shap_e.models.stf.mlp import MLPModel
|
||
|
from shap_e.util.collections import AttrDict
|
||
|
|
||
|
|
||
|
class MLPDensitySDFModel(MLPModel):
|
||
|
def __init__(
|
||
|
self,
|
||
|
initial_bias: float = -0.1,
|
||
|
sdf_activation="tanh",
|
||
|
density_activation="exp",
|
||
|
**kwargs,
|
||
|
):
|
||
|
super().__init__(
|
||
|
n_output=2,
|
||
|
output_activation="identity",
|
||
|
**kwargs,
|
||
|
)
|
||
|
self.mlp[-1].bias[0].data.fill_(initial_bias)
|
||
|
self.sdf_activation = get_act(sdf_activation)
|
||
|
self.density_activation = get_act(density_activation)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
query: Query,
|
||
|
params: Optional[Dict[str, torch.Tensor]] = None,
|
||
|
options: Optional[Dict[str, Any]] = None,
|
||
|
) -> AttrDict[str, Any]:
|
||
|
# query.direction is None typically for SDF models and training
|
||
|
h, _h_directionless = self._mlp(
|
||
|
query.position, query.direction, params=params, options=options
|
||
|
)
|
||
|
h_sdf, h_density = h.split(1, dim=-1)
|
||
|
return AttrDict(
|
||
|
density=self.density_activation(h_density),
|
||
|
signed_distance=self.sdf_activation(h_sdf),
|
||
|
)
|
||
|
|
||
|
|
||
|
class MLPNeRSTFModel(MLPModel):
|
||
|
def __init__(
|
||
|
self,
|
||
|
sdf_activation="tanh",
|
||
|
density_activation="exp",
|
||
|
channel_activation="sigmoid",
|
||
|
direction_dependent_shape: bool = True, # To be able to load old models. Set this to be False in future models.
|
||
|
separate_nerf_channels: bool = False,
|
||
|
separate_coarse_channels: bool = False,
|
||
|
initial_density_bias: float = 0.0,
|
||
|
initial_sdf_bias: float = -0.1,
|
||
|
**kwargs,
|
||
|
):
|
||
|
h_map, h_directionless_map = indices_for_output_mode(
|
||
|
direction_dependent_shape=direction_dependent_shape,
|
||
|
separate_nerf_channels=separate_nerf_channels,
|
||
|
separate_coarse_channels=separate_coarse_channels,
|
||
|
)
|
||
|
n_output = index_mapping_max(h_map)
|
||
|
super().__init__(
|
||
|
n_output=n_output,
|
||
|
output_activation="identity",
|
||
|
**kwargs,
|
||
|
)
|
||
|
self.direction_dependent_shape = direction_dependent_shape
|
||
|
self.separate_nerf_channels = separate_nerf_channels
|
||
|
self.separate_coarse_channels = separate_coarse_channels
|
||
|
self.sdf_activation = get_act(sdf_activation)
|
||
|
self.density_activation = get_act(density_activation)
|
||
|
self.channel_activation = get_act(channel_activation)
|
||
|
self.h_map = h_map
|
||
|
self.h_directionless_map = h_directionless_map
|
||
|
self.mlp[-1].bias.data.zero_()
|
||
|
layer = -1 if self.direction_dependent_shape else self.insert_direction_at
|
||
|
self.mlp[layer].bias[0].data.fill_(initial_sdf_bias)
|
||
|
self.mlp[layer].bias[1].data.fill_(initial_density_bias)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
query: Query,
|
||
|
params: Optional[Dict[str, torch.Tensor]] = None,
|
||
|
options: Optional[Dict[str, Any]] = None,
|
||
|
) -> AttrDict[str, Any]:
|
||
|
options = AttrDict() if options is None else AttrDict(options)
|
||
|
h, h_directionless = self._mlp(
|
||
|
query.position, query.direction, params=params, options=options
|
||
|
)
|
||
|
activations = map_indices_to_keys(self.h_map, h)
|
||
|
activations.update(map_indices_to_keys(self.h_directionless_map, h_directionless))
|
||
|
|
||
|
if options.nerf_level == "coarse":
|
||
|
h_density = activations.density_coarse
|
||
|
else:
|
||
|
h_density = activations.density_fine
|
||
|
|
||
|
if options.get("rendering_mode", "stf") == "nerf":
|
||
|
if options.nerf_level == "coarse":
|
||
|
h_channels = activations.nerf_coarse
|
||
|
else:
|
||
|
h_channels = activations.nerf_fine
|
||
|
else:
|
||
|
h_channels = activations.stf
|
||
|
return AttrDict(
|
||
|
density=self.density_activation(h_density),
|
||
|
signed_distance=self.sdf_activation(activations.sdf),
|
||
|
channels=self.channel_activation(h_channels),
|
||
|
)
|
||
|
|
||
|
|
||
|
IndexMapping = AttrDict[str, Tuple[int, int]]
|
||
|
|
||
|
|
||
|
def indices_for_output_mode(
|
||
|
direction_dependent_shape: bool,
|
||
|
separate_nerf_channels: bool,
|
||
|
separate_coarse_channels: bool,
|
||
|
) -> Tuple[IndexMapping, IndexMapping]:
|
||
|
"""
|
||
|
Get output mappings for (h, h_directionless).
|
||
|
"""
|
||
|
h_map = AttrDict()
|
||
|
h_directionless_map = AttrDict()
|
||
|
if direction_dependent_shape:
|
||
|
h_map.sdf = (0, 1)
|
||
|
if separate_coarse_channels:
|
||
|
assert separate_nerf_channels
|
||
|
h_map.density_coarse = (1, 2)
|
||
|
h_map.density_fine = (2, 3)
|
||
|
h_map.stf = (3, 6)
|
||
|
h_map.nerf_coarse = (6, 9)
|
||
|
h_map.nerf_fine = (9, 12)
|
||
|
else:
|
||
|
h_map.density_coarse = (1, 2)
|
||
|
h_map.density_fine = (1, 2)
|
||
|
if separate_nerf_channels:
|
||
|
h_map.stf = (2, 5)
|
||
|
h_map.nerf_coarse = (5, 8)
|
||
|
h_map.nerf_fine = (5, 8)
|
||
|
else:
|
||
|
h_map.stf = (2, 5)
|
||
|
h_map.nerf_coarse = (2, 5)
|
||
|
h_map.nerf_fine = (2, 5)
|
||
|
else:
|
||
|
h_directionless_map.sdf = (0, 1)
|
||
|
h_directionless_map.density_coarse = (1, 2)
|
||
|
if separate_coarse_channels:
|
||
|
h_directionless_map.density_fine = (2, 3)
|
||
|
else:
|
||
|
h_directionless_map.density_fine = h_directionless_map.density_coarse
|
||
|
h_map.stf = (0, 3)
|
||
|
if separate_coarse_channels:
|
||
|
assert separate_nerf_channels
|
||
|
h_map.nerf_coarse = (3, 6)
|
||
|
h_map.nerf_fine = (6, 9)
|
||
|
else:
|
||
|
if separate_nerf_channels:
|
||
|
h_map.nerf_coarse = (3, 6)
|
||
|
else:
|
||
|
h_map.nerf_coarse = (0, 3)
|
||
|
h_map.nerf_fine = h_map.nerf_coarse
|
||
|
return h_map, h_directionless_map
|
||
|
|
||
|
|
||
|
def map_indices_to_keys(mapping: IndexMapping, data: torch.Tensor) -> AttrDict[str, torch.Tensor]:
|
||
|
return AttrDict({k: data[..., start:end] for k, (start, end) in mapping.items()})
|
||
|
|
||
|
|
||
|
def index_mapping_max(mapping: IndexMapping) -> int:
|
||
|
return max(end for _, (_, end) in mapping.items())
|