You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
273 lines
11 KiB
273 lines
11 KiB
8 months ago
|
import os
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
import pytorch_lightning as pl
|
||
|
from tqdm import tqdm
|
||
|
from torchvision.transforms import v2
|
||
|
from torchvision.utils import make_grid, save_image
|
||
|
from einops import rearrange
|
||
|
|
||
|
from src.utils.train_util import instantiate_from_config
|
||
|
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler, DDPMScheduler, UNet2DConditionModel
|
||
|
from .pipeline import RefOnlyNoisedUNet
|
||
|
|
||
|
|
||
|
def scale_latents(latents):
|
||
|
latents = (latents - 0.22) * 0.75
|
||
|
return latents
|
||
|
|
||
|
|
||
|
def unscale_latents(latents):
|
||
|
latents = latents / 0.75 + 0.22
|
||
|
return latents
|
||
|
|
||
|
|
||
|
def scale_image(image):
|
||
|
image = image * 0.5 / 0.8
|
||
|
return image
|
||
|
|
||
|
|
||
|
def unscale_image(image):
|
||
|
image = image / 0.5 * 0.8
|
||
|
return image
|
||
|
|
||
|
|
||
|
def extract_into_tensor(a, t, x_shape):
|
||
|
b, *_ = t.shape
|
||
|
out = a.gather(-1, t)
|
||
|
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
||
|
|
||
|
|
||
|
class MVDiffusion(pl.LightningModule):
|
||
|
def __init__(
|
||
|
self,
|
||
|
stable_diffusion_config,
|
||
|
drop_cond_prob=0.1,
|
||
|
):
|
||
|
super(MVDiffusion, self).__init__()
|
||
|
|
||
|
self.drop_cond_prob = drop_cond_prob
|
||
|
|
||
|
self.register_schedule()
|
||
|
|
||
|
# init modules
|
||
|
pipeline = DiffusionPipeline.from_pretrained(**stable_diffusion_config)
|
||
|
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
||
|
pipeline.scheduler.config, timestep_spacing='trailing'
|
||
|
)
|
||
|
self.pipeline = pipeline
|
||
|
|
||
|
train_sched = DDPMScheduler.from_config(self.pipeline.scheduler.config)
|
||
|
if isinstance(self.pipeline.unet, UNet2DConditionModel):
|
||
|
self.pipeline.unet = RefOnlyNoisedUNet(self.pipeline.unet, train_sched, self.pipeline.scheduler)
|
||
|
|
||
|
self.train_scheduler = train_sched # use ddpm scheduler during training
|
||
|
|
||
|
self.unet = pipeline.unet
|
||
|
|
||
|
# validation output buffer
|
||
|
self.validation_step_outputs = []
|
||
|
|
||
|
def register_schedule(self):
|
||
|
self.num_timesteps = 1000
|
||
|
|
||
|
# replace scaled_linear schedule with linear schedule as Zero123++
|
||
|
beta_start = 0.00085
|
||
|
beta_end = 0.0120
|
||
|
betas = torch.linspace(beta_start, beta_end, 1000, dtype=torch.float32)
|
||
|
|
||
|
alphas = 1. - betas
|
||
|
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||
|
alphas_cumprod_prev = torch.cat([torch.ones(1, dtype=torch.float64), alphas_cumprod[:-1]], 0)
|
||
|
|
||
|
self.register_buffer('betas', betas.float())
|
||
|
self.register_buffer('alphas_cumprod', alphas_cumprod.float())
|
||
|
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev.float())
|
||
|
|
||
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||
|
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod).float())
|
||
|
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1 - alphas_cumprod).float())
|
||
|
|
||
|
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod).float())
|
||
|
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1).float())
|
||
|
|
||
|
def on_fit_start(self):
|
||
|
device = torch.device(f'cuda:{self.global_rank}')
|
||
|
self.pipeline.to(device)
|
||
|
if self.global_rank == 0:
|
||
|
os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True)
|
||
|
os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True)
|
||
|
|
||
|
def prepare_batch_data(self, batch):
|
||
|
# prepare stable diffusion input
|
||
|
cond_imgs = batch['cond_imgs'] # (B, C, H, W)
|
||
|
cond_imgs = cond_imgs.to(self.device)
|
||
|
|
||
|
# random resize the condition image
|
||
|
cond_size = np.random.randint(128, 513)
|
||
|
cond_imgs = v2.functional.resize(cond_imgs, cond_size, interpolation=3, antialias=True).clamp(0, 1)
|
||
|
|
||
|
target_imgs = batch['target_imgs'] # (B, 6, C, H, W)
|
||
|
target_imgs = v2.functional.resize(target_imgs, 320, interpolation=3, antialias=True).clamp(0, 1)
|
||
|
target_imgs = rearrange(target_imgs, 'b (x y) c h w -> b c (x h) (y w)', x=3, y=2) # (B, C, 3H, 2W)
|
||
|
target_imgs = target_imgs.to(self.device)
|
||
|
|
||
|
return cond_imgs, target_imgs
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def forward_vision_encoder(self, images):
|
||
|
dtype = next(self.pipeline.vision_encoder.parameters()).dtype
|
||
|
image_pil = [v2.functional.to_pil_image(images[i]) for i in range(images.shape[0])]
|
||
|
image_pt = self.pipeline.feature_extractor_clip(images=image_pil, return_tensors="pt").pixel_values
|
||
|
image_pt = image_pt.to(device=self.device, dtype=dtype)
|
||
|
global_embeds = self.pipeline.vision_encoder(image_pt, output_hidden_states=False).image_embeds
|
||
|
global_embeds = global_embeds.unsqueeze(-2)
|
||
|
|
||
|
encoder_hidden_states = self.pipeline._encode_prompt("", self.device, 1, False)[0]
|
||
|
ramp = global_embeds.new_tensor(self.pipeline.config.ramping_coefficients).unsqueeze(-1)
|
||
|
encoder_hidden_states = encoder_hidden_states + global_embeds * ramp
|
||
|
|
||
|
return encoder_hidden_states
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def encode_condition_image(self, images):
|
||
|
dtype = next(self.pipeline.vae.parameters()).dtype
|
||
|
image_pil = [v2.functional.to_pil_image(images[i]) for i in range(images.shape[0])]
|
||
|
image_pt = self.pipeline.feature_extractor_vae(images=image_pil, return_tensors="pt").pixel_values
|
||
|
image_pt = image_pt.to(device=self.device, dtype=dtype)
|
||
|
latents = self.pipeline.vae.encode(image_pt).latent_dist.sample()
|
||
|
return latents
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def encode_target_images(self, images):
|
||
|
dtype = next(self.pipeline.vae.parameters()).dtype
|
||
|
# equals to scaling images to [-1, 1] first and then call scale_image
|
||
|
images = (images - 0.5) / 0.8 # [-0.625, 0.625]
|
||
|
posterior = self.pipeline.vae.encode(images.to(dtype)).latent_dist
|
||
|
latents = posterior.sample() * self.pipeline.vae.config.scaling_factor
|
||
|
latents = scale_latents(latents)
|
||
|
return latents
|
||
|
|
||
|
def forward_unet(self, latents, t, prompt_embeds, cond_latents):
|
||
|
dtype = next(self.pipeline.unet.parameters()).dtype
|
||
|
latents = latents.to(dtype)
|
||
|
prompt_embeds = prompt_embeds.to(dtype)
|
||
|
cond_latents = cond_latents.to(dtype)
|
||
|
cross_attention_kwargs = dict(cond_lat=cond_latents)
|
||
|
pred_noise = self.pipeline.unet(
|
||
|
latents,
|
||
|
t,
|
||
|
encoder_hidden_states=prompt_embeds,
|
||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||
|
return_dict=False,
|
||
|
)[0]
|
||
|
return pred_noise
|
||
|
|
||
|
def predict_start_from_z_and_v(self, x_t, t, v):
|
||
|
return (
|
||
|
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
|
||
|
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
|
||
|
)
|
||
|
|
||
|
def get_v(self, x, noise, t):
|
||
|
return (
|
||
|
extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
|
||
|
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
|
||
|
)
|
||
|
|
||
|
def training_step(self, batch, batch_idx):
|
||
|
# get input
|
||
|
cond_imgs, target_imgs = self.prepare_batch_data(batch)
|
||
|
|
||
|
# sample random timestep
|
||
|
B = cond_imgs.shape[0]
|
||
|
|
||
|
t = torch.randint(0, self.num_timesteps, size=(B,)).long().to(self.device)
|
||
|
|
||
|
# classifier-free guidance
|
||
|
if np.random.rand() < self.drop_cond_prob:
|
||
|
prompt_embeds = self.pipeline._encode_prompt([""]*B, self.device, 1, False)
|
||
|
cond_latents = self.encode_condition_image(torch.zeros_like(cond_imgs))
|
||
|
else:
|
||
|
prompt_embeds = self.forward_vision_encoder(cond_imgs)
|
||
|
cond_latents = self.encode_condition_image(cond_imgs)
|
||
|
|
||
|
latents = self.encode_target_images(target_imgs)
|
||
|
noise = torch.randn_like(latents)
|
||
|
latents_noisy = self.train_scheduler.add_noise(latents, noise, t)
|
||
|
|
||
|
v_pred = self.forward_unet(latents_noisy, t, prompt_embeds, cond_latents)
|
||
|
v_target = self.get_v(latents, noise, t)
|
||
|
|
||
|
loss, loss_dict = self.compute_loss(v_pred, v_target)
|
||
|
|
||
|
# logging
|
||
|
self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||
|
self.log("global_step", self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False)
|
||
|
lr = self.optimizers().param_groups[0]['lr']
|
||
|
self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
|
||
|
|
||
|
if self.global_step % 500 == 0 and self.global_rank == 0:
|
||
|
with torch.no_grad():
|
||
|
latents_pred = self.predict_start_from_z_and_v(latents_noisy, t, v_pred)
|
||
|
|
||
|
latents = unscale_latents(latents_pred)
|
||
|
images = unscale_image(self.pipeline.vae.decode(latents / self.pipeline.vae.config.scaling_factor, return_dict=False)[0]) # [-1, 1]
|
||
|
images = (images * 0.5 + 0.5).clamp(0, 1)
|
||
|
images = torch.cat([target_imgs, images], dim=-2)
|
||
|
|
||
|
grid = make_grid(images, nrow=images.shape[0], normalize=True, value_range=(0, 1))
|
||
|
save_image(grid, os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png'))
|
||
|
|
||
|
return loss
|
||
|
|
||
|
def compute_loss(self, noise_pred, noise_gt):
|
||
|
loss = F.mse_loss(noise_pred, noise_gt)
|
||
|
|
||
|
prefix = 'train'
|
||
|
loss_dict = {}
|
||
|
loss_dict.update({f'{prefix}/loss': loss})
|
||
|
|
||
|
return loss, loss_dict
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def validation_step(self, batch, batch_idx):
|
||
|
# get input
|
||
|
cond_imgs, target_imgs = self.prepare_batch_data(batch)
|
||
|
|
||
|
images_pil = [v2.functional.to_pil_image(cond_imgs[i]) for i in range(cond_imgs.shape[0])]
|
||
|
|
||
|
outputs = []
|
||
|
for cond_img in images_pil:
|
||
|
latent = self.pipeline(cond_img, num_inference_steps=75, output_type='latent').images
|
||
|
image = unscale_image(self.pipeline.vae.decode(latent / self.pipeline.vae.config.scaling_factor, return_dict=False)[0]) # [-1, 1]
|
||
|
image = (image * 0.5 + 0.5).clamp(0, 1)
|
||
|
outputs.append(image)
|
||
|
outputs = torch.cat(outputs, dim=0).to(self.device)
|
||
|
images = torch.cat([target_imgs, outputs], dim=-2)
|
||
|
|
||
|
self.validation_step_outputs.append(images)
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def on_validation_epoch_end(self):
|
||
|
images = torch.cat(self.validation_step_outputs, dim=0)
|
||
|
|
||
|
all_images = self.all_gather(images)
|
||
|
all_images = rearrange(all_images, 'r b c h w -> (r b) c h w')
|
||
|
|
||
|
if self.global_rank == 0:
|
||
|
grid = make_grid(all_images, nrow=8, normalize=True, value_range=(0, 1))
|
||
|
save_image(grid, os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png'))
|
||
|
|
||
|
self.validation_step_outputs.clear() # free memory
|
||
|
|
||
|
def configure_optimizers(self):
|
||
|
lr = self.learning_rate
|
||
|
|
||
|
optimizer = torch.optim.AdamW(self.unet.parameters(), lr=lr)
|
||
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 3000, eta_min=lr/4)
|
||
|
|
||
|
return {'optimizer': optimizer, 'lr_scheduler': scheduler}
|