a fork of shap-e for gc
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

125 lines
4.2 KiB

from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
import numpy as np
import torch.nn as nn
from torch import torch
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
from shap_e.util.collections import AttrDict
class LatentBottleneck(nn.Module, ABC):
def __init__(self, *, device: torch.device, d_latent: int):
super().__init__()
self.device = device
self.d_latent = d_latent
@abstractmethod
def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:
pass
class LatentWarp(nn.Module, ABC):
def __init__(self, *, device: torch.device):
super().__init__()
self.device = device
@abstractmethod
def warp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:
pass
@abstractmethod
def unwarp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:
pass
class IdentityLatentWarp(LatentWarp):
def warp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:
_ = options
return x
def unwarp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:
_ = options
return x
class Tan2LatentWarp(LatentWarp):
def __init__(self, *, coeff1: float = 1.0, device: torch.device):
super().__init__(device=device)
self.coeff1 = coeff1
self.scale = np.tan(np.tan(1.0) * coeff1)
def warp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:
_ = options
return ((x.float().tan() * self.coeff1).tan() / self.scale).to(x.dtype)
def unwarp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:
_ = options
return ((x.float() * self.scale).arctan() / self.coeff1).arctan().to(x.dtype)
class IdentityLatentBottleneck(LatentBottleneck):
def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:
_ = options
return x
class ClampNoiseBottleneck(LatentBottleneck):
def __init__(self, *, device: torch.device, d_latent: int, noise_scale: float):
super().__init__(device=device, d_latent=d_latent)
self.noise_scale = noise_scale
def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:
_ = options
x = x.tanh()
if not self.training:
return x
return x + torch.randn_like(x) * self.noise_scale
class ClampDiffusionNoiseBottleneck(LatentBottleneck):
def __init__(
self,
*,
device: torch.device,
d_latent: int,
diffusion: Dict[str, Any],
diffusion_prob: float = 1.0,
):
super().__init__(device=device, d_latent=d_latent)
self.diffusion = diffusion_from_config(diffusion)
self.diffusion_prob = diffusion_prob
def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:
_ = options
x = x.tanh()
if not self.training:
return x
t = torch.randint(low=0, high=self.diffusion.num_timesteps, size=(len(x),), device=x.device)
t = torch.where(
torch.rand(len(x), device=x.device) < self.diffusion_prob, t, torch.zeros_like(t)
)
return self.diffusion.q_sample(x, t)
def latent_bottleneck_from_config(config: Dict[str, Any], device: torch.device, d_latent: int):
name = config.pop("name")
if name == "clamp_noise":
return ClampNoiseBottleneck(**config, device=device, d_latent=d_latent)
elif name == "identity":
return IdentityLatentBottleneck(**config, device=device, d_latent=d_latent)
elif name == "clamp_diffusion_noise":
return ClampDiffusionNoiseBottleneck(**config, device=device, d_latent=d_latent)
else:
raise ValueError(f"unknown latent bottleneck: {name}")
def latent_warp_from_config(config: Dict[str, Any], device: torch.device):
name = config.pop("name")
if name == "identity":
return IdentityLatentWarp(**config, device=device)
elif name == "tan2":
return Tan2LatentWarp(**config, device=device)
else:
raise ValueError(f"unknown latent warping function: {name}")