a fork of shap-e for gc
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

90 lines
2.8 KiB

from typing import Any, Callable, Dict, Optional
import torch
import torch.nn as nn
from .gaussian_diffusion import GaussianDiffusion
from .k_diffusion import karras_sample
DEFAULT_KARRAS_STEPS = 64
DEFAULT_KARRAS_SIGMA_MIN = 1e-3
DEFAULT_KARRAS_SIGMA_MAX = 160
DEFAULT_KARRAS_S_CHURN = 0.0
def uncond_guide_model(
model: Callable[..., torch.Tensor], scale: float
) -> Callable[..., torch.Tensor]:
def model_fn(x_t, ts, **kwargs):
half = x_t[: len(x_t) // 2]
combined = torch.cat([half, half], dim=0)
model_out = model(combined, ts, **kwargs)
eps, rest = model_out[:, :3], model_out[:, 3:]
cond_eps, uncond_eps = torch.chunk(eps, 2, dim=0)
half_eps = uncond_eps + scale * (cond_eps - uncond_eps)
eps = torch.cat([half_eps, half_eps], dim=0)
return torch.cat([eps, rest], dim=1)
return model_fn
def sample_latents(
*,
batch_size: int,
model: nn.Module,
diffusion: GaussianDiffusion,
model_kwargs: Dict[str, Any],
guidance_scale: float,
clip_denoised: bool,
use_fp16: bool,
use_karras: bool,
karras_steps: int,
sigma_min: float,
sigma_max: float,
s_churn: float,
device: Optional[torch.device] = None,
progress: bool = False,
) -> torch.Tensor:
sample_shape = (batch_size, model.d_latent)
if device is None:
device = next(model.parameters()).device
if hasattr(model, "cached_model_kwargs"):
model_kwargs = model.cached_model_kwargs(batch_size, model_kwargs)
if guidance_scale != 1.0 and guidance_scale != 0.0:
for k, v in model_kwargs.copy().items():
model_kwargs[k] = torch.cat([v, torch.zeros_like(v)], dim=0)
sample_shape = (batch_size, model.d_latent)
with torch.autocast(device_type=device.type, enabled=use_fp16):
if use_karras:
samples = karras_sample(
diffusion=diffusion,
model=model,
shape=sample_shape,
steps=karras_steps,
clip_denoised=clip_denoised,
model_kwargs=model_kwargs,
device=device,
sigma_min=sigma_min,
sigma_max=sigma_max,
s_churn=s_churn,
guidance_scale=guidance_scale,
progress=progress,
)
else:
internal_batch_size = batch_size
if guidance_scale != 1.0:
model = uncond_guide_model(model, guidance_scale)
internal_batch_size *= 2
samples = diffusion.p_sample_loop(
model,
shape=(internal_batch_size, *sample_shape[1:]),
model_kwargs=model_kwargs,
device=device,
clip_denoised=clip_denoised,
progress=progress,
)
return samples