You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
197 lines
6.7 KiB
197 lines
6.7 KiB
import math
|
|
from abc import ABC, abstractmethod
|
|
from collections import OrderedDict
|
|
from typing import Any, Dict, Optional, Tuple
|
|
|
|
import numpy as np
|
|
import torch.nn as nn
|
|
from torch import torch
|
|
|
|
from shap_e.util.collections import AttrDict
|
|
|
|
|
|
def flatten_param_shapes(param_shapes: Dict[str, Tuple[int]]):
|
|
flat_shapes = OrderedDict(
|
|
(name, (int(np.prod(shape)) // shape[-1], shape[-1]))
|
|
for name, shape in param_shapes.items()
|
|
)
|
|
return flat_shapes
|
|
|
|
|
|
class ParamsProj(nn.Module, ABC):
|
|
def __init__(self, *, device: torch.device, param_shapes: Dict[str, Tuple[int]], d_latent: int):
|
|
super().__init__()
|
|
self.device = device
|
|
self.param_shapes = param_shapes
|
|
self.d_latent = d_latent
|
|
|
|
@abstractmethod
|
|
def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:
|
|
pass
|
|
|
|
|
|
class LinearParamsProj(ParamsProj):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
device: torch.device,
|
|
param_shapes: Dict[str, Tuple[int]],
|
|
d_latent: int,
|
|
init_scale: Optional[float] = None,
|
|
):
|
|
super().__init__(device=device, param_shapes=param_shapes, d_latent=d_latent)
|
|
self.param_shapes = param_shapes
|
|
self.projections = nn.ModuleDict({})
|
|
for k, v in param_shapes.items():
|
|
self.projections[_sanitize_name(k)] = nn.Linear(
|
|
d_latent, int(np.prod(v)), device=device
|
|
)
|
|
if init_scale is not None:
|
|
scale = init_scale / math.sqrt(d_latent)
|
|
mod = self.projections[_sanitize_name(k)]
|
|
nn.init.normal_(mod.weight, std=scale)
|
|
nn.init.zeros_(mod.bias)
|
|
|
|
def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:
|
|
out = AttrDict()
|
|
for k in self.param_shapes.keys():
|
|
proj = self.projections[_sanitize_name(k)]
|
|
out[k] = proj(x).reshape([len(x), *self.param_shapes[k]])
|
|
return out
|
|
|
|
|
|
class MLPParamsProj(ParamsProj):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
device: torch.device,
|
|
param_shapes: Dict[str, Tuple[int]],
|
|
d_latent: int,
|
|
hidden_size: Optional[int] = None,
|
|
):
|
|
super().__init__(device=device, param_shapes=param_shapes, d_latent=d_latent)
|
|
if hidden_size is None:
|
|
hidden_size = d_latent
|
|
self.param_shapes = param_shapes
|
|
self.projections = nn.ModuleDict({})
|
|
for k, v in param_shapes.items():
|
|
self.projections[_sanitize_name(k)] = nn.Sequential(
|
|
nn.Linear(d_latent, hidden_size, device=device),
|
|
nn.GELU(),
|
|
nn.Linear(hidden_size, int(np.prod(v)), device=device),
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:
|
|
out = AttrDict()
|
|
for k in self.param_shapes.keys():
|
|
proj = self.projections[_sanitize_name(k)]
|
|
out[k] = proj(x).reshape([len(x), *self.param_shapes[k]])
|
|
return out
|
|
|
|
|
|
class ChannelsProj(nn.Module):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
device: torch.device,
|
|
vectors: int,
|
|
channels: int,
|
|
d_latent: int,
|
|
init_scale: float = 1.0,
|
|
learned_scale: Optional[float] = None,
|
|
use_ln: bool = False,
|
|
):
|
|
super().__init__()
|
|
self.proj = nn.Linear(d_latent, vectors * channels, device=device)
|
|
self.use_ln = use_ln
|
|
self.learned_scale = learned_scale
|
|
if use_ln:
|
|
self.norm = nn.LayerNorm(normalized_shape=(channels,), device=device)
|
|
if learned_scale is not None:
|
|
self.norm.weight.data.fill_(learned_scale)
|
|
scale = init_scale / math.sqrt(d_latent)
|
|
elif learned_scale is not None:
|
|
gain = torch.ones((channels,), device=device) * learned_scale
|
|
self.register_parameter("gain", nn.Parameter(gain))
|
|
scale = init_scale / math.sqrt(d_latent)
|
|
else:
|
|
scale = init_scale / math.sqrt(d_latent * channels)
|
|
nn.init.normal_(self.proj.weight, std=scale)
|
|
nn.init.zeros_(self.proj.bias)
|
|
self.d_latent = d_latent
|
|
self.vectors = vectors
|
|
self.channels = channels
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x_bvd = x
|
|
w_vcd = self.proj.weight.view(self.vectors, self.channels, self.d_latent)
|
|
b_vc = self.proj.bias.view(1, self.vectors, self.channels)
|
|
h = torch.einsum("bvd,vcd->bvc", x_bvd, w_vcd)
|
|
if self.use_ln:
|
|
h = self.norm(h)
|
|
elif self.learned_scale is not None:
|
|
h = h * self.gain.view(1, 1, -1)
|
|
h = h + b_vc
|
|
return h
|
|
|
|
|
|
class ChannelsParamsProj(ParamsProj):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
device: torch.device,
|
|
param_shapes: Dict[str, Tuple[int]],
|
|
d_latent: int,
|
|
init_scale: float = 1.0,
|
|
learned_scale: Optional[float] = None,
|
|
use_ln: bool = False,
|
|
):
|
|
super().__init__(device=device, param_shapes=param_shapes, d_latent=d_latent)
|
|
self.param_shapes = param_shapes
|
|
self.projections = nn.ModuleDict({})
|
|
self.flat_shapes = flatten_param_shapes(param_shapes)
|
|
self.learned_scale = learned_scale
|
|
self.use_ln = use_ln
|
|
for k, (vectors, channels) in self.flat_shapes.items():
|
|
self.projections[_sanitize_name(k)] = ChannelsProj(
|
|
device=device,
|
|
vectors=vectors,
|
|
channels=channels,
|
|
d_latent=d_latent,
|
|
init_scale=init_scale,
|
|
learned_scale=learned_scale,
|
|
use_ln=use_ln,
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:
|
|
out = AttrDict()
|
|
start = 0
|
|
for k, shape in self.param_shapes.items():
|
|
vectors, _ = self.flat_shapes[k]
|
|
end = start + vectors
|
|
x_bvd = x[:, start:end]
|
|
out[k] = self.projections[_sanitize_name(k)](x_bvd).reshape(len(x), *shape)
|
|
start = end
|
|
return out
|
|
|
|
|
|
def params_proj_from_config(
|
|
config: Dict[str, Any], device: torch.device, param_shapes: Dict[str, Tuple[int]], d_latent: int
|
|
):
|
|
name = config.pop("name")
|
|
if name == "linear":
|
|
return LinearParamsProj(
|
|
**config, device=device, param_shapes=param_shapes, d_latent=d_latent
|
|
)
|
|
elif name == "mlp":
|
|
return MLPParamsProj(**config, device=device, param_shapes=param_shapes, d_latent=d_latent)
|
|
elif name == "channels":
|
|
return ChannelsParamsProj(
|
|
**config, device=device, param_shapes=param_shapes, d_latent=d_latent
|
|
)
|
|
else:
|
|
raise ValueError(f"unknown params proj: {name}")
|
|
|
|
|
|
def _sanitize_name(x: str) -> str:
|
|
return x.replace(".", "__")
|
|
|