import math from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple import torch import torch.nn as nn from shap_e.models.nn.checkpoint import checkpoint from .pretrained_clip import FrozenImageCLIP, ImageCLIP, ImageType from .util import timestep_embedding def init_linear(l, stddev): nn.init.normal_(l.weight, std=stddev) if l.bias is not None: nn.init.constant_(l.bias, 0.0) class MultiheadAttention(nn.Module): def __init__( self, *, device: torch.device, dtype: torch.dtype, n_ctx: int, width: int, heads: int, init_scale: float, ): super().__init__() self.n_ctx = n_ctx self.width = width self.heads = heads self.c_qkv = nn.Linear(width, width * 3, device=device, dtype=dtype) self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx) init_linear(self.c_qkv, init_scale) init_linear(self.c_proj, init_scale) def forward(self, x): x = self.c_qkv(x) x = checkpoint(self.attention, (x,), (), True) x = self.c_proj(x) return x class MLP(nn.Module): def __init__(self, *, device: torch.device, dtype: torch.dtype, width: int, init_scale: float): super().__init__() self.width = width self.c_fc = nn.Linear(width, width * 4, device=device, dtype=dtype) self.c_proj = nn.Linear(width * 4, width, device=device, dtype=dtype) self.gelu = nn.GELU() init_linear(self.c_fc, init_scale) init_linear(self.c_proj, init_scale) def forward(self, x): return self.c_proj(self.gelu(self.c_fc(x))) class QKVMultiheadAttention(nn.Module): def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int): super().__init__() self.device = device self.dtype = dtype self.heads = heads self.n_ctx = n_ctx def forward(self, qkv): bs, n_ctx, width = qkv.shape attn_ch = width // self.heads // 3 scale = 1 / math.sqrt(math.sqrt(attn_ch)) qkv = qkv.view(bs, n_ctx, self.heads, -1) q, k, v = torch.split(qkv, attn_ch, dim=-1) weight = torch.einsum( "bthc,bshc->bhts", q * scale, k * scale ) # More stable with f16 than dividing afterwards wdtype = weight.dtype weight = torch.softmax(weight.float(), dim=-1).type(wdtype) return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) class ResidualAttentionBlock(nn.Module): def __init__( self, *, device: torch.device, dtype: torch.dtype, n_ctx: int, width: int, heads: int, init_scale: float = 1.0, ): super().__init__() self.attn = MultiheadAttention( device=device, dtype=dtype, n_ctx=n_ctx, width=width, heads=heads, init_scale=init_scale, ) self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype) def forward(self, x: torch.Tensor): x = x + self.attn(self.ln_1(x)) x = x + self.mlp(self.ln_2(x)) return x class Transformer(nn.Module): def __init__( self, *, device: torch.device, dtype: torch.dtype, n_ctx: int, width: int, layers: int, heads: int, init_scale: float = 0.25, ): super().__init__() self.n_ctx = n_ctx self.width = width self.layers = layers init_scale = init_scale * math.sqrt(1.0 / width) self.resblocks = nn.ModuleList( [ ResidualAttentionBlock( device=device, dtype=dtype, n_ctx=n_ctx, width=width, heads=heads, init_scale=init_scale, ) for _ in range(layers) ] ) def forward(self, x: torch.Tensor): for block in self.resblocks: x = block(x) return x class PointDiffusionTransformer(nn.Module): def __init__( self, *, device: torch.device, dtype: torch.dtype, input_channels: int = 3, output_channels: int = 3, n_ctx: int = 1024, width: int = 512, layers: int = 12, heads: int = 8, init_scale: float = 0.25, time_token_cond: bool = False, use_pos_emb: bool = False, pos_emb_init_scale: float = 1.0, pos_emb_n_ctx: Optional[int] = None, ): super().__init__() self.input_channels = input_channels self.output_channels = output_channels self.n_ctx = n_ctx self.time_token_cond = time_token_cond self.use_pos_emb = use_pos_emb self.time_embed = MLP( device=device, dtype=dtype, width=width, init_scale=init_scale * math.sqrt(1.0 / width) ) self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype) self.backbone = Transformer( device=device, dtype=dtype, n_ctx=n_ctx + int(time_token_cond), width=width, layers=layers, heads=heads, init_scale=init_scale, ) self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) self.input_proj = nn.Linear(input_channels, width, device=device, dtype=dtype) self.output_proj = nn.Linear(width, output_channels, device=device, dtype=dtype) with torch.no_grad(): self.output_proj.weight.zero_() self.output_proj.bias.zero_() if self.use_pos_emb: self.register_parameter( "pos_emb", nn.Parameter( pos_emb_init_scale * torch.randn(pos_emb_n_ctx or self.n_ctx, width, device=device, dtype=dtype) ), ) def forward(self, x: torch.Tensor, t: torch.Tensor): """ :param x: an [N x C x T] tensor. :param t: an [N] tensor. :return: an [N x C' x T] tensor. """ assert x.shape[-1] == self.n_ctx t_embed = self.time_embed(timestep_embedding(t, self.backbone.width)) return self._forward_with_cond(x, [(t_embed, self.time_token_cond)]) def _forward_with_cond( self, x: torch.Tensor, cond_as_token: List[Tuple[torch.Tensor, bool]] ) -> torch.Tensor: h = self.input_proj(x.permute(0, 2, 1)) # NCL -> NLC for emb, as_token in cond_as_token: if not as_token: h = h + emb[:, None] if self.use_pos_emb: h = h + self.pos_emb extra_tokens = [ (emb[:, None] if len(emb.shape) == 2 else emb) for emb, as_token in cond_as_token if as_token ] if len(extra_tokens): h = torch.cat(extra_tokens + [h], dim=1) h = self.ln_pre(h) h = self.backbone(h) h = self.ln_post(h) if len(extra_tokens): h = h[:, sum(h.shape[1] for h in extra_tokens) :] h = self.output_proj(h) return h.permute(0, 2, 1) class CLIPImagePointDiffusionTransformer(PointDiffusionTransformer): def __init__( self, *, device: torch.device, dtype: torch.dtype, n_ctx: int = 1024, token_cond: bool = False, cond_drop_prob: float = 0.0, frozen_clip: bool = True, **kwargs, ): super().__init__( device=device, dtype=dtype, n_ctx=n_ctx + int(token_cond), pos_emb_n_ctx=n_ctx, **kwargs ) self.n_ctx = n_ctx self.token_cond = token_cond self.clip = (FrozenImageCLIP if frozen_clip else ImageCLIP)(device) self.clip_embed = nn.Linear( self.clip.feature_dim, self.backbone.width, device=device, dtype=dtype ) self.cond_drop_prob = cond_drop_prob def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]) -> Dict[str, Any]: with torch.no_grad(): return dict(embeddings=self.clip(batch_size, **model_kwargs)) def forward( self, x: torch.Tensor, t: torch.Tensor, images: Optional[Iterable[Optional[ImageType]]] = None, texts: Optional[Iterable[Optional[str]]] = None, embeddings: Optional[Iterable[Optional[torch.Tensor]]] = None, ): """ :param x: an [N x C x T] tensor. :param t: an [N] tensor. :param images: a batch of images to condition on. :param texts: a batch of texts to condition on. :param embeddings: a batch of CLIP embeddings to condition on. :return: an [N x C' x T] tensor. """ assert x.shape[-1] == self.n_ctx t_embed = self.time_embed(timestep_embedding(t, self.backbone.width)) clip_out = self.clip(batch_size=len(x), images=images, texts=texts, embeddings=embeddings) assert len(clip_out.shape) == 2 and clip_out.shape[0] == x.shape[0] if self.training: mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob clip_out = clip_out * mask[:, None].to(clip_out) # Rescale the features to have unit variance clip_out = math.sqrt(clip_out.shape[1]) * clip_out clip_embed = self.clip_embed(clip_out) cond = [(clip_embed, self.token_cond), (t_embed, self.time_token_cond)] return self._forward_with_cond(x, cond) class CLIPImageGridPointDiffusionTransformer(PointDiffusionTransformer): def __init__( self, *, device: torch.device, dtype: torch.dtype, n_ctx: int = 1024, cond_drop_prob: float = 0.0, frozen_clip: bool = True, **kwargs, ): clip = (FrozenImageCLIP if frozen_clip else ImageCLIP)(device) super().__init__( device=device, dtype=dtype, n_ctx=n_ctx + clip.grid_size**2, pos_emb_n_ctx=n_ctx, **kwargs, ) self.n_ctx = n_ctx self.clip = clip self.clip_embed = nn.Sequential( nn.LayerNorm( normalized_shape=(self.clip.grid_feature_dim,), device=device, dtype=dtype ), nn.Linear(self.clip.grid_feature_dim, self.backbone.width, device=device, dtype=dtype), ) self.cond_drop_prob = cond_drop_prob def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]) -> Dict[str, Any]: _ = batch_size with torch.no_grad(): return dict(embeddings=self.clip.embed_images_grid(model_kwargs["images"])) def forward( self, x: torch.Tensor, t: torch.Tensor, images: Optional[Iterable[ImageType]] = None, embeddings: Optional[Iterable[torch.Tensor]] = None, ): """ :param x: an [N x C x T] tensor. :param t: an [N] tensor. :param images: a batch of images to condition on. :param embeddings: a batch of CLIP latent grids to condition on. :return: an [N x C' x T] tensor. """ assert images is not None or embeddings is not None, "must specify images or embeddings" assert images is None or embeddings is None, "cannot specify both images and embeddings" assert x.shape[-1] == self.n_ctx t_embed = self.time_embed(timestep_embedding(t, self.backbone.width)) if images is not None: clip_out = self.clip.embed_images_grid(images) else: clip_out = embeddings if self.training: mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob clip_out = clip_out * mask[:, None, None].to(clip_out) clip_out = clip_out.permute(0, 2, 1) # NCL -> NLC clip_embed = self.clip_embed(clip_out) cond = [(t_embed, self.time_token_cond), (clip_embed, True)] return self._forward_with_cond(x, cond) class UpsamplePointDiffusionTransformer(PointDiffusionTransformer): def __init__( self, *, device: torch.device, dtype: torch.dtype, cond_input_channels: Optional[int] = None, cond_ctx: int = 1024, n_ctx: int = 4096 - 1024, channel_scales: Optional[Sequence[float]] = None, channel_biases: Optional[Sequence[float]] = None, **kwargs, ): super().__init__(device=device, dtype=dtype, n_ctx=n_ctx + cond_ctx, **kwargs) self.n_ctx = n_ctx self.cond_input_channels = cond_input_channels or self.input_channels self.cond_point_proj = nn.Linear( self.cond_input_channels, self.backbone.width, device=device, dtype=dtype ) self.register_buffer( "channel_scales", torch.tensor(channel_scales, dtype=dtype, device=device) if channel_scales is not None else None, ) self.register_buffer( "channel_biases", torch.tensor(channel_biases, dtype=dtype, device=device) if channel_biases is not None else None, ) def forward(self, x: torch.Tensor, t: torch.Tensor, *, low_res: torch.Tensor): """ :param x: an [N x C1 x T] tensor. :param t: an [N] tensor. :param low_res: an [N x C2 x T'] tensor of conditioning points. :return: an [N x C3 x T] tensor. """ assert x.shape[-1] == self.n_ctx t_embed = self.time_embed(timestep_embedding(t, self.backbone.width)) low_res_embed = self._embed_low_res(low_res) cond = [(t_embed, self.time_token_cond), (low_res_embed, True)] return self._forward_with_cond(x, cond) def _embed_low_res(self, x: torch.Tensor) -> torch.Tensor: if self.channel_scales is not None: x = x * self.channel_scales[None, :, None] if self.channel_biases is not None: x = x + self.channel_biases[None, :, None] return self.cond_point_proj(x.permute(0, 2, 1)) class CLIPImageGridUpsamplePointDiffusionTransformer(UpsamplePointDiffusionTransformer): def __init__( self, *, device: torch.device, dtype: torch.dtype, n_ctx: int = 4096 - 1024, cond_drop_prob: float = 0.0, frozen_clip: bool = True, **kwargs, ): clip = (FrozenImageCLIP if frozen_clip else ImageCLIP)(device) super().__init__(device=device, dtype=dtype, n_ctx=n_ctx + clip.grid_size**2, **kwargs) self.n_ctx = n_ctx self.clip = clip self.clip_embed = nn.Sequential( nn.LayerNorm( normalized_shape=(self.clip.grid_feature_dim,), device=device, dtype=dtype ), nn.Linear(self.clip.grid_feature_dim, self.backbone.width, device=device, dtype=dtype), ) self.cond_drop_prob = cond_drop_prob def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]) -> Dict[str, Any]: _ = batch_size with torch.no_grad(): return dict( embeddings=self.clip.embed_images_grid(model_kwargs["images"]), low_res=model_kwargs["low_res"], ) def forward( self, x: torch.Tensor, t: torch.Tensor, *, low_res: torch.Tensor, images: Optional[Iterable[ImageType]] = None, embeddings: Optional[Iterable[torch.Tensor]] = None, ): """ :param x: an [N x C1 x T] tensor. :param t: an [N] tensor. :param low_res: an [N x C2 x T'] tensor of conditioning points. :param images: a batch of images to condition on. :param embeddings: a batch of CLIP latent grids to condition on. :return: an [N x C3 x T] tensor. """ assert x.shape[-1] == self.n_ctx t_embed = self.time_embed(timestep_embedding(t, self.backbone.width)) low_res_embed = self._embed_low_res(low_res) if images is not None: clip_out = self.clip.embed_images_grid(images) else: clip_out = embeddings if self.training: mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob clip_out = clip_out * mask[:, None, None].to(clip_out) clip_out = clip_out.permute(0, 2, 1) # NCL -> NLC clip_embed = self.clip_embed(clip_out) cond = [(t_embed, self.time_token_cond), (clip_embed, True), (low_res_embed, True)] return self._forward_with_cond(x, cond)