a fork of shap-e for gc
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

213 lines
7.3 KiB

from functools import partial
from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn as nn
from shap_e.models.nn.checkpoint import checkpoint
from shap_e.models.nn.encoding import encode_position, maybe_encode_direction
from shap_e.models.nn.meta import MetaModule, subdict
from shap_e.models.nn.ops import MetaLinear, get_act, mlp_init
from shap_e.models.query import Query
from shap_e.util.collections import AttrDict
from .base import Model
class MLPModel(MetaModule, Model):
def __init__(
self,
n_output: int,
output_activation: str,
# Positional encoding parameters
posenc_version: str = "v1",
# Direction related channel prediction
insert_direction_at: Optional[int] = None,
# MLP parameters
d_hidden: int = 256,
n_hidden_layers: int = 4,
activation: str = "relu",
init: Optional[str] = None,
init_scale: float = 1.0,
meta_parameters: bool = False,
trainable_meta: bool = False,
meta_proj: bool = True,
meta_bias: bool = True,
meta_start: int = 0,
meta_stop: Optional[int] = None,
n_meta_layers: Optional[int] = None,
register_freqs: bool = False,
device: torch.device = torch.device("cuda"),
):
super().__init__()
if register_freqs:
self.register_buffer("freqs", 2.0 ** torch.arange(10, device=device).view(1, 10))
# Positional encoding
self.posenc_version = posenc_version
dummy = torch.eye(1, 3)
d_posenc_pos = encode_position(posenc_version, position=dummy).shape[-1]
d_posenc_dir = maybe_encode_direction(posenc_version, position=dummy).shape[-1]
# Instantiate the MLP
mlp_widths = [d_hidden] * n_hidden_layers
input_widths = [d_posenc_pos, *mlp_widths]
output_widths = mlp_widths + [n_output]
self.meta_parameters = meta_parameters
# When this model is used jointly to express NeRF, it may have to
# process directions as well in which case we simply concatenate
# the direction representation at the specified layer.
self.insert_direction_at = insert_direction_at
if insert_direction_at is not None:
input_widths[self.insert_direction_at] += d_posenc_dir
linear_cls = lambda meta: (
partial(
MetaLinear,
meta_scale=False,
meta_shift=False,
meta_proj=meta_proj,
meta_bias=meta_bias,
trainable_meta=trainable_meta,
)
if meta
else nn.Linear
)
if meta_stop is None:
if n_meta_layers is not None:
assert n_meta_layers > 0
meta_stop = meta_start + n_meta_layers - 1
else:
meta_stop = n_hidden_layers
if meta_parameters:
metas = [meta_start <= layer <= meta_stop for layer in range(n_hidden_layers + 1)]
else:
metas = [False] * (n_hidden_layers + 1)
self.mlp = nn.ModuleList(
[
linear_cls(meta)(d_in, d_out, device=device)
for meta, d_in, d_out in zip(metas, input_widths, output_widths)
]
)
mlp_init(self.mlp, init=init, init_scale=init_scale)
self.activation = get_act(activation)
self.output_activation = get_act(output_activation)
self.device = device
self.to(device)
def forward(
self,
query: Query,
params: Optional[Dict[str, torch.Tensor]] = None,
options: Optional[Dict[str, Any]] = None,
) -> AttrDict:
"""
:param position: [batch_size x ... x 3]
:param params: Meta parameters
:param options: Optional hyperparameters
"""
# query.direction is None typically for SDF models and training
h_final, _h_directionless = self._mlp(
query.position, query.direction, params=params, options=options
)
return self.output_activation(h_final)
def _run_mlp(
self, position: torch.Tensor, direction: torch.Tensor, params: AttrDict[str, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
:return: the final and directionless activations at the given query
"""
h_preact = h = encode_position(self.posenc_version, position=position)
h_directionless = None
for i, layer in enumerate(self.mlp):
if i == self.insert_direction_at:
h_directionless = h_preact
h_direction = maybe_encode_direction(
self.posenc_version, position=position, direction=direction
)
h = torch.cat([h, h_direction], dim=-1)
if isinstance(layer, MetaLinear):
h = layer(h, params=subdict(params, f"mlp.{i}"))
else:
h = layer(h)
h_preact = h
if i < len(self.mlp) - 1:
h = self.activation(h)
h_final = h
if h_directionless is None:
h_directionless = h_preact
return h_final, h_directionless
def _mlp(
self,
position: torch.Tensor,
direction: Optional[torch.Tensor] = None,
params: Optional[Dict[str, torch.Tensor]] = None,
options: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
:param position: [batch_size x ... x 3]
:param params: Meta parameters
:param options: Optional hyperparameters
:return: the final and directionless activations at the given query
"""
params = self.update(params)
options = AttrDict() if options is None else AttrDict(options)
mlp = partial(self._run_mlp, direction=direction, params=params)
parameters = []
for i, layer in enumerate(self.mlp):
if isinstance(layer, MetaLinear):
parameters.extend(list(subdict(params, f"mlp.{i}").values()))
else:
parameters.extend(layer.parameters())
h_final, h_directionless = checkpoint(
mlp, (position,), parameters, options.checkpoint_stf_model
)
return h_final, h_directionless
class MLPSDFModel(MLPModel):
def __init__(self, initial_bias: float = -0.1, **kwargs):
super().__init__(n_output=1, output_activation="identity", **kwargs)
self.mlp[-1].bias.data.fill_(initial_bias)
def forward(
self,
query: Query,
params: Optional[Dict[str, torch.Tensor]] = None,
options: Optional[Dict[str, Any]] = None,
) -> AttrDict[str, Any]:
signed_distance = super().forward(query=query, params=params, options=options)
return AttrDict(signed_distance=signed_distance)
class MLPTextureFieldModel(MLPModel):
def __init__(
self,
n_channels: int = 3,
**kwargs,
):
super().__init__(n_output=n_channels, output_activation="sigmoid", **kwargs)
def forward(
self,
query: Query,
params: Optional[Dict[str, torch.Tensor]] = None,
options: Optional[Dict[str, Any]] = None,
) -> AttrDict[str, Any]:
channels = super().forward(query=query, params=params, options=options)
return AttrDict(channels=channels)