a fork of shap-e for gc
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

410 lines
12 KiB

import math
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from shap_e.util.collections import AttrDict
from .meta import MetaModule, subdict
from .pointnet2_utils import sample_and_group, sample_and_group_all
def gelu(x):
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
def swish(x):
return x * torch.sigmoid(x)
def quick_gelu(x):
return x * torch.sigmoid(1.702 * x)
def torch_gelu(x):
return torch.nn.functional.gelu(x)
def geglu(x):
v, gates = x.chunk(2, dim=-1)
return v * gelu(gates)
class SirenSin:
def __init__(self, w0=30.0):
self.w0 = w0
def __call__(self, x):
return torch.sin(self.w0 * x)
def get_act(name):
return {
"relu": torch.nn.functional.relu,
"leaky_relu": torch.nn.functional.leaky_relu,
"swish": swish,
"tanh": torch.tanh,
"gelu": gelu,
"quick_gelu": quick_gelu,
"torch_gelu": torch_gelu,
"gelu2": quick_gelu,
"geglu": geglu,
"sigmoid": torch.sigmoid,
"sin": torch.sin,
"sin30": SirenSin(w0=30.0),
"softplus": F.softplus,
"exp": torch.exp,
"identity": lambda x: x,
}[name]
def zero_init(affine):
nn.init.constant_(affine.weight, 0.0)
if affine.bias is not None:
nn.init.constant_(affine.bias, 0.0)
def siren_init_first_layer(affine, init_scale: float = 1.0):
n_input = affine.weight.shape[1]
u = init_scale / n_input
nn.init.uniform_(affine.weight, -u, u)
if affine.bias is not None:
nn.init.constant_(affine.bias, 0.0)
def siren_init(affine, coeff=1.0, init_scale: float = 1.0):
n_input = affine.weight.shape[1]
u = init_scale * np.sqrt(6.0 / n_input) / coeff
nn.init.uniform_(affine.weight, -u, u)
if affine.bias is not None:
nn.init.constant_(affine.bias, 0.0)
def siren_init_30(affine, init_scale: float = 1.0):
siren_init(affine, coeff=30.0, init_scale=init_scale)
def std_init(affine, init_scale: float = 1.0):
n_in = affine.weight.shape[1]
stddev = init_scale / math.sqrt(n_in)
nn.init.normal_(affine.weight, std=stddev)
if affine.bias is not None:
nn.init.constant_(affine.bias, 0.0)
def mlp_init(affines, init: Optional[str] = None, init_scale: float = 1.0):
if init == "siren30":
for idx, affine in enumerate(affines):
init = siren_init_first_layer if idx == 0 else siren_init_30
init(affine, init_scale=init_scale)
elif init == "siren":
for idx, affine in enumerate(affines):
init = siren_init_first_layer if idx == 0 else siren_init
init(affine, init_scale=init_scale)
elif init is None:
for affine in affines:
std_init(affine, init_scale=init_scale)
else:
raise NotImplementedError(init)
class MetaLinear(MetaModule):
def __init__(
self,
n_in,
n_out,
bias: bool = True,
meta_scale: bool = True,
meta_shift: bool = True,
meta_proj: bool = False,
meta_bias: bool = False,
trainable_meta: bool = False,
**kwargs,
):
super().__init__()
# n_in, n_out, bias=bias)
register_meta_fn = (
self.register_meta_parameter if trainable_meta else self.register_meta_buffer
)
if meta_scale:
register_meta_fn("scale", nn.Parameter(torch.ones(n_out, **kwargs)))
if meta_shift:
register_meta_fn("shift", nn.Parameter(torch.zeros(n_out, **kwargs)))
register_proj_fn = self.register_parameter if not meta_proj else register_meta_fn
register_proj_fn("weight", nn.Parameter(torch.empty((n_out, n_in), **kwargs)))
if not bias:
self.register_parameter("bias", None)
else:
register_bias_fn = self.register_parameter if not meta_bias else register_meta_fn
register_bias_fn("bias", nn.Parameter(torch.empty(n_out, **kwargs)))
self.reset_parameters()
def reset_parameters(self) -> None:
# from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
# Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
# uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
# https://github.com/pytorch/pytorch/issues/57109
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(self.bias, -bound, bound)
def _bcast(self, op, left, right):
if right.ndim == 2:
# Has dimension [batch x d_output]
right = right.unsqueeze(1)
return op(left, right)
def forward(self, x, params=None):
params = self.update(params)
batch_size, *shape, d_in = x.shape
x = x.view(batch_size, -1, d_in)
if params.weight.ndim == 2:
h = torch.einsum("bni,oi->bno", x, params.weight)
elif params.weight.ndim == 3:
h = torch.einsum("bni,boi->bno", x, params.weight)
if params.bias is not None:
h = self._bcast(torch.add, h, params.bias)
if params.scale is not None:
h = self._bcast(torch.mul, h, params.scale)
if params.shift is not None:
h = self._bcast(torch.add, h, params.shift)
h = h.view(batch_size, *shape, -1)
return h
def Conv(n_dim, d_in, d_out, kernel, stride=1, padding=0, dilation=1, **kwargs):
cls = {
1: nn.Conv1d,
2: nn.Conv2d,
3: nn.Conv3d,
}[n_dim]
return cls(d_in, d_out, kernel, stride=stride, padding=padding, dilation=dilation, **kwargs)
def flatten(x):
batch_size, *shape, n_channels = x.shape
n_ctx = np.prod(shape)
return x.view(batch_size, n_ctx, n_channels), AttrDict(
shape=shape, n_ctx=n_ctx, n_channels=n_channels
)
def unflatten(x, info):
batch_size = x.shape[0]
return x.view(batch_size, *info.shape, info.n_channels)
def torchify(x):
extent = list(range(1, x.ndim - 1))
return x.permute([0, x.ndim - 1, *extent])
def untorchify(x):
extent = list(range(2, x.ndim))
return x.permute([0, *extent, 1])
class MLP(nn.Module):
def __init__(
self,
d_input: int,
d_hidden: List[int],
d_output: int,
act_name: str = "quick_gelu",
bias: bool = True,
init: Optional[str] = None,
init_scale: float = 1.0,
zero_out: bool = False,
):
"""
Required: d_input, d_hidden, d_output
Optional: act_name, bias
"""
super().__init__()
ds = [d_input] + d_hidden + [d_output]
affines = [nn.Linear(d_in, d_out, bias=bias) for d_in, d_out in zip(ds[:-1], ds[1:])]
self.d = ds
self.affines = nn.ModuleList(affines)
self.act = get_act(act_name)
mlp_init(self.affines, init=init, init_scale=init_scale)
if zero_out:
zero_init(affines[-1])
def forward(self, h, options: Optional[AttrDict] = None, log_prefix: str = ""):
options = AttrDict() if options is None else AttrDict(options)
*hid, out = self.affines
for i, f in enumerate(hid):
h = self.act(f(h))
h = out(h)
return h
class MetaMLP(MetaModule):
def __init__(
self,
d_input: int,
d_hidden: List[int],
d_output: int,
act_name: str = "quick_gelu",
bias: bool = True,
meta_scale: bool = True,
meta_shift: bool = True,
meta_proj: bool = False,
meta_bias: bool = False,
trainable_meta: bool = False,
init: Optional[str] = None,
init_scale: float = 1.0,
zero_out: bool = False,
):
super().__init__()
ds = [d_input] + d_hidden + [d_output]
affines = [
MetaLinear(
d_in,
d_out,
bias=bias,
meta_scale=meta_scale,
meta_shift=meta_shift,
meta_proj=meta_proj,
meta_bias=meta_bias,
trainable_meta=trainable_meta,
)
for d_in, d_out in zip(ds[:-1], ds[1:])
]
self.d = ds
self.affines = nn.ModuleList(affines)
self.act = get_act(act_name)
mlp_init(affines, init=init, init_scale=init_scale)
if zero_out:
zero_init(affines[-1])
def forward(self, h, params=None, options: Optional[AttrDict] = None, log_prefix: str = ""):
options = AttrDict() if options is None else AttrDict(options)
params = self.update(params)
*hid, out = self.affines
for i, layer in enumerate(hid):
h = self.act(layer(h, params=subdict(params, f"{log_prefix}affines.{i}")))
last = len(self.affines) - 1
h = out(h, params=subdict(params, f"{log_prefix}affines.{last}"))
return h
class LayerNorm(nn.LayerNorm):
def __init__(
self, norm_shape: Union[int, Tuple[int]], eps: float = 1e-5, elementwise_affine: bool = True
):
super().__init__(norm_shape, eps=eps, elementwise_affine=elementwise_affine)
self.width = np.prod(norm_shape)
self.max_numel = 65535 * self.width
def forward(self, input):
if input.numel() > self.max_numel:
return F.layer_norm(
input.float(), self.normalized_shape, self.weight, self.bias, self.eps
).type_as(input)
else:
return super(LayerNorm, self).forward(input.float()).type_as(input)
class PointSetEmbedding(nn.Module):
def __init__(
self,
*,
radius: float,
n_point: int,
n_sample: int,
d_input: int,
d_hidden: List[int],
patch_size: int = 1,
stride: int = 1,
activation: str = "swish",
group_all: bool = False,
padding_mode: str = "zeros",
fps_method: str = "fps",
**kwargs,
):
super().__init__()
self.n_point = n_point
self.radius = radius
self.n_sample = n_sample
self.mlp_convs = nn.ModuleList()
self.act = get_act(activation)
self.patch_size = patch_size
self.stride = stride
last_channel = d_input + 3
for out_channel in d_hidden:
self.mlp_convs.append(
nn.Conv2d(
last_channel,
out_channel,
kernel_size=(patch_size, 1),
stride=(stride, 1),
padding=(patch_size // 2, 0),
padding_mode=padding_mode,
**kwargs,
)
)
last_channel = out_channel
self.group_all = group_all
self.fps_method = fps_method
def forward(self, xyz, points):
"""
Input:
xyz: input points position data, [B, C, N]
points: input points data, [B, D, N]
Return:
new_points: sample points feature data, [B, d_hidden[-1], n_point]
"""
xyz = xyz.permute(0, 2, 1)
if points is not None:
points = points.permute(0, 2, 1)
if self.group_all:
new_xyz, new_points = sample_and_group_all(xyz, points)
else:
new_xyz, new_points = sample_and_group(
self.n_point,
self.radius,
self.n_sample,
xyz,
points,
deterministic=not self.training,
fps_method=self.fps_method,
)
# new_xyz: sampled points position data, [B, n_point, C]
# new_points: sampled points data, [B, n_point, n_sample, C+D]
new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, n_sample, n_point]
for i, conv in enumerate(self.mlp_convs):
new_points = self.act(self.apply_conv(new_points, conv))
new_points = new_points.mean(dim=2)
return new_points
def apply_conv(self, points: torch.Tensor, conv: nn.Module):
batch, channels, n_samples, _ = points.shape
# Shuffle the representations
if self.patch_size > 1:
# TODO shuffle deterministically when not self.training
_, indices = torch.rand(batch, channels, n_samples, 1, device=points.device).sort(dim=2)
points = torch.gather(points, 2, torch.broadcast_to(indices, points.shape))
return conv(points)