a fork of shap-e for gc
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

214 lines
7.3 KiB

2 years ago
from functools import partial
from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn as nn
from shap_e.models.nn.checkpoint import checkpoint
from shap_e.models.nn.encoding import encode_position, maybe_encode_direction
from shap_e.models.nn.meta import MetaModule, subdict
from shap_e.models.nn.ops import MetaLinear, get_act, mlp_init
from shap_e.models.query import Query
from shap_e.util.collections import AttrDict
from .base import Model
class MLPModel(MetaModule, Model):
def __init__(
self,
n_output: int,
output_activation: str,
# Positional encoding parameters
posenc_version: str = "v1",
# Direction related channel prediction
insert_direction_at: Optional[int] = None,
# MLP parameters
d_hidden: int = 256,
n_hidden_layers: int = 4,
activation: str = "relu",
init: Optional[str] = None,
init_scale: float = 1.0,
meta_parameters: bool = False,
trainable_meta: bool = False,
meta_proj: bool = True,
meta_bias: bool = True,
meta_start: int = 0,
meta_stop: Optional[int] = None,
n_meta_layers: Optional[int] = None,
register_freqs: bool = False,
device: torch.device = torch.device("cuda"),
):
super().__init__()
if register_freqs:
self.register_buffer("freqs", 2.0 ** torch.arange(10, device=device).view(1, 10))
# Positional encoding
self.posenc_version = posenc_version
dummy = torch.eye(1, 3)
d_posenc_pos = encode_position(posenc_version, position=dummy).shape[-1]
d_posenc_dir = maybe_encode_direction(posenc_version, position=dummy).shape[-1]
# Instantiate the MLP
mlp_widths = [d_hidden] * n_hidden_layers
input_widths = [d_posenc_pos, *mlp_widths]
output_widths = mlp_widths + [n_output]
self.meta_parameters = meta_parameters
# When this model is used jointly to express NeRF, it may have to
# process directions as well in which case we simply concatenate
# the direction representation at the specified layer.
self.insert_direction_at = insert_direction_at
if insert_direction_at is not None:
input_widths[self.insert_direction_at] += d_posenc_dir
linear_cls = lambda meta: (
partial(
MetaLinear,
meta_scale=False,
meta_shift=False,
meta_proj=meta_proj,
meta_bias=meta_bias,
trainable_meta=trainable_meta,
)
if meta
else nn.Linear
)
if meta_stop is None:
if n_meta_layers is not None:
assert n_meta_layers > 0
meta_stop = meta_start + n_meta_layers - 1
else:
meta_stop = n_hidden_layers
if meta_parameters:
metas = [meta_start <= layer <= meta_stop for layer in range(n_hidden_layers + 1)]
else:
metas = [False] * (n_hidden_layers + 1)
self.mlp = nn.ModuleList(
[
linear_cls(meta)(d_in, d_out, device=device)
for meta, d_in, d_out in zip(metas, input_widths, output_widths)
]
)
mlp_init(self.mlp, init=init, init_scale=init_scale)
self.activation = get_act(activation)
self.output_activation = get_act(output_activation)
self.device = device
self.to(device)
def forward(
self,
query: Query,
params: Optional[Dict[str, torch.Tensor]] = None,
options: Optional[Dict[str, Any]] = None,
) -> AttrDict:
"""
:param position: [batch_size x ... x 3]
:param params: Meta parameters
:param options: Optional hyperparameters
"""
# query.direction is None typically for SDF models and training
h_final, _h_directionless = self._mlp(
query.position, query.direction, params=params, options=options
)
return self.output_activation(h_final)
def _run_mlp(
self, position: torch.Tensor, direction: torch.Tensor, params: AttrDict[str, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
:return: the final and directionless activations at the given query
"""
h_preact = h = encode_position(self.posenc_version, position=position)
h_directionless = None
for i, layer in enumerate(self.mlp):
if i == self.insert_direction_at:
h_directionless = h_preact
h_direction = maybe_encode_direction(
self.posenc_version, position=position, direction=direction
)
h = torch.cat([h, h_direction], dim=-1)
if isinstance(layer, MetaLinear):
h = layer(h, params=subdict(params, f"mlp.{i}"))
else:
h = layer(h)
h_preact = h
if i < len(self.mlp) - 1:
h = self.activation(h)
h_final = h
if h_directionless is None:
h_directionless = h_preact
return h_final, h_directionless
def _mlp(
self,
position: torch.Tensor,
direction: Optional[torch.Tensor] = None,
params: Optional[Dict[str, torch.Tensor]] = None,
options: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
:param position: [batch_size x ... x 3]
:param params: Meta parameters
:param options: Optional hyperparameters
:return: the final and directionless activations at the given query
"""
params = self.update(params)
options = AttrDict() if options is None else AttrDict(options)
mlp = partial(self._run_mlp, direction=direction, params=params)
parameters = []
for i, layer in enumerate(self.mlp):
if isinstance(layer, MetaLinear):
parameters.extend(list(subdict(params, f"mlp.{i}").values()))
else:
parameters.extend(layer.parameters())
h_final, h_directionless = checkpoint(
mlp, (position,), parameters, options.checkpoint_stf_model
)
return h_final, h_directionless
class MLPSDFModel(MLPModel):
def __init__(self, initial_bias: float = -0.1, **kwargs):
super().__init__(n_output=1, output_activation="identity", **kwargs)
self.mlp[-1].bias.data.fill_(initial_bias)
def forward(
self,
query: Query,
params: Optional[Dict[str, torch.Tensor]] = None,
options: Optional[Dict[str, Any]] = None,
) -> AttrDict[str, Any]:
signed_distance = super().forward(query=query, params=params, options=options)
return AttrDict(signed_distance=signed_distance)
class MLPTextureFieldModel(MLPModel):
def __init__(
self,
n_channels: int = 3,
**kwargs,
):
super().__init__(n_output=n_channels, output_activation="sigmoid", **kwargs)
def forward(
self,
query: Query,
params: Optional[Dict[str, torch.Tensor]] = None,
options: Optional[Dict[str, Any]] = None,
) -> AttrDict[str, Any]:
channels = super().forward(query=query, params=params, options=options)
return AttrDict(channels=channels)