You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
91 lines
2.8 KiB
91 lines
2.8 KiB
2 years ago
|
from typing import Any, Callable, Dict, Optional
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
|
||
|
from .gaussian_diffusion import GaussianDiffusion
|
||
|
from .k_diffusion import karras_sample
|
||
|
|
||
|
DEFAULT_KARRAS_STEPS = 64
|
||
|
DEFAULT_KARRAS_SIGMA_MIN = 1e-3
|
||
|
DEFAULT_KARRAS_SIGMA_MAX = 160
|
||
|
DEFAULT_KARRAS_S_CHURN = 0.0
|
||
|
|
||
|
|
||
|
def uncond_guide_model(
|
||
|
model: Callable[..., torch.Tensor], scale: float
|
||
|
) -> Callable[..., torch.Tensor]:
|
||
|
def model_fn(x_t, ts, **kwargs):
|
||
|
half = x_t[: len(x_t) // 2]
|
||
|
combined = torch.cat([half, half], dim=0)
|
||
|
model_out = model(combined, ts, **kwargs)
|
||
|
eps, rest = model_out[:, :3], model_out[:, 3:]
|
||
|
cond_eps, uncond_eps = torch.chunk(eps, 2, dim=0)
|
||
|
half_eps = uncond_eps + scale * (cond_eps - uncond_eps)
|
||
|
eps = torch.cat([half_eps, half_eps], dim=0)
|
||
|
return torch.cat([eps, rest], dim=1)
|
||
|
|
||
|
return model_fn
|
||
|
|
||
|
|
||
|
def sample_latents(
|
||
|
*,
|
||
|
batch_size: int,
|
||
|
model: nn.Module,
|
||
|
diffusion: GaussianDiffusion,
|
||
|
model_kwargs: Dict[str, Any],
|
||
|
guidance_scale: float,
|
||
|
clip_denoised: bool,
|
||
|
use_fp16: bool,
|
||
|
use_karras: bool,
|
||
|
karras_steps: int,
|
||
|
sigma_min: float,
|
||
|
sigma_max: float,
|
||
|
s_churn: float,
|
||
|
device: Optional[torch.device] = None,
|
||
|
progress: bool = False,
|
||
|
) -> torch.Tensor:
|
||
|
sample_shape = (batch_size, model.d_latent)
|
||
|
|
||
|
if device is None:
|
||
|
device = next(model.parameters()).device
|
||
|
|
||
|
if hasattr(model, "cached_model_kwargs"):
|
||
|
model_kwargs = model.cached_model_kwargs(batch_size, model_kwargs)
|
||
|
if guidance_scale != 1.0 and guidance_scale != 0.0:
|
||
|
for k, v in model_kwargs.copy().items():
|
||
|
model_kwargs[k] = torch.cat([v, torch.zeros_like(v)], dim=0)
|
||
|
|
||
|
sample_shape = (batch_size, model.d_latent)
|
||
|
with torch.autocast(device_type=device.type, enabled=use_fp16):
|
||
|
if use_karras:
|
||
|
samples = karras_sample(
|
||
|
diffusion=diffusion,
|
||
|
model=model,
|
||
|
shape=sample_shape,
|
||
|
steps=karras_steps,
|
||
|
clip_denoised=clip_denoised,
|
||
|
model_kwargs=model_kwargs,
|
||
|
device=device,
|
||
|
sigma_min=sigma_min,
|
||
|
sigma_max=sigma_max,
|
||
|
s_churn=s_churn,
|
||
|
guidance_scale=guidance_scale,
|
||
|
progress=progress,
|
||
|
)
|
||
|
else:
|
||
|
internal_batch_size = batch_size
|
||
|
if guidance_scale != 1.0:
|
||
|
model = uncond_guide_model(model, guidance_scale)
|
||
|
internal_batch_size *= 2
|
||
|
samples = diffusion.p_sample_loop(
|
||
|
model,
|
||
|
shape=(internal_batch_size, *sample_shape[1:]),
|
||
|
model_kwargs=model_kwargs,
|
||
|
device=device,
|
||
|
clip_denoised=clip_denoised,
|
||
|
progress=progress,
|
||
|
)
|
||
|
|
||
|
return samples
|