import os import numpy as np import torch import rembg import time from PIL import Image from torchvision.transforms import v2 from tqdm import tqdm from huggingface_hub import hf_hub_download from pytorch_lightning import seed_everything from omegaconf import OmegaConf from einops import rearrange from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler import subprocess from ObjSender import ObjSender from src.utils.train_util import instantiate_from_config from src.utils.camera_util import ( FOV_to_intrinsics, get_zero123plus_input_cameras, get_circular_camera_poses, ) from src.utils.mesh_util import save_obj, save_obj_with_mtl, save_gltf from src.utils.infer_util import remove_background, resize_foreground, save_video class MeshRenderingPipeline: def __init__(self, config_file, input_path, output_path='outputs/', diffusion_steps=75, seed=42, scale=1.0, distance=4.5, view=4, no_rembg=False, export_texmap=False, save_video=False, gltf=False, remove=False, gltf_path='C:/Users/caile/Desktop/', local=False): self.config_file = config_file self.input_path = input_path self.output_path = output_path self.diffusion_steps = diffusion_steps self.seed = seed self.scale = scale self.distance = distance self.view = view self.no_rembg = no_rembg self.export_texmap = export_texmap self.save_video = save_video self.gltf = gltf self.remove = remove self.gltf_path = gltf_path self.sender = ObjSender("localhost", "3000") self.local = local # Parse configuration and setup self._parse_config() self._setup() def _parse_config(self): # Parse configuration file self.config = OmegaConf.load(self.config_file) self.config_name = os.path.basename(self.config_file).replace('.yaml', '') self.model_config = self.config.model_config self.infer_config = self.config.infer_config self.IS_FLEXICUBES = True if self.config_name.startswith('instant-mesh') else False def _setup(self): # Seed for reproducibility seed_everything(self.seed) # Device setup self.device = torch.device('cuda') # Load diffusion model print('Loading diffusion model ...') self.pipeline = DiffusionPipeline.from_pretrained( "sudo-ai/zero123plus-v1.2", custom_pipeline="zero123plus", torch_dtype=torch.float16, ) self.pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( self.pipeline.scheduler.config, timestep_spacing='trailing' ) # Load custom white-background UNet print('Loading custom white-background unet ...') if os.path.exists(self.infer_config.unet_path): unet_ckpt_path = self.infer_config.unet_path else: unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model") state_dict = torch.load(unet_ckpt_path, map_location='cpu') self.pipeline.unet.load_state_dict(state_dict, strict=True) self.pipeline = self.pipeline.to(self.device) # Load reconstruction model print('Loading reconstruction model ...') self.model = instantiate_from_config(self.model_config) if os.path.exists(self.infer_config.model_path): model_ckpt_path = self.infer_config.model_path else: model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename=f"{self.config_name.replace('-', '_')}.ckpt", repo_type="model") state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict'] state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')} self.model.load_state_dict(state_dict, strict=True) self.model = self.model.to(self.device) if self.IS_FLEXICUBES: self.model.init_flexicubes_geometry(self.device, fovy=30.0) self.model = self.model.eval() # Make output directories self.image_path = os.path.join(self.output_path, self.config_name, 'images') self.mesh_path = os.path.join(self.output_path, self.config_name, 'meshes') self.video_path = os.path.join(self.output_path, self.config_name, 'videos') os.makedirs(self.image_path, exist_ok=True) os.makedirs(self.mesh_path, exist_ok=True) os.makedirs(self.video_path, exist_ok=True) def process_image(self, image_file, idx, total_num_files, rembg_session): if rembg_session == None: rembg_session = None if self.no_rembg else rembg.new_session() name = os.path.basename(image_file).split('.')[0] print(f'[{idx + 1}/{total_num_files}] Creating novel viewpoints of {name} ...') # Remove background optionally input_image = Image.open(image_file) if not self.no_rembg: input_image = remove_background(input_image, rembg_session) input_image = resize_foreground(input_image, 0.85) generator = torch.Generator(device=self.device) self.pipeline = self.pipeline.to(self.device) # Sampling output_image = self.pipeline( input_image, num_inference_steps=self.diffusion_steps, generator=generator ).images[0] self.pipeline.to("cpu") torch.cuda.empty_cache() torch.cuda.reset_max_memory_cached() img_path = os.path.join(self.image_path, f'{name}.png') output_image.save(img_path) print(f"Image of viewpoints saved to {os.path.join(self.image_path, f'{name}.png')}") images = np.asarray(output_image, dtype=np.float32) / 255.0 images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640) images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320) input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0 * self.scale).to(self.device) chunk_size = 20 if self.IS_FLEXICUBES else 1 print(f'Creating {name} ...') start_time = time.time() images = images.unsqueeze(0).to(self.device) images = v2.functional.resize(images, 320, interpolation=3, antialias=True).clamp(0, 1) if self.view == 4: indices = torch.tensor([0, 2, 4, 5]).long().to(self.device) images = images[:, indices] input_cameras = input_cameras[:, indices] with torch.no_grad(): # Get triplane planes = self.model.forward_planes(images, input_cameras) # Get mesh mesh_path_idx = os.path.join(self.mesh_path, f'{name}.obj') mtl_path_idx = os.path.join(self.mesh_path, f'{name}.mtl') texmap_path_idx = os.path.join(self.mesh_path, f'{name}.png') mesh_out = self.model.extract_mesh( planes, use_texture_map=self.export_texmap, **self.infer_config, ) if self.export_texmap: vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out save_obj_with_mtl( vertices.data.cpu().numpy(), uvs.data.cpu().numpy(), faces.data.cpu().numpy(), mesh_tex_idx.data.cpu().numpy(), tex_map.permute(1, 2, 0).data.cpu().numpy(), mesh_path_idx, ) else: vertices, faces, vertex_colors = mesh_out save_obj(vertices, faces, vertex_colors, mesh_path_idx) print(f"Mesh saved to {mesh_path_idx}") # Get video if self.save_video: video_path_idx = os.path.join(self.video_path, f'{name}.mp4') render_size = self.infer_config.render_resolution render_cameras = self._get_render_cameras() frames = self._render_frames( planes, render_cameras=render_cameras, render_size=render_size, chunk_size=chunk_size, ) save_video( frames, video_path_idx, fps=30, ) print(f"Video saved to {video_path_idx}") if self.gltf: output_path = os.path.join(self.gltf_path, f'{name}.gltf') command = f' obj2gltf -i {mesh_path_idx} -o {output_path}' process = subprocess.run(command, shell=True, capture_output=True, text=True) # Check if the process was successful if process.returncode == 0: print(f'Successfully converted {mesh_path_idx} to {output_path}') if self.local is False: self.sender.send_file(output_path) else: print(f'Error converting {mesh_path_idx}: {process.stderr}') if self.remove: print(f'Removing files (mtl, obj, diffusion, texmap) for {name}') os.remove(mesh_path_idx) os.remove(mtl_path_idx) os.remove(img_path) os.remove(texmap_path_idx) end_time = time.time() elapsed_time = end_time - start_time print(f'Total Time: {elapsed_time}') pass def run_pipeline_sequence(self, stop_event): # Process input files if os.path.isdir(self.input_path): input_files = [ os.path.join(self.input_path, file) for file in os.listdir(self.input_path) if file.endswith('.png') or file.endswith('.jpg') or file.endswith('.webp') ] else: input_files = [self.input_path] print(f'\nTotal number of input images: {len(input_files)}') rembg_session = None if self.no_rembg else rembg.new_session() for idx, image_file in enumerate(input_files): if stop_event.is_set(): # Check if stop event is set print("\nStopping pipeline sequence.") break self.process_image(image_file, idx, len(input_files), rembg_session) def _render_frames(self, planes, render_cameras, render_size=512, chunk_size=1): # Render frames from triplanes frames = [] for i in tqdm(range(0, render_cameras.shape[1], chunk_size)): if self.IS_FLEXICUBES: frame = self.model.forward_geometry( planes, render_cameras[:, i:i + chunk_size], render_size=render_size, )['img'] else: frame = self.model.forward_synthesizer( planes, render_cameras[:, i:i + chunk_size], render_size=render_size, )['images_rgb'] frames.append(frame) frames = torch.cat(frames, dim=1)[0] # we suppose batch size is always 1 return frames