You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
214 lines
7.3 KiB
214 lines
7.3 KiB
2 years ago
|
from functools import partial
|
||
|
from typing import Any, Dict, Optional, Tuple
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
|
||
|
from shap_e.models.nn.checkpoint import checkpoint
|
||
|
from shap_e.models.nn.encoding import encode_position, maybe_encode_direction
|
||
|
from shap_e.models.nn.meta import MetaModule, subdict
|
||
|
from shap_e.models.nn.ops import MetaLinear, get_act, mlp_init
|
||
|
from shap_e.models.query import Query
|
||
|
from shap_e.util.collections import AttrDict
|
||
|
|
||
|
from .base import Model
|
||
|
|
||
|
|
||
|
class MLPModel(MetaModule, Model):
|
||
|
def __init__(
|
||
|
self,
|
||
|
n_output: int,
|
||
|
output_activation: str,
|
||
|
# Positional encoding parameters
|
||
|
posenc_version: str = "v1",
|
||
|
# Direction related channel prediction
|
||
|
insert_direction_at: Optional[int] = None,
|
||
|
# MLP parameters
|
||
|
d_hidden: int = 256,
|
||
|
n_hidden_layers: int = 4,
|
||
|
activation: str = "relu",
|
||
|
init: Optional[str] = None,
|
||
|
init_scale: float = 1.0,
|
||
|
meta_parameters: bool = False,
|
||
|
trainable_meta: bool = False,
|
||
|
meta_proj: bool = True,
|
||
|
meta_bias: bool = True,
|
||
|
meta_start: int = 0,
|
||
|
meta_stop: Optional[int] = None,
|
||
|
n_meta_layers: Optional[int] = None,
|
||
|
register_freqs: bool = False,
|
||
|
device: torch.device = torch.device("cuda"),
|
||
|
):
|
||
|
super().__init__()
|
||
|
|
||
|
if register_freqs:
|
||
|
self.register_buffer("freqs", 2.0 ** torch.arange(10, device=device).view(1, 10))
|
||
|
|
||
|
# Positional encoding
|
||
|
self.posenc_version = posenc_version
|
||
|
dummy = torch.eye(1, 3)
|
||
|
d_posenc_pos = encode_position(posenc_version, position=dummy).shape[-1]
|
||
|
d_posenc_dir = maybe_encode_direction(posenc_version, position=dummy).shape[-1]
|
||
|
|
||
|
# Instantiate the MLP
|
||
|
mlp_widths = [d_hidden] * n_hidden_layers
|
||
|
input_widths = [d_posenc_pos, *mlp_widths]
|
||
|
output_widths = mlp_widths + [n_output]
|
||
|
|
||
|
self.meta_parameters = meta_parameters
|
||
|
|
||
|
# When this model is used jointly to express NeRF, it may have to
|
||
|
# process directions as well in which case we simply concatenate
|
||
|
# the direction representation at the specified layer.
|
||
|
self.insert_direction_at = insert_direction_at
|
||
|
if insert_direction_at is not None:
|
||
|
input_widths[self.insert_direction_at] += d_posenc_dir
|
||
|
|
||
|
linear_cls = lambda meta: (
|
||
|
partial(
|
||
|
MetaLinear,
|
||
|
meta_scale=False,
|
||
|
meta_shift=False,
|
||
|
meta_proj=meta_proj,
|
||
|
meta_bias=meta_bias,
|
||
|
trainable_meta=trainable_meta,
|
||
|
)
|
||
|
if meta
|
||
|
else nn.Linear
|
||
|
)
|
||
|
|
||
|
if meta_stop is None:
|
||
|
if n_meta_layers is not None:
|
||
|
assert n_meta_layers > 0
|
||
|
meta_stop = meta_start + n_meta_layers - 1
|
||
|
else:
|
||
|
meta_stop = n_hidden_layers
|
||
|
|
||
|
if meta_parameters:
|
||
|
metas = [meta_start <= layer <= meta_stop for layer in range(n_hidden_layers + 1)]
|
||
|
else:
|
||
|
metas = [False] * (n_hidden_layers + 1)
|
||
|
|
||
|
self.mlp = nn.ModuleList(
|
||
|
[
|
||
|
linear_cls(meta)(d_in, d_out, device=device)
|
||
|
for meta, d_in, d_out in zip(metas, input_widths, output_widths)
|
||
|
]
|
||
|
)
|
||
|
|
||
|
mlp_init(self.mlp, init=init, init_scale=init_scale)
|
||
|
|
||
|
self.activation = get_act(activation)
|
||
|
self.output_activation = get_act(output_activation)
|
||
|
|
||
|
self.device = device
|
||
|
self.to(device)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
query: Query,
|
||
|
params: Optional[Dict[str, torch.Tensor]] = None,
|
||
|
options: Optional[Dict[str, Any]] = None,
|
||
|
) -> AttrDict:
|
||
|
"""
|
||
|
:param position: [batch_size x ... x 3]
|
||
|
:param params: Meta parameters
|
||
|
:param options: Optional hyperparameters
|
||
|
"""
|
||
|
|
||
|
# query.direction is None typically for SDF models and training
|
||
|
h_final, _h_directionless = self._mlp(
|
||
|
query.position, query.direction, params=params, options=options
|
||
|
)
|
||
|
return self.output_activation(h_final)
|
||
|
|
||
|
def _run_mlp(
|
||
|
self, position: torch.Tensor, direction: torch.Tensor, params: AttrDict[str, torch.Tensor]
|
||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
|
"""
|
||
|
:return: the final and directionless activations at the given query
|
||
|
"""
|
||
|
h_preact = h = encode_position(self.posenc_version, position=position)
|
||
|
h_directionless = None
|
||
|
for i, layer in enumerate(self.mlp):
|
||
|
if i == self.insert_direction_at:
|
||
|
h_directionless = h_preact
|
||
|
h_direction = maybe_encode_direction(
|
||
|
self.posenc_version, position=position, direction=direction
|
||
|
)
|
||
|
h = torch.cat([h, h_direction], dim=-1)
|
||
|
if isinstance(layer, MetaLinear):
|
||
|
h = layer(h, params=subdict(params, f"mlp.{i}"))
|
||
|
else:
|
||
|
h = layer(h)
|
||
|
h_preact = h
|
||
|
if i < len(self.mlp) - 1:
|
||
|
h = self.activation(h)
|
||
|
h_final = h
|
||
|
if h_directionless is None:
|
||
|
h_directionless = h_preact
|
||
|
return h_final, h_directionless
|
||
|
|
||
|
def _mlp(
|
||
|
self,
|
||
|
position: torch.Tensor,
|
||
|
direction: Optional[torch.Tensor] = None,
|
||
|
params: Optional[Dict[str, torch.Tensor]] = None,
|
||
|
options: Optional[Dict[str, Any]] = None,
|
||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
|
"""
|
||
|
:param position: [batch_size x ... x 3]
|
||
|
:param params: Meta parameters
|
||
|
:param options: Optional hyperparameters
|
||
|
:return: the final and directionless activations at the given query
|
||
|
"""
|
||
|
params = self.update(params)
|
||
|
options = AttrDict() if options is None else AttrDict(options)
|
||
|
|
||
|
mlp = partial(self._run_mlp, direction=direction, params=params)
|
||
|
parameters = []
|
||
|
for i, layer in enumerate(self.mlp):
|
||
|
if isinstance(layer, MetaLinear):
|
||
|
parameters.extend(list(subdict(params, f"mlp.{i}").values()))
|
||
|
else:
|
||
|
parameters.extend(layer.parameters())
|
||
|
|
||
|
h_final, h_directionless = checkpoint(
|
||
|
mlp, (position,), parameters, options.checkpoint_stf_model
|
||
|
)
|
||
|
|
||
|
return h_final, h_directionless
|
||
|
|
||
|
|
||
|
class MLPSDFModel(MLPModel):
|
||
|
def __init__(self, initial_bias: float = -0.1, **kwargs):
|
||
|
super().__init__(n_output=1, output_activation="identity", **kwargs)
|
||
|
self.mlp[-1].bias.data.fill_(initial_bias)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
query: Query,
|
||
|
params: Optional[Dict[str, torch.Tensor]] = None,
|
||
|
options: Optional[Dict[str, Any]] = None,
|
||
|
) -> AttrDict[str, Any]:
|
||
|
signed_distance = super().forward(query=query, params=params, options=options)
|
||
|
return AttrDict(signed_distance=signed_distance)
|
||
|
|
||
|
|
||
|
class MLPTextureFieldModel(MLPModel):
|
||
|
def __init__(
|
||
|
self,
|
||
|
n_channels: int = 3,
|
||
|
**kwargs,
|
||
|
):
|
||
|
super().__init__(n_output=n_channels, output_activation="sigmoid", **kwargs)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
query: Query,
|
||
|
params: Optional[Dict[str, torch.Tensor]] = None,
|
||
|
options: Optional[Dict[str, Any]] = None,
|
||
|
) -> AttrDict[str, Any]:
|
||
|
channels = super().forward(query=query, params=params, options=options)
|
||
|
return AttrDict(channels=channels)
|