From 34c193cc96eebd46deb7c48a76613753ad777122 Mon Sep 17 00:00:00 2001 From: bluestyle97 Date: Tue, 7 May 2024 16:22:49 +0800 Subject: [PATCH] release zero123++ fine-tuning code --- README.md | 14 +- configs/zero123plus-finetune.yaml | 47 ++++++ src/data/objaverse_zero123plus.py | 124 ++++++++++++++ zero123plus/model.py | 272 ++++++++++++++++++++++++++++++ 4 files changed, 454 insertions(+), 3 deletions(-) create mode 100644 configs/zero123plus-finetune.yaml create mode 100644 src/data/objaverse_zero123plus.py create mode 100644 zero123plus/model.py diff --git a/README.md b/README.md index 215e588..eebebdd 100644 --- a/README.md +++ b/README.md @@ -17,12 +17,13 @@ This repo is the official implementation of InstantMesh, a feed-forward framewor https://github.com/TencentARC/InstantMesh/assets/20635237/dab3511e-e7c6-4c0b-bab7-15772045c47d -# 🚩 Todo List - +# 🚩 Features and Todo List +- [x] 🔥🔥 Release Zero123++ fine-tuning code. +- [x] 🔥🔥 Support for running gradio demo on two GPUs to save memory. +- [x] 🔥🔥 Support for running demo with docker. Please refer to the [docker](docker/) directory. - [x] Release inference and training code. - [x] Release model weights. - [x] Release huggingface gradio demo. Please try it at [demo](https://huggingface.co/spaces/TencentARC/InstantMesh) link. -- [x] Add support for running gradio demo on two GPUs to save memory. - [ ] Add support for more multi-view diffusion models. # ⚙️ Dependencies and Installation @@ -76,6 +77,8 @@ If you have multiple GPUs in your machine, the demo app will run on two GPUs aut CUDA_VISIBLE_DEVICES=0 python app.py ``` +Alternatively, you can run the demo with docker. Please follow the instructions in the [docker](docker/) directory. + ## Running with command line To generate 3D meshes from images via command line, simply run: @@ -112,6 +115,11 @@ python train.py --base configs/instant-nerf-large-train.yaml --gpus 0,1,2,3,4,5, python train.py --base configs/instant-mesh-large-train.yaml --gpus 0,1,2,3,4,5,6,7 --num_nodes 1 ``` +We also provide our Zero123++ fine-tuning code since it is frequently requested. The running command is: +```bash +python train.py --base configs/zero123plus-finetune.yaml --gpus 0,1,2,3,4,5,6,7 --num_nodes 1 +``` + # :books: Citation If you find our work useful for your research or applications, please cite using this BibTeX: diff --git a/configs/zero123plus-finetune.yaml b/configs/zero123plus-finetune.yaml new file mode 100644 index 0000000..52b3394 --- /dev/null +++ b/configs/zero123plus-finetune.yaml @@ -0,0 +1,47 @@ +model: + base_learning_rate: 1.0e-05 + target: zero123plus.model.MVDiffusion + params: + drop_cond_prob: 0.1 + + stable_diffusion_config: + pretrained_model_name_or_path: sudo-ai/zero123plus-v1.2 + custom_pipeline: ./zero123plus + +data: + target: src.data.objaverse_zero123plus.DataModuleFromConfig + params: + batch_size: 6 + num_workers: 8 + train: + target: src.data.objaverse_zero123plus.ObjaverseData + params: + root_dir: data/objaverse + meta_fname: lvis-annotations.json + image_dir: rendering_zero123plus + validation: false + validation: + target: src.data.objaverse_zero123plus.ObjaverseData + params: + root_dir: data/objaverse + meta_fname: lvis-annotations.json + image_dir: rendering_zero123plus + validation: true + + +lightning: + modelcheckpoint: + params: + every_n_train_steps: 1000 + save_top_k: -1 + save_last: true + callbacks: {} + + trainer: + benchmark: true + max_epochs: -1 + gradient_clip_val: 1.0 + val_check_interval: 1000 + num_sanity_val_steps: 0 + accumulate_grad_batches: 1 + check_val_every_n_epoch: null # if not set this, validation does not run diff --git a/src/data/objaverse_zero123plus.py b/src/data/objaverse_zero123plus.py new file mode 100644 index 0000000..1d02612 --- /dev/null +++ b/src/data/objaverse_zero123plus.py @@ -0,0 +1,124 @@ +import os +import json +import numpy as np +import webdataset as wds +import pytorch_lightning as pl +import torch +from torch.utils.data import Dataset +from torch.utils.data.distributed import DistributedSampler +from PIL import Image +from pathlib import Path + +from src.utils.train_util import instantiate_from_config + + +class DataModuleFromConfig(pl.LightningDataModule): + def __init__( + self, + batch_size=8, + num_workers=4, + train=None, + validation=None, + test=None, + **kwargs, + ): + super().__init__() + + self.batch_size = batch_size + self.num_workers = num_workers + + self.dataset_configs = dict() + if train is not None: + self.dataset_configs['train'] = train + if validation is not None: + self.dataset_configs['validation'] = validation + if test is not None: + self.dataset_configs['test'] = test + + def setup(self, stage): + + if stage in ['fit']: + self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs) + else: + raise NotImplementedError + + def train_dataloader(self): + + sampler = DistributedSampler(self.datasets['train']) + return wds.WebLoader(self.datasets['train'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler) + + def val_dataloader(self): + + sampler = DistributedSampler(self.datasets['validation']) + return wds.WebLoader(self.datasets['validation'], batch_size=4, num_workers=self.num_workers, shuffle=False, sampler=sampler) + + def test_dataloader(self): + + return wds.WebLoader(self.datasets['test'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) + + +class ObjaverseData(Dataset): + def __init__(self, + root_dir='objaverse/', + meta_fname='valid_paths.json', + image_dir='rendering_zero123plus', + validation=False, + ): + self.root_dir = Path(root_dir) + self.image_dir = image_dir + + with open(os.path.join(root_dir, meta_fname)) as f: + lvis_dict = json.load(f) + paths = [] + for k in lvis_dict.keys(): + paths.extend(lvis_dict[k]) + self.paths = paths + + total_objects = len(self.paths) + if validation: + self.paths = self.paths[-16:] # used last 16 as validation + else: + self.paths = self.paths[:-16] + print('============= length of dataset %d =============' % len(self.paths)) + + def __len__(self): + return len(self.paths) + + def load_im(self, path, color): + pil_img = Image.open(path) + + image = np.asarray(pil_img, dtype=np.float32) / 255. + alpha = image[:, :, 3:] + image = image[:, :, :3] * alpha + color * (1 - alpha) + + image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float() + alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float() + return image, alpha + + def __getitem__(self, index): + while True: + image_path = os.path.join(self.root_dir, self.image_dir, self.paths[index]) + + '''background color, default: white''' + bkg_color = [1., 1., 1.] + + img_list = [] + try: + for idx in range(7): + img, alpha = self.load_im(os.path.join(image_path, '%03d.png' % idx), bkg_color) + img_list.append(img) + + except Exception as e: + print(e) + index = np.random.randint(0, len(self.paths)) + continue + + break + + imgs = torch.stack(img_list, dim=0).float() + + data = { + 'cond_imgs': imgs[0], # (3, H, W) + 'target_imgs': imgs[1:], # (6, 3, H, W) + } + return data diff --git a/zero123plus/model.py b/zero123plus/model.py new file mode 100644 index 0000000..1655c45 --- /dev/null +++ b/zero123plus/model.py @@ -0,0 +1,272 @@ +import os +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import pytorch_lightning as pl +from tqdm import tqdm +from torchvision.transforms import v2 +from torchvision.utils import make_grid, save_image +from einops import rearrange + +from src.utils.train_util import instantiate_from_config +from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler, DDPMScheduler, UNet2DConditionModel +from .pipeline import RefOnlyNoisedUNet + + +def scale_latents(latents): + latents = (latents - 0.22) * 0.75 + return latents + + +def unscale_latents(latents): + latents = latents / 0.75 + 0.22 + return latents + + +def scale_image(image): + image = image * 0.5 / 0.8 + return image + + +def unscale_image(image): + image = image / 0.5 * 0.8 + return image + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +class MVDiffusion(pl.LightningModule): + def __init__( + self, + stable_diffusion_config, + drop_cond_prob=0.1, + ): + super(MVDiffusion, self).__init__() + + self.drop_cond_prob = drop_cond_prob + + self.register_schedule() + + # init modules + pipeline = DiffusionPipeline.from_pretrained(**stable_diffusion_config) + pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( + pipeline.scheduler.config, timestep_spacing='trailing' + ) + self.pipeline = pipeline + + train_sched = DDPMScheduler.from_config(self.pipeline.scheduler.config) + if isinstance(self.pipeline.unet, UNet2DConditionModel): + self.pipeline.unet = RefOnlyNoisedUNet(self.pipeline.unet, train_sched, self.pipeline.scheduler) + + self.train_scheduler = train_sched # use ddpm scheduler during training + + self.unet = pipeline.unet + + # validation output buffer + self.validation_step_outputs = [] + + def register_schedule(self): + self.num_timesteps = 1000 + + # replace scaled_linear schedule with linear schedule as Zero123++ + beta_start = 0.00085 + beta_end = 0.0120 + betas = torch.linspace(beta_start, beta_end, 1000, dtype=torch.float32) + + alphas = 1. - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_cumprod_prev = torch.cat([torch.ones(1, dtype=torch.float64), alphas_cumprod[:-1]], 0) + + self.register_buffer('betas', betas.float()) + self.register_buffer('alphas_cumprod', alphas_cumprod.float()) + self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev.float()) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod).float()) + self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1 - alphas_cumprod).float()) + + self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod).float()) + self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1).float()) + + def on_fit_start(self): + device = torch.device(f'cuda:{self.global_rank}') + self.pipeline.to(device) + if self.global_rank == 0: + os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True) + os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True) + + def prepare_batch_data(self, batch): + # prepare stable diffusion input + cond_imgs = batch['cond_imgs'] # (B, C, H, W) + cond_imgs = cond_imgs.to(self.device) + + # random resize the condition image + cond_size = np.random.randint(128, 513) + cond_imgs = v2.functional.resize(cond_imgs, cond_size, interpolation=3, antialias=True).clamp(0, 1) + + target_imgs = batch['target_imgs'] # (B, 6, C, H, W) + target_imgs = v2.functional.resize(target_imgs, 320, interpolation=3, antialias=True).clamp(0, 1) + target_imgs = rearrange(target_imgs, 'b (x y) c h w -> b c (x h) (y w)', x=3, y=2) # (B, C, 3H, 2W) + target_imgs = target_imgs.to(self.device) + + return cond_imgs, target_imgs + + @torch.no_grad() + def forward_vision_encoder(self, images): + dtype = next(self.pipeline.vision_encoder.parameters()).dtype + image_pil = [v2.functional.to_pil_image(images[i]) for i in range(images.shape[0])] + image_pt = self.pipeline.feature_extractor_clip(images=image_pil, return_tensors="pt").pixel_values + image_pt = image_pt.to(device=self.device, dtype=dtype) + global_embeds = self.pipeline.vision_encoder(image_pt, output_hidden_states=False).image_embeds + global_embeds = global_embeds.unsqueeze(-2) + + encoder_hidden_states = self.pipeline._encode_prompt("", self.device, 1, False)[0] + ramp = global_embeds.new_tensor(self.pipeline.config.ramping_coefficients).unsqueeze(-1) + encoder_hidden_states = encoder_hidden_states + global_embeds * ramp + + return encoder_hidden_states + + @torch.no_grad() + def encode_condition_image(self, images): + dtype = next(self.pipeline.vae.parameters()).dtype + image_pil = [v2.functional.to_pil_image(images[i]) for i in range(images.shape[0])] + image_pt = self.pipeline.feature_extractor_vae(images=image_pil, return_tensors="pt").pixel_values + image_pt = image_pt.to(device=self.device, dtype=dtype) + latents = self.pipeline.vae.encode(image_pt).latent_dist.sample() + return latents + + @torch.no_grad() + def encode_target_images(self, images): + dtype = next(self.pipeline.vae.parameters()).dtype + # equals to scaling images to [-1, 1] first and then call scale_image + images = (images - 0.5) / 0.8 # [-0.625, 0.625] + posterior = self.pipeline.vae.encode(images.to(dtype)).latent_dist + latents = posterior.sample() * self.pipeline.vae.config.scaling_factor + latents = scale_latents(latents) + return latents + + def forward_unet(self, latents, t, prompt_embeds, cond_latents): + dtype = next(self.pipeline.unet.parameters()).dtype + latents = latents.to(dtype) + prompt_embeds = prompt_embeds.to(dtype) + cond_latents = cond_latents.to(dtype) + cross_attention_kwargs = dict(cond_lat=cond_latents) + pred_noise = self.pipeline.unet( + latents, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + return pred_noise + + def predict_start_from_z_and_v(self, x_t, t, v): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v + ) + + def get_v(self, x, noise, t): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise - + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x + ) + + def training_step(self, batch, batch_idx): + # get input + cond_imgs, target_imgs = self.prepare_batch_data(batch) + + # sample random timestep + B = cond_imgs.shape[0] + + t = torch.randint(0, self.num_timesteps, size=(B,)).long().to(self.device) + + # classifier-free guidance + if np.random.rand() < self.drop_cond_prob: + prompt_embeds = self.pipeline._encode_prompt([""]*B, self.device, 1, False) + cond_latents = self.encode_condition_image(torch.zeros_like(cond_imgs)) + else: + prompt_embeds = self.forward_vision_encoder(cond_imgs) + cond_latents = self.encode_condition_image(cond_imgs) + + latents = self.encode_target_images(target_imgs) + noise = torch.randn_like(latents) + latents_noisy = self.train_scheduler.add_noise(latents, noise, t) + + v_pred = self.forward_unet(latents_noisy, t, prompt_embeds, cond_latents) + v_target = self.get_v(latents, noise, t) + + loss, loss_dict = self.compute_loss(v_pred, v_target) + + # logging + self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log("global_step", self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False) + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) + + if self.global_step % 500 == 0 and self.global_rank == 0: + with torch.no_grad(): + latents_pred = self.predict_start_from_z_and_v(latents_noisy, t, v_pred) + + latents = unscale_latents(latents_pred) + images = unscale_image(self.pipeline.vae.decode(latents / self.pipeline.vae.config.scaling_factor, return_dict=False)[0]) # [-1, 1] + images = (images * 0.5 + 0.5).clamp(0, 1) + images = torch.cat([target_imgs, images], dim=-2) + + grid = make_grid(images, nrow=images.shape[0], normalize=True, value_range=(0, 1)) + save_image(grid, os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png')) + + return loss + + def compute_loss(self, noise_pred, noise_gt): + loss = F.mse_loss(noise_pred, noise_gt) + + prefix = 'train' + loss_dict = {} + loss_dict.update({f'{prefix}/loss': loss}) + + return loss, loss_dict + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + # get input + cond_imgs, target_imgs = self.prepare_batch_data(batch) + + images_pil = [v2.functional.to_pil_image(cond_imgs[i]) for i in range(cond_imgs.shape[0])] + + outputs = [] + for cond_img in images_pil: + latent = self.pipeline(cond_img, num_inference_steps=75, output_type='latent').images + image = unscale_image(self.pipeline.vae.decode(latent / self.pipeline.vae.config.scaling_factor, return_dict=False)[0]) # [-1, 1] + image = (image * 0.5 + 0.5).clamp(0, 1) + outputs.append(image) + outputs = torch.cat(outputs, dim=0).to(self.device) + images = torch.cat([target_imgs, outputs], dim=-2) + + self.validation_step_outputs.append(images) + + @torch.no_grad() + def on_validation_epoch_end(self): + images = torch.cat(self.validation_step_outputs, dim=0) + + all_images = self.all_gather(images) + all_images = rearrange(all_images, 'r b c h w -> (r b) c h w') + + if self.global_rank == 0: + grid = make_grid(all_images, nrow=8, normalize=True, value_range=(0, 1)) + save_image(grid, os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png')) + + self.validation_step_outputs.clear() # free memory + + def configure_optimizers(self): + lr = self.learning_rate + + optimizer = torch.optim.AdamW(self.unet.parameters(), lr=lr) + scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 3000, eta_min=lr/4) + + return {'optimizer': optimizer, 'lr_scheduler': scheduler}