a fork of shap-e for gc
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

255 lines
7.9 KiB

from abc import ABC, abstractmethod
from functools import partial
from typing import Any, Dict, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from shap_e.models.nn.checkpoint import checkpoint
from shap_e.models.nn.encoding import encode_position, spherical_harmonics_basis
from shap_e.models.nn.meta import MetaModule, subdict
from shap_e.models.nn.ops import MLP, MetaMLP, get_act, mlp_init, zero_init
from shap_e.models.nn.utils import ArrayType
from shap_e.models.query import Query
from shap_e.util.collections import AttrDict
class NeRFModel(ABC):
"""
Parametric scene representation whose outputs are integrated by NeRFRenderer
"""
@abstractmethod
def forward(
self,
query: Query,
params: Optional[Dict[str, torch.Tensor]] = None,
options: Optional[Dict[str, Any]] = None,
) -> AttrDict:
"""
:param query: the points in the field to query.
:param params: Meta parameters
:param options: Optional hyperparameters
:return: An AttrDict containing at least
- density: [batch_size x ... x 1]
- channels: [batch_size x ... x n_channels]
- aux_losses: [batch_size x ... x 1]
"""
class VoidNeRFModel(MetaModule, NeRFModel):
"""
Implements the default empty space model where all queries are rendered as
background.
"""
def __init__(
self,
background: ArrayType,
trainable: bool = False,
channel_scale: float = 255.0,
device: torch.device = torch.device("cuda"),
):
super().__init__()
background = nn.Parameter(
torch.from_numpy(np.array(background)).to(dtype=torch.float32, device=device)
/ channel_scale
)
if trainable:
self.register_parameter("background", background)
else:
self.register_buffer("background", background)
def forward(
self,
query: Query,
params: Optional[Dict[str, torch.Tensor]] = None,
options: Optional[Dict[str, Any]] = None,
) -> AttrDict:
_ = params
default_bg = self.background[None]
background = options.get("background", default_bg) if options is not None else default_bg
shape = query.position.shape[:-1]
ones = [1] * (len(shape) - 1)
n_channels = background.shape[-1]
background = torch.broadcast_to(
background.view(background.shape[0], *ones, n_channels), [*shape, n_channels]
)
return background
class MLPNeRFModel(MetaModule, NeRFModel):
def __init__(
self,
# Positional encoding parameters
n_levels: int = 10,
# MLP parameters
d_hidden: int = 256,
n_density_layers: int = 4,
n_channel_layers: int = 1,
n_channels: int = 3,
sh_degree: int = 4,
activation: str = "relu",
density_activation: str = "exp",
init: Optional[str] = None,
init_scale: float = 1.0,
output_activation: str = "sigmoid",
meta_parameters: bool = False,
trainable_meta: bool = False,
zero_out: bool = True,
register_freqs: bool = True,
posenc_version: str = "v1",
device: torch.device = torch.device("cuda"),
):
super().__init__()
# Positional encoding
if register_freqs:
# not used anymore
self.register_buffer(
"freqs",
2.0 ** torch.arange(n_levels, device=device, dtype=torch.float).view(1, n_levels),
)
self.posenc_version = posenc_version
dummy = torch.eye(1, 3)
d_input = encode_position(posenc_version, position=dummy).shape[-1]
self.n_levels = n_levels
self.sh_degree = sh_degree
d_sh_coeffs = sh_degree**2
self.meta_parameters = meta_parameters
mlp_cls = (
partial(
MetaMLP,
meta_scale=False,
meta_shift=False,
meta_proj=True,
meta_bias=True,
trainable_meta=trainable_meta,
)
if meta_parameters
else MLP
)
self.density_mlp = mlp_cls(
d_input=d_input,
d_hidden=[d_hidden] * (n_density_layers - 1),
d_output=d_hidden,
act_name=activation,
init_scale=init_scale,
)
self.channel_mlp = mlp_cls(
d_input=d_hidden + d_sh_coeffs,
d_hidden=[d_hidden] * n_channel_layers,
d_output=n_channels,
act_name=activation,
init_scale=init_scale,
)
self.act = get_act(output_activation)
self.density_act = get_act(density_activation)
mlp_init(
list(self.density_mlp.affines) + list(self.channel_mlp.affines),
init=init,
init_scale=init_scale,
)
if zero_out:
zero_init(self.channel_mlp.affines[-1])
self.to(device)
def encode_position(self, query: Query):
h = encode_position(self.posenc_version, position=query.position)
return h
def forward(
self,
query: Query,
params: Optional[Dict[str, torch.Tensor]] = None,
options: Optional[Dict[str, Any]] = None,
) -> AttrDict:
params = self.update(params)
options = AttrDict() if options is None else AttrDict(options)
query = query.copy()
h_position = self.encode_position(query)
if self.meta_parameters:
density_params = subdict(params, "density_mlp")
density_mlp = partial(
self.density_mlp, params=density_params, options=options, log_prefix="density_"
)
density_mlp_parameters = list(density_params.values())
else:
density_mlp = partial(self.density_mlp, options=options, log_prefix="density_")
density_mlp_parameters = self.density_mlp.parameters()
h_density = checkpoint(
density_mlp,
(h_position,),
density_mlp_parameters,
options.checkpoint_nerf_mlp,
)
h_direction = maybe_get_spherical_harmonics_basis(
sh_degree=self.sh_degree,
coords_shape=query.position.shape,
coords=query.direction,
device=query.position.device,
)
if self.meta_parameters:
channel_params = subdict(params, "channel_mlp")
channel_mlp = partial(
self.channel_mlp, params=channel_params, options=options, log_prefix="channel_"
)
channel_mlp_parameters = list(channel_params.values())
else:
channel_mlp = partial(self.channel_mlp, options=options, log_prefix="channel_")
channel_mlp_parameters = self.channel_mlp.parameters()
h_channel = checkpoint(
channel_mlp,
(torch.cat([h_density, h_direction], dim=-1),),
channel_mlp_parameters,
options.checkpoint_nerf_mlp,
)
density_logit = h_density[..., :1]
res = AttrDict(
density_logit=density_logit,
density=self.density_act(density_logit),
channels=self.act(h_channel),
aux_losses=AttrDict(),
no_weight_grad_aux_losses=AttrDict(),
)
if options.return_h_density:
res.h_density = h_density
return res
def maybe_get_spherical_harmonics_basis(
sh_degree: int,
coords_shape: Tuple[int],
coords: Optional[torch.Tensor] = None,
device: torch.device = torch.device("cuda"),
) -> torch.Tensor:
"""
:param sh_degree: Spherical harmonics degree
:param coords_shape: [*shape, 3]
:param coords: optional coordinate tensor of coords_shape
"""
if coords is None:
return torch.zeros(*coords_shape[:-1], sh_degree**2).to(device)
return spherical_harmonics_basis(coords, sh_degree)