a fork of shap-e for gc
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

256 lines
7.9 KiB

2 years ago
from abc import ABC, abstractmethod
from functools import partial
from typing import Any, Dict, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from shap_e.models.nn.checkpoint import checkpoint
from shap_e.models.nn.encoding import encode_position, spherical_harmonics_basis
from shap_e.models.nn.meta import MetaModule, subdict
from shap_e.models.nn.ops import MLP, MetaMLP, get_act, mlp_init, zero_init
from shap_e.models.nn.utils import ArrayType
from shap_e.models.query import Query
from shap_e.util.collections import AttrDict
class NeRFModel(ABC):
"""
Parametric scene representation whose outputs are integrated by NeRFRenderer
"""
@abstractmethod
def forward(
self,
query: Query,
params: Optional[Dict[str, torch.Tensor]] = None,
options: Optional[Dict[str, Any]] = None,
) -> AttrDict:
"""
:param query: the points in the field to query.
:param params: Meta parameters
:param options: Optional hyperparameters
:return: An AttrDict containing at least
- density: [batch_size x ... x 1]
- channels: [batch_size x ... x n_channels]
- aux_losses: [batch_size x ... x 1]
"""
class VoidNeRFModel(MetaModule, NeRFModel):
"""
Implements the default empty space model where all queries are rendered as
background.
"""
def __init__(
self,
background: ArrayType,
trainable: bool = False,
channel_scale: float = 255.0,
device: torch.device = torch.device("cuda"),
):
super().__init__()
background = nn.Parameter(
torch.from_numpy(np.array(background)).to(dtype=torch.float32, device=device)
/ channel_scale
)
if trainable:
self.register_parameter("background", background)
else:
self.register_buffer("background", background)
def forward(
self,
query: Query,
params: Optional[Dict[str, torch.Tensor]] = None,
options: Optional[Dict[str, Any]] = None,
) -> AttrDict:
_ = params
default_bg = self.background[None]
background = options.get("background", default_bg) if options is not None else default_bg
shape = query.position.shape[:-1]
ones = [1] * (len(shape) - 1)
n_channels = background.shape[-1]
background = torch.broadcast_to(
background.view(background.shape[0], *ones, n_channels), [*shape, n_channels]
)
return background
class MLPNeRFModel(MetaModule, NeRFModel):
def __init__(
self,
# Positional encoding parameters
n_levels: int = 10,
# MLP parameters
d_hidden: int = 256,
n_density_layers: int = 4,
n_channel_layers: int = 1,
n_channels: int = 3,
sh_degree: int = 4,
activation: str = "relu",
density_activation: str = "exp",
init: Optional[str] = None,
init_scale: float = 1.0,
output_activation: str = "sigmoid",
meta_parameters: bool = False,
trainable_meta: bool = False,
zero_out: bool = True,
register_freqs: bool = True,
posenc_version: str = "v1",
device: torch.device = torch.device("cuda"),
):
super().__init__()
# Positional encoding
if register_freqs:
# not used anymore
self.register_buffer(
"freqs",
2.0 ** torch.arange(n_levels, device=device, dtype=torch.float).view(1, n_levels),
)
self.posenc_version = posenc_version
dummy = torch.eye(1, 3)
d_input = encode_position(posenc_version, position=dummy).shape[-1]
self.n_levels = n_levels
self.sh_degree = sh_degree
d_sh_coeffs = sh_degree**2
self.meta_parameters = meta_parameters
mlp_cls = (
partial(
MetaMLP,
meta_scale=False,
meta_shift=False,
meta_proj=True,
meta_bias=True,
trainable_meta=trainable_meta,
)
if meta_parameters
else MLP
)
self.density_mlp = mlp_cls(
d_input=d_input,
d_hidden=[d_hidden] * (n_density_layers - 1),
d_output=d_hidden,
act_name=activation,
init_scale=init_scale,
)
self.channel_mlp = mlp_cls(
d_input=d_hidden + d_sh_coeffs,
d_hidden=[d_hidden] * n_channel_layers,
d_output=n_channels,
act_name=activation,
init_scale=init_scale,
)
self.act = get_act(output_activation)
self.density_act = get_act(density_activation)
mlp_init(
list(self.density_mlp.affines) + list(self.channel_mlp.affines),
init=init,
init_scale=init_scale,
)
if zero_out:
zero_init(self.channel_mlp.affines[-1])
self.to(device)
def encode_position(self, query: Query):
h = encode_position(self.posenc_version, position=query.position)
return h
def forward(
self,
query: Query,
params: Optional[Dict[str, torch.Tensor]] = None,
options: Optional[Dict[str, Any]] = None,
) -> AttrDict:
params = self.update(params)
options = AttrDict() if options is None else AttrDict(options)
query = query.copy()
h_position = self.encode_position(query)
if self.meta_parameters:
density_params = subdict(params, "density_mlp")
density_mlp = partial(
self.density_mlp, params=density_params, options=options, log_prefix="density_"
)
density_mlp_parameters = list(density_params.values())
else:
density_mlp = partial(self.density_mlp, options=options, log_prefix="density_")
density_mlp_parameters = self.density_mlp.parameters()
h_density = checkpoint(
density_mlp,
(h_position,),
density_mlp_parameters,
options.checkpoint_nerf_mlp,
)
h_direction = maybe_get_spherical_harmonics_basis(
sh_degree=self.sh_degree,
coords_shape=query.position.shape,
coords=query.direction,
device=query.position.device,
)
if self.meta_parameters:
channel_params = subdict(params, "channel_mlp")
channel_mlp = partial(
self.channel_mlp, params=channel_params, options=options, log_prefix="channel_"
)
channel_mlp_parameters = list(channel_params.values())
else:
channel_mlp = partial(self.channel_mlp, options=options, log_prefix="channel_")
channel_mlp_parameters = self.channel_mlp.parameters()
h_channel = checkpoint(
channel_mlp,
(torch.cat([h_density, h_direction], dim=-1),),
channel_mlp_parameters,
options.checkpoint_nerf_mlp,
)
density_logit = h_density[..., :1]
res = AttrDict(
density_logit=density_logit,
density=self.density_act(density_logit),
channels=self.act(h_channel),
aux_losses=AttrDict(),
no_weight_grad_aux_losses=AttrDict(),
)
if options.return_h_density:
res.h_density = h_density
return res
def maybe_get_spherical_harmonics_basis(
sh_degree: int,
coords_shape: Tuple[int],
coords: Optional[torch.Tensor] = None,
device: torch.device = torch.device("cuda"),
) -> torch.Tensor:
"""
:param sh_degree: Spherical harmonics degree
:param coords_shape: [*shape, 3]
:param coords: optional coordinate tensor of coords_shape
"""
if coords is None:
return torch.zeros(*coords_shape[:-1], sh_degree**2).to(device)
return spherical_harmonics_basis(coords, sh_degree)