a fork of shap-e for gc
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

411 lines
12 KiB

2 years ago
import math
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from shap_e.util.collections import AttrDict
from .meta import MetaModule, subdict
from .pointnet2_utils import sample_and_group, sample_and_group_all
def gelu(x):
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
def swish(x):
return x * torch.sigmoid(x)
def quick_gelu(x):
return x * torch.sigmoid(1.702 * x)
def torch_gelu(x):
return torch.nn.functional.gelu(x)
def geglu(x):
v, gates = x.chunk(2, dim=-1)
return v * gelu(gates)
class SirenSin:
def __init__(self, w0=30.0):
self.w0 = w0
def __call__(self, x):
return torch.sin(self.w0 * x)
def get_act(name):
return {
"relu": torch.nn.functional.relu,
"leaky_relu": torch.nn.functional.leaky_relu,
"swish": swish,
"tanh": torch.tanh,
"gelu": gelu,
"quick_gelu": quick_gelu,
"torch_gelu": torch_gelu,
"gelu2": quick_gelu,
"geglu": geglu,
"sigmoid": torch.sigmoid,
"sin": torch.sin,
"sin30": SirenSin(w0=30.0),
"softplus": F.softplus,
"exp": torch.exp,
"identity": lambda x: x,
}[name]
def zero_init(affine):
nn.init.constant_(affine.weight, 0.0)
if affine.bias is not None:
nn.init.constant_(affine.bias, 0.0)
def siren_init_first_layer(affine, init_scale: float = 1.0):
n_input = affine.weight.shape[1]
u = init_scale / n_input
nn.init.uniform_(affine.weight, -u, u)
if affine.bias is not None:
nn.init.constant_(affine.bias, 0.0)
def siren_init(affine, coeff=1.0, init_scale: float = 1.0):
n_input = affine.weight.shape[1]
u = init_scale * np.sqrt(6.0 / n_input) / coeff
nn.init.uniform_(affine.weight, -u, u)
if affine.bias is not None:
nn.init.constant_(affine.bias, 0.0)
def siren_init_30(affine, init_scale: float = 1.0):
siren_init(affine, coeff=30.0, init_scale=init_scale)
def std_init(affine, init_scale: float = 1.0):
n_in = affine.weight.shape[1]
stddev = init_scale / math.sqrt(n_in)
nn.init.normal_(affine.weight, std=stddev)
if affine.bias is not None:
nn.init.constant_(affine.bias, 0.0)
def mlp_init(affines, init: Optional[str] = None, init_scale: float = 1.0):
if init == "siren30":
for idx, affine in enumerate(affines):
init = siren_init_first_layer if idx == 0 else siren_init_30
init(affine, init_scale=init_scale)
elif init == "siren":
for idx, affine in enumerate(affines):
init = siren_init_first_layer if idx == 0 else siren_init
init(affine, init_scale=init_scale)
elif init is None:
for affine in affines:
std_init(affine, init_scale=init_scale)
else:
raise NotImplementedError(init)
class MetaLinear(MetaModule):
def __init__(
self,
n_in,
n_out,
bias: bool = True,
meta_scale: bool = True,
meta_shift: bool = True,
meta_proj: bool = False,
meta_bias: bool = False,
trainable_meta: bool = False,
**kwargs,
):
super().__init__()
# n_in, n_out, bias=bias)
register_meta_fn = (
self.register_meta_parameter if trainable_meta else self.register_meta_buffer
)
if meta_scale:
register_meta_fn("scale", nn.Parameter(torch.ones(n_out, **kwargs)))
if meta_shift:
register_meta_fn("shift", nn.Parameter(torch.zeros(n_out, **kwargs)))
register_proj_fn = self.register_parameter if not meta_proj else register_meta_fn
register_proj_fn("weight", nn.Parameter(torch.empty((n_out, n_in), **kwargs)))
if not bias:
self.register_parameter("bias", None)
else:
register_bias_fn = self.register_parameter if not meta_bias else register_meta_fn
register_bias_fn("bias", nn.Parameter(torch.empty(n_out, **kwargs)))
self.reset_parameters()
def reset_parameters(self) -> None:
# from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
# Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
# uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
# https://github.com/pytorch/pytorch/issues/57109
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(self.bias, -bound, bound)
def _bcast(self, op, left, right):
if right.ndim == 2:
# Has dimension [batch x d_output]
right = right.unsqueeze(1)
return op(left, right)
def forward(self, x, params=None):
params = self.update(params)
batch_size, *shape, d_in = x.shape
x = x.view(batch_size, -1, d_in)
if params.weight.ndim == 2:
h = torch.einsum("bni,oi->bno", x, params.weight)
elif params.weight.ndim == 3:
h = torch.einsum("bni,boi->bno", x, params.weight)
if params.bias is not None:
h = self._bcast(torch.add, h, params.bias)
if params.scale is not None:
h = self._bcast(torch.mul, h, params.scale)
if params.shift is not None:
h = self._bcast(torch.add, h, params.shift)
h = h.view(batch_size, *shape, -1)
return h
def Conv(n_dim, d_in, d_out, kernel, stride=1, padding=0, dilation=1, **kwargs):
cls = {
1: nn.Conv1d,
2: nn.Conv2d,
3: nn.Conv3d,
}[n_dim]
return cls(d_in, d_out, kernel, stride=stride, padding=padding, dilation=dilation, **kwargs)
def flatten(x):
batch_size, *shape, n_channels = x.shape
n_ctx = np.prod(shape)
return x.view(batch_size, n_ctx, n_channels), AttrDict(
shape=shape, n_ctx=n_ctx, n_channels=n_channels
)
def unflatten(x, info):
batch_size = x.shape[0]
return x.view(batch_size, *info.shape, info.n_channels)
def torchify(x):
extent = list(range(1, x.ndim - 1))
return x.permute([0, x.ndim - 1, *extent])
def untorchify(x):
extent = list(range(2, x.ndim))
return x.permute([0, *extent, 1])
class MLP(nn.Module):
def __init__(
self,
d_input: int,
d_hidden: List[int],
d_output: int,
act_name: str = "quick_gelu",
bias: bool = True,
init: Optional[str] = None,
init_scale: float = 1.0,
zero_out: bool = False,
):
"""
Required: d_input, d_hidden, d_output
Optional: act_name, bias
"""
super().__init__()
ds = [d_input] + d_hidden + [d_output]
affines = [nn.Linear(d_in, d_out, bias=bias) for d_in, d_out in zip(ds[:-1], ds[1:])]
self.d = ds
self.affines = nn.ModuleList(affines)
self.act = get_act(act_name)
mlp_init(self.affines, init=init, init_scale=init_scale)
if zero_out:
zero_init(affines[-1])
def forward(self, h, options: Optional[AttrDict] = None, log_prefix: str = ""):
options = AttrDict() if options is None else AttrDict(options)
*hid, out = self.affines
for i, f in enumerate(hid):
h = self.act(f(h))
h = out(h)
return h
class MetaMLP(MetaModule):
def __init__(
self,
d_input: int,
d_hidden: List[int],
d_output: int,
act_name: str = "quick_gelu",
bias: bool = True,
meta_scale: bool = True,
meta_shift: bool = True,
meta_proj: bool = False,
meta_bias: bool = False,
trainable_meta: bool = False,
init: Optional[str] = None,
init_scale: float = 1.0,
zero_out: bool = False,
):
super().__init__()
ds = [d_input] + d_hidden + [d_output]
affines = [
MetaLinear(
d_in,
d_out,
bias=bias,
meta_scale=meta_scale,
meta_shift=meta_shift,
meta_proj=meta_proj,
meta_bias=meta_bias,
trainable_meta=trainable_meta,
)
for d_in, d_out in zip(ds[:-1], ds[1:])
]
self.d = ds
self.affines = nn.ModuleList(affines)
self.act = get_act(act_name)
mlp_init(affines, init=init, init_scale=init_scale)
if zero_out:
zero_init(affines[-1])
def forward(self, h, params=None, options: Optional[AttrDict] = None, log_prefix: str = ""):
options = AttrDict() if options is None else AttrDict(options)
params = self.update(params)
*hid, out = self.affines
for i, layer in enumerate(hid):
h = self.act(layer(h, params=subdict(params, f"{log_prefix}affines.{i}")))
last = len(self.affines) - 1
h = out(h, params=subdict(params, f"{log_prefix}affines.{last}"))
return h
class LayerNorm(nn.LayerNorm):
def __init__(
self, norm_shape: Union[int, Tuple[int]], eps: float = 1e-5, elementwise_affine: bool = True
):
super().__init__(norm_shape, eps=eps, elementwise_affine=elementwise_affine)
self.width = np.prod(norm_shape)
self.max_numel = 65535 * self.width
def forward(self, input):
if input.numel() > self.max_numel:
return F.layer_norm(
input.float(), self.normalized_shape, self.weight, self.bias, self.eps
).type_as(input)
else:
return super(LayerNorm, self).forward(input.float()).type_as(input)
class PointSetEmbedding(nn.Module):
def __init__(
self,
*,
radius: float,
n_point: int,
n_sample: int,
d_input: int,
d_hidden: List[int],
patch_size: int = 1,
stride: int = 1,
activation: str = "swish",
group_all: bool = False,
padding_mode: str = "zeros",
fps_method: str = "fps",
**kwargs,
):
super().__init__()
self.n_point = n_point
self.radius = radius
self.n_sample = n_sample
self.mlp_convs = nn.ModuleList()
self.act = get_act(activation)
self.patch_size = patch_size
self.stride = stride
last_channel = d_input + 3
for out_channel in d_hidden:
self.mlp_convs.append(
nn.Conv2d(
last_channel,
out_channel,
kernel_size=(patch_size, 1),
stride=(stride, 1),
padding=(patch_size // 2, 0),
padding_mode=padding_mode,
**kwargs,
)
)
last_channel = out_channel
self.group_all = group_all
self.fps_method = fps_method
def forward(self, xyz, points):
"""
Input:
xyz: input points position data, [B, C, N]
points: input points data, [B, D, N]
Return:
new_points: sample points feature data, [B, d_hidden[-1], n_point]
"""
xyz = xyz.permute(0, 2, 1)
if points is not None:
points = points.permute(0, 2, 1)
if self.group_all:
new_xyz, new_points = sample_and_group_all(xyz, points)
else:
new_xyz, new_points = sample_and_group(
self.n_point,
self.radius,
self.n_sample,
xyz,
points,
deterministic=not self.training,
fps_method=self.fps_method,
)
# new_xyz: sampled points position data, [B, n_point, C]
# new_points: sampled points data, [B, n_point, n_sample, C+D]
new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, n_sample, n_point]
for i, conv in enumerate(self.mlp_convs):
new_points = self.act(self.apply_conv(new_points, conv))
new_points = new_points.mean(dim=2)
return new_points
def apply_conv(self, points: torch.Tensor, conv: nn.Module):
batch, channels, n_samples, _ = points.shape
# Shuffle the representations
if self.patch_size > 1:
# TODO shuffle deterministically when not self.training
_, indices = torch.rand(batch, channels, n_samples, 1, device=points.device).sort(dim=2)
points = torch.gather(points, 2, torch.broadcast_to(indices, points.shape))
return conv(points)