bluestyle97
9 months ago
4 changed files with 454 additions and 3 deletions
@ -0,0 +1,47 @@ |
|||
model: |
|||
base_learning_rate: 1.0e-05 |
|||
target: zero123plus.model.MVDiffusion |
|||
params: |
|||
drop_cond_prob: 0.1 |
|||
|
|||
stable_diffusion_config: |
|||
pretrained_model_name_or_path: sudo-ai/zero123plus-v1.2 |
|||
custom_pipeline: ./zero123plus |
|||
|
|||
data: |
|||
target: src.data.objaverse_zero123plus.DataModuleFromConfig |
|||
params: |
|||
batch_size: 6 |
|||
num_workers: 8 |
|||
train: |
|||
target: src.data.objaverse_zero123plus.ObjaverseData |
|||
params: |
|||
root_dir: data/objaverse |
|||
meta_fname: lvis-annotations.json |
|||
image_dir: rendering_zero123plus |
|||
validation: false |
|||
validation: |
|||
target: src.data.objaverse_zero123plus.ObjaverseData |
|||
params: |
|||
root_dir: data/objaverse |
|||
meta_fname: lvis-annotations.json |
|||
image_dir: rendering_zero123plus |
|||
validation: true |
|||
|
|||
|
|||
lightning: |
|||
modelcheckpoint: |
|||
params: |
|||
every_n_train_steps: 1000 |
|||
save_top_k: -1 |
|||
save_last: true |
|||
callbacks: {} |
|||
|
|||
trainer: |
|||
benchmark: true |
|||
max_epochs: -1 |
|||
gradient_clip_val: 1.0 |
|||
val_check_interval: 1000 |
|||
num_sanity_val_steps: 0 |
|||
accumulate_grad_batches: 1 |
|||
check_val_every_n_epoch: null # if not set this, validation does not run |
@ -0,0 +1,124 @@ |
|||
import os |
|||
import json |
|||
import numpy as np |
|||
import webdataset as wds |
|||
import pytorch_lightning as pl |
|||
import torch |
|||
from torch.utils.data import Dataset |
|||
from torch.utils.data.distributed import DistributedSampler |
|||
from PIL import Image |
|||
from pathlib import Path |
|||
|
|||
from src.utils.train_util import instantiate_from_config |
|||
|
|||
|
|||
class DataModuleFromConfig(pl.LightningDataModule): |
|||
def __init__( |
|||
self, |
|||
batch_size=8, |
|||
num_workers=4, |
|||
train=None, |
|||
validation=None, |
|||
test=None, |
|||
**kwargs, |
|||
): |
|||
super().__init__() |
|||
|
|||
self.batch_size = batch_size |
|||
self.num_workers = num_workers |
|||
|
|||
self.dataset_configs = dict() |
|||
if train is not None: |
|||
self.dataset_configs['train'] = train |
|||
if validation is not None: |
|||
self.dataset_configs['validation'] = validation |
|||
if test is not None: |
|||
self.dataset_configs['test'] = test |
|||
|
|||
def setup(self, stage): |
|||
|
|||
if stage in ['fit']: |
|||
self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs) |
|||
else: |
|||
raise NotImplementedError |
|||
|
|||
def train_dataloader(self): |
|||
|
|||
sampler = DistributedSampler(self.datasets['train']) |
|||
return wds.WebLoader(self.datasets['train'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler) |
|||
|
|||
def val_dataloader(self): |
|||
|
|||
sampler = DistributedSampler(self.datasets['validation']) |
|||
return wds.WebLoader(self.datasets['validation'], batch_size=4, num_workers=self.num_workers, shuffle=False, sampler=sampler) |
|||
|
|||
def test_dataloader(self): |
|||
|
|||
return wds.WebLoader(self.datasets['test'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) |
|||
|
|||
|
|||
class ObjaverseData(Dataset): |
|||
def __init__(self, |
|||
root_dir='objaverse/', |
|||
meta_fname='valid_paths.json', |
|||
image_dir='rendering_zero123plus', |
|||
validation=False, |
|||
): |
|||
self.root_dir = Path(root_dir) |
|||
self.image_dir = image_dir |
|||
|
|||
with open(os.path.join(root_dir, meta_fname)) as f: |
|||
lvis_dict = json.load(f) |
|||
paths = [] |
|||
for k in lvis_dict.keys(): |
|||
paths.extend(lvis_dict[k]) |
|||
self.paths = paths |
|||
|
|||
total_objects = len(self.paths) |
|||
if validation: |
|||
self.paths = self.paths[-16:] # used last 16 as validation |
|||
else: |
|||
self.paths = self.paths[:-16] |
|||
print('============= length of dataset %d =============' % len(self.paths)) |
|||
|
|||
def __len__(self): |
|||
return len(self.paths) |
|||
|
|||
def load_im(self, path, color): |
|||
pil_img = Image.open(path) |
|||
|
|||
image = np.asarray(pil_img, dtype=np.float32) / 255. |
|||
alpha = image[:, :, 3:] |
|||
image = image[:, :, :3] * alpha + color * (1 - alpha) |
|||
|
|||
image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float() |
|||
alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float() |
|||
return image, alpha |
|||
|
|||
def __getitem__(self, index): |
|||
while True: |
|||
image_path = os.path.join(self.root_dir, self.image_dir, self.paths[index]) |
|||
|
|||
'''background color, default: white''' |
|||
bkg_color = [1., 1., 1.] |
|||
|
|||
img_list = [] |
|||
try: |
|||
for idx in range(7): |
|||
img, alpha = self.load_im(os.path.join(image_path, '%03d.png' % idx), bkg_color) |
|||
img_list.append(img) |
|||
|
|||
except Exception as e: |
|||
print(e) |
|||
index = np.random.randint(0, len(self.paths)) |
|||
continue |
|||
|
|||
break |
|||
|
|||
imgs = torch.stack(img_list, dim=0).float() |
|||
|
|||
data = { |
|||
'cond_imgs': imgs[0], # (3, H, W) |
|||
'target_imgs': imgs[1:], # (6, 3, H, W) |
|||
} |
|||
return data |
@ -0,0 +1,272 @@ |
|||
import os |
|||
import numpy as np |
|||
import torch |
|||
import torch.nn as nn |
|||
import torch.nn.functional as F |
|||
import pytorch_lightning as pl |
|||
from tqdm import tqdm |
|||
from torchvision.transforms import v2 |
|||
from torchvision.utils import make_grid, save_image |
|||
from einops import rearrange |
|||
|
|||
from src.utils.train_util import instantiate_from_config |
|||
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler, DDPMScheduler, UNet2DConditionModel |
|||
from .pipeline import RefOnlyNoisedUNet |
|||
|
|||
|
|||
def scale_latents(latents): |
|||
latents = (latents - 0.22) * 0.75 |
|||
return latents |
|||
|
|||
|
|||
def unscale_latents(latents): |
|||
latents = latents / 0.75 + 0.22 |
|||
return latents |
|||
|
|||
|
|||
def scale_image(image): |
|||
image = image * 0.5 / 0.8 |
|||
return image |
|||
|
|||
|
|||
def unscale_image(image): |
|||
image = image / 0.5 * 0.8 |
|||
return image |
|||
|
|||
|
|||
def extract_into_tensor(a, t, x_shape): |
|||
b, *_ = t.shape |
|||
out = a.gather(-1, t) |
|||
return out.reshape(b, *((1,) * (len(x_shape) - 1))) |
|||
|
|||
|
|||
class MVDiffusion(pl.LightningModule): |
|||
def __init__( |
|||
self, |
|||
stable_diffusion_config, |
|||
drop_cond_prob=0.1, |
|||
): |
|||
super(MVDiffusion, self).__init__() |
|||
|
|||
self.drop_cond_prob = drop_cond_prob |
|||
|
|||
self.register_schedule() |
|||
|
|||
# init modules |
|||
pipeline = DiffusionPipeline.from_pretrained(**stable_diffusion_config) |
|||
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( |
|||
pipeline.scheduler.config, timestep_spacing='trailing' |
|||
) |
|||
self.pipeline = pipeline |
|||
|
|||
train_sched = DDPMScheduler.from_config(self.pipeline.scheduler.config) |
|||
if isinstance(self.pipeline.unet, UNet2DConditionModel): |
|||
self.pipeline.unet = RefOnlyNoisedUNet(self.pipeline.unet, train_sched, self.pipeline.scheduler) |
|||
|
|||
self.train_scheduler = train_sched # use ddpm scheduler during training |
|||
|
|||
self.unet = pipeline.unet |
|||
|
|||
# validation output buffer |
|||
self.validation_step_outputs = [] |
|||
|
|||
def register_schedule(self): |
|||
self.num_timesteps = 1000 |
|||
|
|||
# replace scaled_linear schedule with linear schedule as Zero123++ |
|||
beta_start = 0.00085 |
|||
beta_end = 0.0120 |
|||
betas = torch.linspace(beta_start, beta_end, 1000, dtype=torch.float32) |
|||
|
|||
alphas = 1. - betas |
|||
alphas_cumprod = torch.cumprod(alphas, dim=0) |
|||
alphas_cumprod_prev = torch.cat([torch.ones(1, dtype=torch.float64), alphas_cumprod[:-1]], 0) |
|||
|
|||
self.register_buffer('betas', betas.float()) |
|||
self.register_buffer('alphas_cumprod', alphas_cumprod.float()) |
|||
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev.float()) |
|||
|
|||
# calculations for diffusion q(x_t | x_{t-1}) and others |
|||
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod).float()) |
|||
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1 - alphas_cumprod).float()) |
|||
|
|||
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod).float()) |
|||
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1).float()) |
|||
|
|||
def on_fit_start(self): |
|||
device = torch.device(f'cuda:{self.global_rank}') |
|||
self.pipeline.to(device) |
|||
if self.global_rank == 0: |
|||
os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True) |
|||
os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True) |
|||
|
|||
def prepare_batch_data(self, batch): |
|||
# prepare stable diffusion input |
|||
cond_imgs = batch['cond_imgs'] # (B, C, H, W) |
|||
cond_imgs = cond_imgs.to(self.device) |
|||
|
|||
# random resize the condition image |
|||
cond_size = np.random.randint(128, 513) |
|||
cond_imgs = v2.functional.resize(cond_imgs, cond_size, interpolation=3, antialias=True).clamp(0, 1) |
|||
|
|||
target_imgs = batch['target_imgs'] # (B, 6, C, H, W) |
|||
target_imgs = v2.functional.resize(target_imgs, 320, interpolation=3, antialias=True).clamp(0, 1) |
|||
target_imgs = rearrange(target_imgs, 'b (x y) c h w -> b c (x h) (y w)', x=3, y=2) # (B, C, 3H, 2W) |
|||
target_imgs = target_imgs.to(self.device) |
|||
|
|||
return cond_imgs, target_imgs |
|||
|
|||
@torch.no_grad() |
|||
def forward_vision_encoder(self, images): |
|||
dtype = next(self.pipeline.vision_encoder.parameters()).dtype |
|||
image_pil = [v2.functional.to_pil_image(images[i]) for i in range(images.shape[0])] |
|||
image_pt = self.pipeline.feature_extractor_clip(images=image_pil, return_tensors="pt").pixel_values |
|||
image_pt = image_pt.to(device=self.device, dtype=dtype) |
|||
global_embeds = self.pipeline.vision_encoder(image_pt, output_hidden_states=False).image_embeds |
|||
global_embeds = global_embeds.unsqueeze(-2) |
|||
|
|||
encoder_hidden_states = self.pipeline._encode_prompt("", self.device, 1, False)[0] |
|||
ramp = global_embeds.new_tensor(self.pipeline.config.ramping_coefficients).unsqueeze(-1) |
|||
encoder_hidden_states = encoder_hidden_states + global_embeds * ramp |
|||
|
|||
return encoder_hidden_states |
|||
|
|||
@torch.no_grad() |
|||
def encode_condition_image(self, images): |
|||
dtype = next(self.pipeline.vae.parameters()).dtype |
|||
image_pil = [v2.functional.to_pil_image(images[i]) for i in range(images.shape[0])] |
|||
image_pt = self.pipeline.feature_extractor_vae(images=image_pil, return_tensors="pt").pixel_values |
|||
image_pt = image_pt.to(device=self.device, dtype=dtype) |
|||
latents = self.pipeline.vae.encode(image_pt).latent_dist.sample() |
|||
return latents |
|||
|
|||
@torch.no_grad() |
|||
def encode_target_images(self, images): |
|||
dtype = next(self.pipeline.vae.parameters()).dtype |
|||
# equals to scaling images to [-1, 1] first and then call scale_image |
|||
images = (images - 0.5) / 0.8 # [-0.625, 0.625] |
|||
posterior = self.pipeline.vae.encode(images.to(dtype)).latent_dist |
|||
latents = posterior.sample() * self.pipeline.vae.config.scaling_factor |
|||
latents = scale_latents(latents) |
|||
return latents |
|||
|
|||
def forward_unet(self, latents, t, prompt_embeds, cond_latents): |
|||
dtype = next(self.pipeline.unet.parameters()).dtype |
|||
latents = latents.to(dtype) |
|||
prompt_embeds = prompt_embeds.to(dtype) |
|||
cond_latents = cond_latents.to(dtype) |
|||
cross_attention_kwargs = dict(cond_lat=cond_latents) |
|||
pred_noise = self.pipeline.unet( |
|||
latents, |
|||
t, |
|||
encoder_hidden_states=prompt_embeds, |
|||
cross_attention_kwargs=cross_attention_kwargs, |
|||
return_dict=False, |
|||
)[0] |
|||
return pred_noise |
|||
|
|||
def predict_start_from_z_and_v(self, x_t, t, v): |
|||
return ( |
|||
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - |
|||
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v |
|||
) |
|||
|
|||
def get_v(self, x, noise, t): |
|||
return ( |
|||
extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise - |
|||
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x |
|||
) |
|||
|
|||
def training_step(self, batch, batch_idx): |
|||
# get input |
|||
cond_imgs, target_imgs = self.prepare_batch_data(batch) |
|||
|
|||
# sample random timestep |
|||
B = cond_imgs.shape[0] |
|||
|
|||
t = torch.randint(0, self.num_timesteps, size=(B,)).long().to(self.device) |
|||
|
|||
# classifier-free guidance |
|||
if np.random.rand() < self.drop_cond_prob: |
|||
prompt_embeds = self.pipeline._encode_prompt([""]*B, self.device, 1, False) |
|||
cond_latents = self.encode_condition_image(torch.zeros_like(cond_imgs)) |
|||
else: |
|||
prompt_embeds = self.forward_vision_encoder(cond_imgs) |
|||
cond_latents = self.encode_condition_image(cond_imgs) |
|||
|
|||
latents = self.encode_target_images(target_imgs) |
|||
noise = torch.randn_like(latents) |
|||
latents_noisy = self.train_scheduler.add_noise(latents, noise, t) |
|||
|
|||
v_pred = self.forward_unet(latents_noisy, t, prompt_embeds, cond_latents) |
|||
v_target = self.get_v(latents, noise, t) |
|||
|
|||
loss, loss_dict = self.compute_loss(v_pred, v_target) |
|||
|
|||
# logging |
|||
self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True) |
|||
self.log("global_step", self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False) |
|||
lr = self.optimizers().param_groups[0]['lr'] |
|||
self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) |
|||
|
|||
if self.global_step % 500 == 0 and self.global_rank == 0: |
|||
with torch.no_grad(): |
|||
latents_pred = self.predict_start_from_z_and_v(latents_noisy, t, v_pred) |
|||
|
|||
latents = unscale_latents(latents_pred) |
|||
images = unscale_image(self.pipeline.vae.decode(latents / self.pipeline.vae.config.scaling_factor, return_dict=False)[0]) # [-1, 1] |
|||
images = (images * 0.5 + 0.5).clamp(0, 1) |
|||
images = torch.cat([target_imgs, images], dim=-2) |
|||
|
|||
grid = make_grid(images, nrow=images.shape[0], normalize=True, value_range=(0, 1)) |
|||
save_image(grid, os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png')) |
|||
|
|||
return loss |
|||
|
|||
def compute_loss(self, noise_pred, noise_gt): |
|||
loss = F.mse_loss(noise_pred, noise_gt) |
|||
|
|||
prefix = 'train' |
|||
loss_dict = {} |
|||
loss_dict.update({f'{prefix}/loss': loss}) |
|||
|
|||
return loss, loss_dict |
|||
|
|||
@torch.no_grad() |
|||
def validation_step(self, batch, batch_idx): |
|||
# get input |
|||
cond_imgs, target_imgs = self.prepare_batch_data(batch) |
|||
|
|||
images_pil = [v2.functional.to_pil_image(cond_imgs[i]) for i in range(cond_imgs.shape[0])] |
|||
|
|||
outputs = [] |
|||
for cond_img in images_pil: |
|||
latent = self.pipeline(cond_img, num_inference_steps=75, output_type='latent').images |
|||
image = unscale_image(self.pipeline.vae.decode(latent / self.pipeline.vae.config.scaling_factor, return_dict=False)[0]) # [-1, 1] |
|||
image = (image * 0.5 + 0.5).clamp(0, 1) |
|||
outputs.append(image) |
|||
outputs = torch.cat(outputs, dim=0).to(self.device) |
|||
images = torch.cat([target_imgs, outputs], dim=-2) |
|||
|
|||
self.validation_step_outputs.append(images) |
|||
|
|||
@torch.no_grad() |
|||
def on_validation_epoch_end(self): |
|||
images = torch.cat(self.validation_step_outputs, dim=0) |
|||
|
|||
all_images = self.all_gather(images) |
|||
all_images = rearrange(all_images, 'r b c h w -> (r b) c h w') |
|||
|
|||
if self.global_rank == 0: |
|||
grid = make_grid(all_images, nrow=8, normalize=True, value_range=(0, 1)) |
|||
save_image(grid, os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png')) |
|||
|
|||
self.validation_step_outputs.clear() # free memory |
|||
|
|||
def configure_optimizers(self): |
|||
lr = self.learning_rate |
|||
|
|||
optimizer = torch.optim.AdamW(self.unet.parameters(), lr=lr) |
|||
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 3000, eta_min=lr/4) |
|||
|
|||
return {'optimizer': optimizer, 'lr_scheduler': scheduler} |
Loading…
Reference in new issue