import math from typing import List, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from shap_e.util.collections import AttrDict from .meta import MetaModule, subdict from .pointnet2_utils import sample_and_group, sample_and_group_all def gelu(x): return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) def swish(x): return x * torch.sigmoid(x) def quick_gelu(x): return x * torch.sigmoid(1.702 * x) def torch_gelu(x): return torch.nn.functional.gelu(x) def geglu(x): v, gates = x.chunk(2, dim=-1) return v * gelu(gates) class SirenSin: def __init__(self, w0=30.0): self.w0 = w0 def __call__(self, x): return torch.sin(self.w0 * x) def get_act(name): return { "relu": torch.nn.functional.relu, "leaky_relu": torch.nn.functional.leaky_relu, "swish": swish, "tanh": torch.tanh, "gelu": gelu, "quick_gelu": quick_gelu, "torch_gelu": torch_gelu, "gelu2": quick_gelu, "geglu": geglu, "sigmoid": torch.sigmoid, "sin": torch.sin, "sin30": SirenSin(w0=30.0), "softplus": F.softplus, "exp": torch.exp, "identity": lambda x: x, }[name] def zero_init(affine): nn.init.constant_(affine.weight, 0.0) if affine.bias is not None: nn.init.constant_(affine.bias, 0.0) def siren_init_first_layer(affine, init_scale: float = 1.0): n_input = affine.weight.shape[1] u = init_scale / n_input nn.init.uniform_(affine.weight, -u, u) if affine.bias is not None: nn.init.constant_(affine.bias, 0.0) def siren_init(affine, coeff=1.0, init_scale: float = 1.0): n_input = affine.weight.shape[1] u = init_scale * np.sqrt(6.0 / n_input) / coeff nn.init.uniform_(affine.weight, -u, u) if affine.bias is not None: nn.init.constant_(affine.bias, 0.0) def siren_init_30(affine, init_scale: float = 1.0): siren_init(affine, coeff=30.0, init_scale=init_scale) def std_init(affine, init_scale: float = 1.0): n_in = affine.weight.shape[1] stddev = init_scale / math.sqrt(n_in) nn.init.normal_(affine.weight, std=stddev) if affine.bias is not None: nn.init.constant_(affine.bias, 0.0) def mlp_init(affines, init: Optional[str] = None, init_scale: float = 1.0): if init == "siren30": for idx, affine in enumerate(affines): init = siren_init_first_layer if idx == 0 else siren_init_30 init(affine, init_scale=init_scale) elif init == "siren": for idx, affine in enumerate(affines): init = siren_init_first_layer if idx == 0 else siren_init init(affine, init_scale=init_scale) elif init is None: for affine in affines: std_init(affine, init_scale=init_scale) else: raise NotImplementedError(init) class MetaLinear(MetaModule): def __init__( self, n_in, n_out, bias: bool = True, meta_scale: bool = True, meta_shift: bool = True, meta_proj: bool = False, meta_bias: bool = False, trainable_meta: bool = False, **kwargs, ): super().__init__() # n_in, n_out, bias=bias) register_meta_fn = ( self.register_meta_parameter if trainable_meta else self.register_meta_buffer ) if meta_scale: register_meta_fn("scale", nn.Parameter(torch.ones(n_out, **kwargs))) if meta_shift: register_meta_fn("shift", nn.Parameter(torch.zeros(n_out, **kwargs))) register_proj_fn = self.register_parameter if not meta_proj else register_meta_fn register_proj_fn("weight", nn.Parameter(torch.empty((n_out, n_in), **kwargs))) if not bias: self.register_parameter("bias", None) else: register_bias_fn = self.register_parameter if not meta_bias else register_meta_fn register_bias_fn("bias", nn.Parameter(torch.empty(n_out, **kwargs))) self.reset_parameters() def reset_parameters(self) -> None: # from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see # https://github.com/pytorch/pytorch/issues/57109 nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) if self.bias is not None: fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 nn.init.uniform_(self.bias, -bound, bound) def _bcast(self, op, left, right): if right.ndim == 2: # Has dimension [batch x d_output] right = right.unsqueeze(1) return op(left, right) def forward(self, x, params=None): params = self.update(params) batch_size, *shape, d_in = x.shape x = x.view(batch_size, -1, d_in) if params.weight.ndim == 2: h = torch.einsum("bni,oi->bno", x, params.weight) elif params.weight.ndim == 3: h = torch.einsum("bni,boi->bno", x, params.weight) if params.bias is not None: h = self._bcast(torch.add, h, params.bias) if params.scale is not None: h = self._bcast(torch.mul, h, params.scale) if params.shift is not None: h = self._bcast(torch.add, h, params.shift) h = h.view(batch_size, *shape, -1) return h def Conv(n_dim, d_in, d_out, kernel, stride=1, padding=0, dilation=1, **kwargs): cls = { 1: nn.Conv1d, 2: nn.Conv2d, 3: nn.Conv3d, }[n_dim] return cls(d_in, d_out, kernel, stride=stride, padding=padding, dilation=dilation, **kwargs) def flatten(x): batch_size, *shape, n_channels = x.shape n_ctx = np.prod(shape) return x.view(batch_size, n_ctx, n_channels), AttrDict( shape=shape, n_ctx=n_ctx, n_channels=n_channels ) def unflatten(x, info): batch_size = x.shape[0] return x.view(batch_size, *info.shape, info.n_channels) def torchify(x): extent = list(range(1, x.ndim - 1)) return x.permute([0, x.ndim - 1, *extent]) def untorchify(x): extent = list(range(2, x.ndim)) return x.permute([0, *extent, 1]) class MLP(nn.Module): def __init__( self, d_input: int, d_hidden: List[int], d_output: int, act_name: str = "quick_gelu", bias: bool = True, init: Optional[str] = None, init_scale: float = 1.0, zero_out: bool = False, ): """ Required: d_input, d_hidden, d_output Optional: act_name, bias """ super().__init__() ds = [d_input] + d_hidden + [d_output] affines = [nn.Linear(d_in, d_out, bias=bias) for d_in, d_out in zip(ds[:-1], ds[1:])] self.d = ds self.affines = nn.ModuleList(affines) self.act = get_act(act_name) mlp_init(self.affines, init=init, init_scale=init_scale) if zero_out: zero_init(affines[-1]) def forward(self, h, options: Optional[AttrDict] = None, log_prefix: str = ""): options = AttrDict() if options is None else AttrDict(options) *hid, out = self.affines for i, f in enumerate(hid): h = self.act(f(h)) h = out(h) return h class MetaMLP(MetaModule): def __init__( self, d_input: int, d_hidden: List[int], d_output: int, act_name: str = "quick_gelu", bias: bool = True, meta_scale: bool = True, meta_shift: bool = True, meta_proj: bool = False, meta_bias: bool = False, trainable_meta: bool = False, init: Optional[str] = None, init_scale: float = 1.0, zero_out: bool = False, ): super().__init__() ds = [d_input] + d_hidden + [d_output] affines = [ MetaLinear( d_in, d_out, bias=bias, meta_scale=meta_scale, meta_shift=meta_shift, meta_proj=meta_proj, meta_bias=meta_bias, trainable_meta=trainable_meta, ) for d_in, d_out in zip(ds[:-1], ds[1:]) ] self.d = ds self.affines = nn.ModuleList(affines) self.act = get_act(act_name) mlp_init(affines, init=init, init_scale=init_scale) if zero_out: zero_init(affines[-1]) def forward(self, h, params=None, options: Optional[AttrDict] = None, log_prefix: str = ""): options = AttrDict() if options is None else AttrDict(options) params = self.update(params) *hid, out = self.affines for i, layer in enumerate(hid): h = self.act(layer(h, params=subdict(params, f"{log_prefix}affines.{i}"))) last = len(self.affines) - 1 h = out(h, params=subdict(params, f"{log_prefix}affines.{last}")) return h class LayerNorm(nn.LayerNorm): def __init__( self, norm_shape: Union[int, Tuple[int]], eps: float = 1e-5, elementwise_affine: bool = True ): super().__init__(norm_shape, eps=eps, elementwise_affine=elementwise_affine) self.width = np.prod(norm_shape) self.max_numel = 65535 * self.width def forward(self, input): if input.numel() > self.max_numel: return F.layer_norm( input.float(), self.normalized_shape, self.weight, self.bias, self.eps ).type_as(input) else: return super(LayerNorm, self).forward(input.float()).type_as(input) class PointSetEmbedding(nn.Module): def __init__( self, *, radius: float, n_point: int, n_sample: int, d_input: int, d_hidden: List[int], patch_size: int = 1, stride: int = 1, activation: str = "swish", group_all: bool = False, padding_mode: str = "zeros", fps_method: str = "fps", **kwargs, ): super().__init__() self.n_point = n_point self.radius = radius self.n_sample = n_sample self.mlp_convs = nn.ModuleList() self.act = get_act(activation) self.patch_size = patch_size self.stride = stride last_channel = d_input + 3 for out_channel in d_hidden: self.mlp_convs.append( nn.Conv2d( last_channel, out_channel, kernel_size=(patch_size, 1), stride=(stride, 1), padding=(patch_size // 2, 0), padding_mode=padding_mode, **kwargs, ) ) last_channel = out_channel self.group_all = group_all self.fps_method = fps_method def forward(self, xyz, points): """ Input: xyz: input points position data, [B, C, N] points: input points data, [B, D, N] Return: new_points: sample points feature data, [B, d_hidden[-1], n_point] """ xyz = xyz.permute(0, 2, 1) if points is not None: points = points.permute(0, 2, 1) if self.group_all: new_xyz, new_points = sample_and_group_all(xyz, points) else: new_xyz, new_points = sample_and_group( self.n_point, self.radius, self.n_sample, xyz, points, deterministic=not self.training, fps_method=self.fps_method, ) # new_xyz: sampled points position data, [B, n_point, C] # new_points: sampled points data, [B, n_point, n_sample, C+D] new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, n_sample, n_point] for i, conv in enumerate(self.mlp_convs): new_points = self.act(self.apply_conv(new_points, conv)) new_points = new_points.mean(dim=2) return new_points def apply_conv(self, points: torch.Tensor, conv: nn.Module): batch, channels, n_samples, _ = points.shape # Shuffle the representations if self.patch_size > 1: # TODO shuffle deterministically when not self.training _, indices = torch.rand(batch, channels, n_samples, 1, device=points.device).sort(dim=2) points = torch.gather(points, 2, torch.broadcast_to(indices, points.shape)) return conv(points)