You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
492 lines
16 KiB
492 lines
16 KiB
2 years ago
|
import math
|
||
|
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
|
||
|
from shap_e.models.nn.checkpoint import checkpoint
|
||
|
|
||
|
from .pretrained_clip import FrozenImageCLIP, ImageCLIP, ImageType
|
||
|
from .util import timestep_embedding
|
||
|
|
||
|
|
||
|
def init_linear(l, stddev):
|
||
|
nn.init.normal_(l.weight, std=stddev)
|
||
|
if l.bias is not None:
|
||
|
nn.init.constant_(l.bias, 0.0)
|
||
|
|
||
|
|
||
|
class MultiheadAttention(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
*,
|
||
|
device: torch.device,
|
||
|
dtype: torch.dtype,
|
||
|
n_ctx: int,
|
||
|
width: int,
|
||
|
heads: int,
|
||
|
init_scale: float,
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.n_ctx = n_ctx
|
||
|
self.width = width
|
||
|
self.heads = heads
|
||
|
self.c_qkv = nn.Linear(width, width * 3, device=device, dtype=dtype)
|
||
|
self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
|
||
|
self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx)
|
||
|
init_linear(self.c_qkv, init_scale)
|
||
|
init_linear(self.c_proj, init_scale)
|
||
|
|
||
|
def forward(self, x):
|
||
|
x = self.c_qkv(x)
|
||
|
x = checkpoint(self.attention, (x,), (), True)
|
||
|
x = self.c_proj(x)
|
||
|
return x
|
||
|
|
||
|
|
||
|
class MLP(nn.Module):
|
||
|
def __init__(self, *, device: torch.device, dtype: torch.dtype, width: int, init_scale: float):
|
||
|
super().__init__()
|
||
|
self.width = width
|
||
|
self.c_fc = nn.Linear(width, width * 4, device=device, dtype=dtype)
|
||
|
self.c_proj = nn.Linear(width * 4, width, device=device, dtype=dtype)
|
||
|
self.gelu = nn.GELU()
|
||
|
init_linear(self.c_fc, init_scale)
|
||
|
init_linear(self.c_proj, init_scale)
|
||
|
|
||
|
def forward(self, x):
|
||
|
return self.c_proj(self.gelu(self.c_fc(x)))
|
||
|
|
||
|
|
||
|
class QKVMultiheadAttention(nn.Module):
|
||
|
def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int):
|
||
|
super().__init__()
|
||
|
self.device = device
|
||
|
self.dtype = dtype
|
||
|
self.heads = heads
|
||
|
self.n_ctx = n_ctx
|
||
|
|
||
|
def forward(self, qkv):
|
||
|
bs, n_ctx, width = qkv.shape
|
||
|
attn_ch = width // self.heads // 3
|
||
|
scale = 1 / math.sqrt(math.sqrt(attn_ch))
|
||
|
qkv = qkv.view(bs, n_ctx, self.heads, -1)
|
||
|
q, k, v = torch.split(qkv, attn_ch, dim=-1)
|
||
|
weight = torch.einsum(
|
||
|
"bthc,bshc->bhts", q * scale, k * scale
|
||
|
) # More stable with f16 than dividing afterwards
|
||
|
wdtype = weight.dtype
|
||
|
weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
|
||
|
return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
|
||
|
|
||
|
|
||
|
class ResidualAttentionBlock(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
*,
|
||
|
device: torch.device,
|
||
|
dtype: torch.dtype,
|
||
|
n_ctx: int,
|
||
|
width: int,
|
||
|
heads: int,
|
||
|
init_scale: float = 1.0,
|
||
|
):
|
||
|
super().__init__()
|
||
|
|
||
|
self.attn = MultiheadAttention(
|
||
|
device=device,
|
||
|
dtype=dtype,
|
||
|
n_ctx=n_ctx,
|
||
|
width=width,
|
||
|
heads=heads,
|
||
|
init_scale=init_scale,
|
||
|
)
|
||
|
self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
|
||
|
self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
|
||
|
self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype)
|
||
|
|
||
|
def forward(self, x: torch.Tensor):
|
||
|
x = x + self.attn(self.ln_1(x))
|
||
|
x = x + self.mlp(self.ln_2(x))
|
||
|
return x
|
||
|
|
||
|
|
||
|
class Transformer(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
*,
|
||
|
device: torch.device,
|
||
|
dtype: torch.dtype,
|
||
|
n_ctx: int,
|
||
|
width: int,
|
||
|
layers: int,
|
||
|
heads: int,
|
||
|
init_scale: float = 0.25,
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.n_ctx = n_ctx
|
||
|
self.width = width
|
||
|
self.layers = layers
|
||
|
init_scale = init_scale * math.sqrt(1.0 / width)
|
||
|
self.resblocks = nn.ModuleList(
|
||
|
[
|
||
|
ResidualAttentionBlock(
|
||
|
device=device,
|
||
|
dtype=dtype,
|
||
|
n_ctx=n_ctx,
|
||
|
width=width,
|
||
|
heads=heads,
|
||
|
init_scale=init_scale,
|
||
|
)
|
||
|
for _ in range(layers)
|
||
|
]
|
||
|
)
|
||
|
|
||
|
def forward(self, x: torch.Tensor):
|
||
|
for block in self.resblocks:
|
||
|
x = block(x)
|
||
|
return x
|
||
|
|
||
|
|
||
|
class PointDiffusionTransformer(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
*,
|
||
|
device: torch.device,
|
||
|
dtype: torch.dtype,
|
||
|
input_channels: int = 3,
|
||
|
output_channels: int = 3,
|
||
|
n_ctx: int = 1024,
|
||
|
width: int = 512,
|
||
|
layers: int = 12,
|
||
|
heads: int = 8,
|
||
|
init_scale: float = 0.25,
|
||
|
time_token_cond: bool = False,
|
||
|
use_pos_emb: bool = False,
|
||
|
pos_emb_init_scale: float = 1.0,
|
||
|
pos_emb_n_ctx: Optional[int] = None,
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.input_channels = input_channels
|
||
|
self.output_channels = output_channels
|
||
|
self.n_ctx = n_ctx
|
||
|
self.time_token_cond = time_token_cond
|
||
|
self.use_pos_emb = use_pos_emb
|
||
|
self.time_embed = MLP(
|
||
|
device=device, dtype=dtype, width=width, init_scale=init_scale * math.sqrt(1.0 / width)
|
||
|
)
|
||
|
self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype)
|
||
|
self.backbone = Transformer(
|
||
|
device=device,
|
||
|
dtype=dtype,
|
||
|
n_ctx=n_ctx + int(time_token_cond),
|
||
|
width=width,
|
||
|
layers=layers,
|
||
|
heads=heads,
|
||
|
init_scale=init_scale,
|
||
|
)
|
||
|
self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
|
||
|
self.input_proj = nn.Linear(input_channels, width, device=device, dtype=dtype)
|
||
|
self.output_proj = nn.Linear(width, output_channels, device=device, dtype=dtype)
|
||
|
with torch.no_grad():
|
||
|
self.output_proj.weight.zero_()
|
||
|
self.output_proj.bias.zero_()
|
||
|
if self.use_pos_emb:
|
||
|
self.register_parameter(
|
||
|
"pos_emb",
|
||
|
nn.Parameter(
|
||
|
pos_emb_init_scale
|
||
|
* torch.randn(pos_emb_n_ctx or self.n_ctx, width, device=device, dtype=dtype)
|
||
|
),
|
||
|
)
|
||
|
|
||
|
def forward(self, x: torch.Tensor, t: torch.Tensor):
|
||
|
"""
|
||
|
:param x: an [N x C x T] tensor.
|
||
|
:param t: an [N] tensor.
|
||
|
:return: an [N x C' x T] tensor.
|
||
|
"""
|
||
|
assert x.shape[-1] == self.n_ctx
|
||
|
t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))
|
||
|
return self._forward_with_cond(x, [(t_embed, self.time_token_cond)])
|
||
|
|
||
|
def _forward_with_cond(
|
||
|
self, x: torch.Tensor, cond_as_token: List[Tuple[torch.Tensor, bool]]
|
||
|
) -> torch.Tensor:
|
||
|
h = self.input_proj(x.permute(0, 2, 1)) # NCL -> NLC
|
||
|
for emb, as_token in cond_as_token:
|
||
|
if not as_token:
|
||
|
h = h + emb[:, None]
|
||
|
if self.use_pos_emb:
|
||
|
h = h + self.pos_emb
|
||
|
extra_tokens = [
|
||
|
(emb[:, None] if len(emb.shape) == 2 else emb)
|
||
|
for emb, as_token in cond_as_token
|
||
|
if as_token
|
||
|
]
|
||
|
if len(extra_tokens):
|
||
|
h = torch.cat(extra_tokens + [h], dim=1)
|
||
|
|
||
|
h = self.ln_pre(h)
|
||
|
h = self.backbone(h)
|
||
|
h = self.ln_post(h)
|
||
|
if len(extra_tokens):
|
||
|
h = h[:, sum(h.shape[1] for h in extra_tokens) :]
|
||
|
h = self.output_proj(h)
|
||
|
return h.permute(0, 2, 1)
|
||
|
|
||
|
|
||
|
class CLIPImagePointDiffusionTransformer(PointDiffusionTransformer):
|
||
|
def __init__(
|
||
|
self,
|
||
|
*,
|
||
|
device: torch.device,
|
||
|
dtype: torch.dtype,
|
||
|
n_ctx: int = 1024,
|
||
|
token_cond: bool = False,
|
||
|
cond_drop_prob: float = 0.0,
|
||
|
frozen_clip: bool = True,
|
||
|
**kwargs,
|
||
|
):
|
||
|
super().__init__(
|
||
|
device=device, dtype=dtype, n_ctx=n_ctx + int(token_cond), pos_emb_n_ctx=n_ctx, **kwargs
|
||
|
)
|
||
|
self.n_ctx = n_ctx
|
||
|
self.token_cond = token_cond
|
||
|
self.clip = (FrozenImageCLIP if frozen_clip else ImageCLIP)(device)
|
||
|
self.clip_embed = nn.Linear(
|
||
|
self.clip.feature_dim, self.backbone.width, device=device, dtype=dtype
|
||
|
)
|
||
|
self.cond_drop_prob = cond_drop_prob
|
||
|
|
||
|
def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||
|
with torch.no_grad():
|
||
|
return dict(embeddings=self.clip(batch_size, **model_kwargs))
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
x: torch.Tensor,
|
||
|
t: torch.Tensor,
|
||
|
images: Optional[Iterable[Optional[ImageType]]] = None,
|
||
|
texts: Optional[Iterable[Optional[str]]] = None,
|
||
|
embeddings: Optional[Iterable[Optional[torch.Tensor]]] = None,
|
||
|
):
|
||
|
"""
|
||
|
:param x: an [N x C x T] tensor.
|
||
|
:param t: an [N] tensor.
|
||
|
:param images: a batch of images to condition on.
|
||
|
:param texts: a batch of texts to condition on.
|
||
|
:param embeddings: a batch of CLIP embeddings to condition on.
|
||
|
:return: an [N x C' x T] tensor.
|
||
|
"""
|
||
|
assert x.shape[-1] == self.n_ctx
|
||
|
|
||
|
t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))
|
||
|
clip_out = self.clip(batch_size=len(x), images=images, texts=texts, embeddings=embeddings)
|
||
|
assert len(clip_out.shape) == 2 and clip_out.shape[0] == x.shape[0]
|
||
|
|
||
|
if self.training:
|
||
|
mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob
|
||
|
clip_out = clip_out * mask[:, None].to(clip_out)
|
||
|
|
||
|
# Rescale the features to have unit variance
|
||
|
clip_out = math.sqrt(clip_out.shape[1]) * clip_out
|
||
|
|
||
|
clip_embed = self.clip_embed(clip_out)
|
||
|
|
||
|
cond = [(clip_embed, self.token_cond), (t_embed, self.time_token_cond)]
|
||
|
return self._forward_with_cond(x, cond)
|
||
|
|
||
|
|
||
|
class CLIPImageGridPointDiffusionTransformer(PointDiffusionTransformer):
|
||
|
def __init__(
|
||
|
self,
|
||
|
*,
|
||
|
device: torch.device,
|
||
|
dtype: torch.dtype,
|
||
|
n_ctx: int = 1024,
|
||
|
cond_drop_prob: float = 0.0,
|
||
|
frozen_clip: bool = True,
|
||
|
**kwargs,
|
||
|
):
|
||
|
clip = (FrozenImageCLIP if frozen_clip else ImageCLIP)(device)
|
||
|
super().__init__(
|
||
|
device=device,
|
||
|
dtype=dtype,
|
||
|
n_ctx=n_ctx + clip.grid_size**2,
|
||
|
pos_emb_n_ctx=n_ctx,
|
||
|
**kwargs,
|
||
|
)
|
||
|
self.n_ctx = n_ctx
|
||
|
self.clip = clip
|
||
|
self.clip_embed = nn.Sequential(
|
||
|
nn.LayerNorm(
|
||
|
normalized_shape=(self.clip.grid_feature_dim,), device=device, dtype=dtype
|
||
|
),
|
||
|
nn.Linear(self.clip.grid_feature_dim, self.backbone.width, device=device, dtype=dtype),
|
||
|
)
|
||
|
self.cond_drop_prob = cond_drop_prob
|
||
|
|
||
|
def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||
|
_ = batch_size
|
||
|
with torch.no_grad():
|
||
|
return dict(embeddings=self.clip.embed_images_grid(model_kwargs["images"]))
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
x: torch.Tensor,
|
||
|
t: torch.Tensor,
|
||
|
images: Optional[Iterable[ImageType]] = None,
|
||
|
embeddings: Optional[Iterable[torch.Tensor]] = None,
|
||
|
):
|
||
|
"""
|
||
|
:param x: an [N x C x T] tensor.
|
||
|
:param t: an [N] tensor.
|
||
|
:param images: a batch of images to condition on.
|
||
|
:param embeddings: a batch of CLIP latent grids to condition on.
|
||
|
:return: an [N x C' x T] tensor.
|
||
|
"""
|
||
|
assert images is not None or embeddings is not None, "must specify images or embeddings"
|
||
|
assert images is None or embeddings is None, "cannot specify both images and embeddings"
|
||
|
assert x.shape[-1] == self.n_ctx
|
||
|
|
||
|
t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))
|
||
|
|
||
|
if images is not None:
|
||
|
clip_out = self.clip.embed_images_grid(images)
|
||
|
else:
|
||
|
clip_out = embeddings
|
||
|
|
||
|
if self.training:
|
||
|
mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob
|
||
|
clip_out = clip_out * mask[:, None, None].to(clip_out)
|
||
|
|
||
|
clip_out = clip_out.permute(0, 2, 1) # NCL -> NLC
|
||
|
clip_embed = self.clip_embed(clip_out)
|
||
|
|
||
|
cond = [(t_embed, self.time_token_cond), (clip_embed, True)]
|
||
|
return self._forward_with_cond(x, cond)
|
||
|
|
||
|
|
||
|
class UpsamplePointDiffusionTransformer(PointDiffusionTransformer):
|
||
|
def __init__(
|
||
|
self,
|
||
|
*,
|
||
|
device: torch.device,
|
||
|
dtype: torch.dtype,
|
||
|
cond_input_channels: Optional[int] = None,
|
||
|
cond_ctx: int = 1024,
|
||
|
n_ctx: int = 4096 - 1024,
|
||
|
channel_scales: Optional[Sequence[float]] = None,
|
||
|
channel_biases: Optional[Sequence[float]] = None,
|
||
|
**kwargs,
|
||
|
):
|
||
|
super().__init__(device=device, dtype=dtype, n_ctx=n_ctx + cond_ctx, **kwargs)
|
||
|
self.n_ctx = n_ctx
|
||
|
self.cond_input_channels = cond_input_channels or self.input_channels
|
||
|
self.cond_point_proj = nn.Linear(
|
||
|
self.cond_input_channels, self.backbone.width, device=device, dtype=dtype
|
||
|
)
|
||
|
|
||
|
self.register_buffer(
|
||
|
"channel_scales",
|
||
|
torch.tensor(channel_scales, dtype=dtype, device=device)
|
||
|
if channel_scales is not None
|
||
|
else None,
|
||
|
)
|
||
|
self.register_buffer(
|
||
|
"channel_biases",
|
||
|
torch.tensor(channel_biases, dtype=dtype, device=device)
|
||
|
if channel_biases is not None
|
||
|
else None,
|
||
|
)
|
||
|
|
||
|
def forward(self, x: torch.Tensor, t: torch.Tensor, *, low_res: torch.Tensor):
|
||
|
"""
|
||
|
:param x: an [N x C1 x T] tensor.
|
||
|
:param t: an [N] tensor.
|
||
|
:param low_res: an [N x C2 x T'] tensor of conditioning points.
|
||
|
:return: an [N x C3 x T] tensor.
|
||
|
"""
|
||
|
assert x.shape[-1] == self.n_ctx
|
||
|
t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))
|
||
|
low_res_embed = self._embed_low_res(low_res)
|
||
|
cond = [(t_embed, self.time_token_cond), (low_res_embed, True)]
|
||
|
return self._forward_with_cond(x, cond)
|
||
|
|
||
|
def _embed_low_res(self, x: torch.Tensor) -> torch.Tensor:
|
||
|
if self.channel_scales is not None:
|
||
|
x = x * self.channel_scales[None, :, None]
|
||
|
if self.channel_biases is not None:
|
||
|
x = x + self.channel_biases[None, :, None]
|
||
|
return self.cond_point_proj(x.permute(0, 2, 1))
|
||
|
|
||
|
|
||
|
class CLIPImageGridUpsamplePointDiffusionTransformer(UpsamplePointDiffusionTransformer):
|
||
|
def __init__(
|
||
|
self,
|
||
|
*,
|
||
|
device: torch.device,
|
||
|
dtype: torch.dtype,
|
||
|
n_ctx: int = 4096 - 1024,
|
||
|
cond_drop_prob: float = 0.0,
|
||
|
frozen_clip: bool = True,
|
||
|
**kwargs,
|
||
|
):
|
||
|
clip = (FrozenImageCLIP if frozen_clip else ImageCLIP)(device)
|
||
|
super().__init__(device=device, dtype=dtype, n_ctx=n_ctx + clip.grid_size**2, **kwargs)
|
||
|
self.n_ctx = n_ctx
|
||
|
|
||
|
self.clip = clip
|
||
|
self.clip_embed = nn.Sequential(
|
||
|
nn.LayerNorm(
|
||
|
normalized_shape=(self.clip.grid_feature_dim,), device=device, dtype=dtype
|
||
|
),
|
||
|
nn.Linear(self.clip.grid_feature_dim, self.backbone.width, device=device, dtype=dtype),
|
||
|
)
|
||
|
self.cond_drop_prob = cond_drop_prob
|
||
|
|
||
|
def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||
|
_ = batch_size
|
||
|
with torch.no_grad():
|
||
|
return dict(
|
||
|
embeddings=self.clip.embed_images_grid(model_kwargs["images"]),
|
||
|
low_res=model_kwargs["low_res"],
|
||
|
)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
x: torch.Tensor,
|
||
|
t: torch.Tensor,
|
||
|
*,
|
||
|
low_res: torch.Tensor,
|
||
|
images: Optional[Iterable[ImageType]] = None,
|
||
|
embeddings: Optional[Iterable[torch.Tensor]] = None,
|
||
|
):
|
||
|
"""
|
||
|
:param x: an [N x C1 x T] tensor.
|
||
|
:param t: an [N] tensor.
|
||
|
:param low_res: an [N x C2 x T'] tensor of conditioning points.
|
||
|
:param images: a batch of images to condition on.
|
||
|
:param embeddings: a batch of CLIP latent grids to condition on.
|
||
|
:return: an [N x C3 x T] tensor.
|
||
|
"""
|
||
|
assert x.shape[-1] == self.n_ctx
|
||
|
t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))
|
||
|
low_res_embed = self._embed_low_res(low_res)
|
||
|
|
||
|
if images is not None:
|
||
|
clip_out = self.clip.embed_images_grid(images)
|
||
|
else:
|
||
|
clip_out = embeddings
|
||
|
|
||
|
if self.training:
|
||
|
mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob
|
||
|
clip_out = clip_out * mask[:, None, None].to(clip_out)
|
||
|
|
||
|
clip_out = clip_out.permute(0, 2, 1) # NCL -> NLC
|
||
|
clip_embed = self.clip_embed(clip_out)
|
||
|
|
||
|
cond = [(t_embed, self.time_token_cond), (clip_embed, True), (low_res_embed, True)]
|
||
|
return self._forward_with_cond(x, cond)
|