5 changed files with 406 additions and 1 deletions
			
			
		| @ -0,0 +1,259 @@ | |||
| import os | |||
| import numpy as np | |||
| import torch | |||
| import rembg | |||
| from PIL import Image | |||
| from torchvision.transforms import v2 | |||
| from tqdm import tqdm | |||
| from huggingface_hub import hf_hub_download | |||
| from pytorch_lightning import seed_everything | |||
| from omegaconf import OmegaConf | |||
| from einops import rearrange | |||
| from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler | |||
| import subprocess | |||
| 
 | |||
| 
 | |||
| from src.utils.train_util import instantiate_from_config | |||
| from src.utils.camera_util import ( | |||
|     FOV_to_intrinsics, | |||
|     get_zero123plus_input_cameras, | |||
|     get_circular_camera_poses, | |||
| ) | |||
| from src.utils.mesh_util import save_obj, save_obj_with_mtl, save_gltf | |||
| from src.utils.infer_util import remove_background, resize_foreground, save_video | |||
| 
 | |||
| 
 | |||
| class MeshRenderingPipeline: | |||
|     def __init__(self, config_file, input_path, output_path='outputs/', diffusion_steps=75, | |||
|                  seed=42, scale=1.0, distance=4.5, view=6, no_rembg=False, export_texmap=False, | |||
|                  save_video=False, gltf=False, remove=False, gltf_path='C:/Users/caile/Desktop/'): | |||
|         self.config_file = config_file | |||
|         self.input_path = input_path | |||
|         self.output_path = output_path | |||
|         self.diffusion_steps = diffusion_steps | |||
|         self.seed = seed | |||
|         self.scale = scale | |||
|         self.distance = distance | |||
|         self.view = view | |||
|         self.no_rembg = no_rembg | |||
|         self.export_texmap = export_texmap | |||
|         self.save_video = save_video | |||
|         self.gltf = gltf | |||
|         self.remove = remove | |||
|         self.gltf_path = gltf_path | |||
| 
 | |||
|         # Parse configuration and setup | |||
|         self._parse_config() | |||
|         self._setup() | |||
| 
 | |||
|     def _parse_config(self): | |||
|         # Parse configuration file | |||
|         self.config = OmegaConf.load(self.config_file) | |||
|         self.config_name = os.path.basename(self.config_file).replace('.yaml', '') | |||
|         self.model_config = self.config.model_config | |||
|         self.infer_config = self.config.infer_config | |||
|         self.IS_FLEXICUBES = True if self.config_name.startswith('instant-mesh') else False | |||
| 
 | |||
|     def _setup(self): | |||
|         # Seed for reproducibility | |||
|         seed_everything(self.seed) | |||
| 
 | |||
|         # Device setup | |||
|         self.device = torch.device('cuda') | |||
| 
 | |||
|         # Load diffusion model | |||
|         print('Loading diffusion model ...') | |||
|         self.pipeline = DiffusionPipeline.from_pretrained( | |||
|             "sudo-ai/zero123plus-v1.2", | |||
|             custom_pipeline="zero123plus", | |||
|             torch_dtype=torch.float16, | |||
|         ) | |||
|         self.pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( | |||
|             self.pipeline.scheduler.config, timestep_spacing='trailing' | |||
|         ) | |||
| 
 | |||
|         # Load custom white-background UNet | |||
|         print('Loading custom white-background unet ...') | |||
|         if os.path.exists(self.infer_config.unet_path): | |||
|             unet_ckpt_path = self.infer_config.unet_path | |||
|         else: | |||
|             unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", | |||
|                                              repo_type="model") | |||
|         state_dict = torch.load(unet_ckpt_path, map_location='cpu') | |||
|         self.pipeline.unet.load_state_dict(state_dict, strict=True) | |||
| 
 | |||
|         self.pipeline = self.pipeline.to(self.device) | |||
| 
 | |||
|         # Load reconstruction model | |||
|         print('Loading reconstruction model ...') | |||
|         self.model = instantiate_from_config(self.model_config) | |||
|         if os.path.exists(self.infer_config.model_path): | |||
|             model_ckpt_path = self.infer_config.model_path | |||
|         else: | |||
|             model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", | |||
|                                               filename=f"{self.config_name.replace('-', '_')}.ckpt", repo_type="model") | |||
|         state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict'] | |||
|         state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')} | |||
|         self.model.load_state_dict(state_dict, strict=True) | |||
| 
 | |||
|         self.model = self.model.to(self.device) | |||
|         if self.IS_FLEXICUBES: | |||
|             self.model.init_flexicubes_geometry(self.device, fovy=30.0) | |||
|         self.model = self.model.eval() | |||
| 
 | |||
|         # Make output directories | |||
|         self.image_path = os.path.join(self.output_path, self.config_name, 'images') | |||
|         self.mesh_path = os.path.join(self.output_path, self.config_name, 'meshes') | |||
|         self.video_path = os.path.join(self.output_path, self.config_name, 'videos') | |||
|         os.makedirs(self.image_path, exist_ok=True) | |||
|         os.makedirs(self.mesh_path, exist_ok=True) | |||
|         os.makedirs(self.video_path, exist_ok=True) | |||
| 
 | |||
|     def process_image(self, image_file, idx, total_num_files, rembg_session): | |||
|         if rembg_session == None: | |||
|             rembg_session = None if self.no_rembg else rembg.new_session() | |||
|          | |||
|         name = os.path.basename(image_file).split('.')[0] | |||
|         print(f'[{idx + 1}/{total_num_files}] Creating novel viewpoints of {name} ...') | |||
| 
 | |||
|         # Remove background optionally | |||
|         input_image = Image.open(image_file) | |||
|         if not self.no_rembg: | |||
|             input_image = remove_background(input_image, rembg_session) | |||
|             input_image = resize_foreground(input_image, 0.85) | |||
| 
 | |||
|         # Sampling | |||
|         output_image = self.pipeline( | |||
|             input_image, | |||
|             num_inference_steps=self.diffusion_steps, | |||
|         ).images[0] | |||
| 
 | |||
|         img_path = os.path.join(self.image_path, f'{name}.png') | |||
|         output_image.save(img_path) | |||
|         print(f"Image of viewpoints saved to {os.path.join(self.image_path, f'{name}.png')}") | |||
| 
 | |||
|         images = np.asarray(output_image, dtype=np.float32) / 255.0 | |||
|         images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float()  # (3, 960, 640) | |||
|         images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2)  # (6, 3, 320, 320) | |||
| 
 | |||
|         input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0 * self.scale).to(self.device) | |||
|         chunk_size = 20 if self.IS_FLEXICUBES else 1 | |||
|         print(f'Creating {name} ...') | |||
|         images = images.unsqueeze(0).to(self.device) | |||
|         images = v2.functional.resize(images, 320, interpolation=3, antialias=True).clamp(0, 1) | |||
| 
 | |||
|         if self.view == 4: | |||
|             indices = torch.tensor([0, 2, 4, 5]).long().to(self.device) | |||
|             images = images[:, indices] | |||
|             input_cameras = input_cameras[:, indices] | |||
| 
 | |||
|         with torch.no_grad(): | |||
|             # Get triplane | |||
|             planes = self.model.forward_planes(images, input_cameras) | |||
| 
 | |||
|             # Get mesh | |||
|             mesh_path_idx = os.path.join(self.mesh_path, f'{name}.obj') | |||
|             mtl_path_idx = os.path.join(self.mesh_path, f'{name}.mtl') | |||
|             texmap_path_idx = os.path.join(self.mesh_path, f'{name}.png') | |||
| 
 | |||
|             mesh_out = self.model.extract_mesh( | |||
|                 planes, | |||
|                 use_texture_map=self.export_texmap, | |||
|                 **self.infer_config, | |||
|             ) | |||
| 
 | |||
| 
 | |||
|             if self.export_texmap: | |||
|                 vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out | |||
|                 save_obj_with_mtl( | |||
|                     vertices.data.cpu().numpy(), | |||
|                     uvs.data.cpu().numpy(), | |||
|                     faces.data.cpu().numpy(), | |||
|                     mesh_tex_idx.data.cpu().numpy(), | |||
|                     tex_map.permute(1, 2, 0).data.cpu().numpy(), | |||
|                     mesh_path_idx, | |||
|                 ) | |||
|             else: | |||
|                 vertices, faces, vertex_colors = mesh_out | |||
|                 save_obj(vertices, faces, vertex_colors, mesh_path_idx) | |||
|              | |||
|             print(f"Mesh saved to {mesh_path_idx}") | |||
| 
 | |||
|             # Get video | |||
|             if self.save_video: | |||
|                 video_path_idx = os.path.join(self.video_path, f'{name}.mp4') | |||
|                 render_size = self.infer_config.render_resolution | |||
|                 render_cameras = self._get_render_cameras() | |||
| 
 | |||
|                 frames = self._render_frames( | |||
|                     planes, | |||
|                     render_cameras=render_cameras, | |||
|                     render_size=render_size, | |||
|                     chunk_size=chunk_size, | |||
|                 ) | |||
| 
 | |||
|                 save_video( | |||
|                     frames, | |||
|                     video_path_idx, | |||
|                     fps=30, | |||
|                 ) | |||
|                 print(f"Video saved to {video_path_idx}") | |||
| 
 | |||
|             if self.gltf: | |||
|                 output_path = os.path.join(self.gltf_path, f'{name}.gltf') | |||
|                 command = f' obj2gltf -i {mesh_path_idx} -o {output_path}' | |||
|                 process = subprocess.run(command, shell=True, capture_output=True, text=True) | |||
|                 # Check if the process was successful | |||
|                 if process.returncode == 0: | |||
|                     print(f'Successfully converted {mesh_path_idx} to {output_path}') | |||
|                 else: | |||
|                     print(f'Error converting {mesh_path_idx}: {process.stderr}') | |||
| 
 | |||
|                 if self.remove: | |||
|                     print(f'Removing files (mtl, obj, diffusion, texmap) for {name}') | |||
|                     os.remove(mesh_path_idx) | |||
|                     os.remove(mtl_path_idx) | |||
|                     os.remove(img_path) | |||
|                     os.remove(texmap_path_idx) | |||
|         pass | |||
| 
 | |||
|     def run_pipeline_sequence(self, stop_event): | |||
|         # Process input files | |||
|         if os.path.isdir(self.input_path): | |||
|             input_files = [ | |||
|                 os.path.join(self.input_path, file) | |||
|                 for file in os.listdir(self.input_path) | |||
|                 if file.endswith('.png') or file.endswith('.jpg') or file.endswith('.webp') | |||
|             ] | |||
|         else: | |||
|             input_files = [self.input_path] | |||
|         print(f'\nTotal number of input images: {len(input_files)}') | |||
| 
 | |||
|         rembg_session = None if self.no_rembg else rembg.new_session() | |||
| 
 | |||
|         for idx, image_file in enumerate(input_files): | |||
|             if stop_event.is_set():  # Check if stop event is set | |||
|                 print("\nStopping pipeline sequence.") | |||
|                 break | |||
|             self.process_image(image_file, idx, len(input_files), rembg_session) | |||
|                          | |||
|     def _render_frames(self, planes, render_cameras, render_size=512, chunk_size=1): | |||
|         # Render frames from triplanes | |||
|         frames = [] | |||
|         for i in tqdm(range(0, render_cameras.shape[1], chunk_size)): | |||
|             if self.IS_FLEXICUBES: | |||
|                 frame = self.model.forward_geometry( | |||
|                     planes, | |||
|                     render_cameras[:, i:i + chunk_size], | |||
|                     render_size=render_size, | |||
|                 )['img'] | |||
|             else: | |||
|                 frame = self.model.forward_synthesizer( | |||
|                     planes, | |||
|                     render_cameras[:, i:i + chunk_size], | |||
|                     render_size=render_size, | |||
|                 )['images_rgb'] | |||
|             frames.append(frame) | |||
| 
 | |||
|         frames = torch.cat(frames, dim=1)[0]  # we suppose batch size is always 1 | |||
|         return frames | |||
| @ -0,0 +1,127 @@ | |||
| import argparse | |||
| from MeshRenderingPipeline import MeshRenderingPipeline | |||
| from watchdog.observers import Observer | |||
| from watchdog.events import FileSystemEventHandler | |||
| import threading | |||
| import time | |||
| import random | |||
| 
 | |||
| class ImageHandler(FileSystemEventHandler): | |||
|     def __init__(self, mesh_pipeline): | |||
|         self.mesh_pipeline = mesh_pipeline | |||
| 
 | |||
|     def on_created(self, event): | |||
|         if event.is_directory: | |||
|             return | |||
|         if event.src_path.endswith(('.png', '.jpg', '.webp')): | |||
|             print(f"New image detected: {event.src_path}") | |||
|             self.mesh_pipeline.process_image(event.src_path, 0, 1, None) | |||
| 
 | |||
| class App: | |||
|     def __init__(self, args): | |||
|         self.args = args | |||
|         self.running = False | |||
|         self.stop_event = threading.Event() | |||
|         self.thread = None | |||
|         self.input_thread = None | |||
|         self.observer = None | |||
|         self.mesh_pipeline = MeshRenderingPipeline( | |||
|             args.config, | |||
|             args.input_path, | |||
|             args.output_path, | |||
|             args.diffusion_steps, | |||
|             args.seed, | |||
|             args.scale, | |||
|             args.distance, | |||
|             args.view, | |||
|             args.no_rembg, | |||
|             args.export_texmap, | |||
|             args.save_video, | |||
|             args.gltf, | |||
|             args.remove, | |||
|             args.gltf_path | |||
|         ) | |||
| 
 | |||
|     def start_generation(self): | |||
|         self.running = True | |||
|         self.stop_event.clear() | |||
|         self.thread = threading.Thread(target=self._generate_objects) | |||
|         self.thread.start() | |||
| 
 | |||
|     def stop_generation(self): | |||
|         self.stop_event.set() | |||
|         self.running = False | |||
|         if self.thread: | |||
|             self.thread.join() | |||
|         self.stop_observer() | |||
|      | |||
|     def monitor_new_images(self): | |||
|         event_handler = ImageHandler(self.mesh_pipeline) | |||
|         self.observer = Observer() | |||
|         self.observer.schedule(event_handler, self.args.input_path, recursive=False) | |||
|         self.observer.start() | |||
|         print("\nWaiting for new image..") | |||
|         while not self.stop_event.is_set(): | |||
|                 time.sleep(1) | |||
|         self.stop_observer()  # Ensure the observer is stopped on interruption | |||
| 
 | |||
| 
 | |||
|     def stop_observer(self): | |||
|         if self.observer: | |||
|             self.observer.stop() | |||
|             self.observer.join() | |||
|             self.observer = None | |||
|      | |||
|     def _generate_objects(self): | |||
|         while not self.stop_event.is_set(): | |||
|             self.run_pipeline() | |||
|             time.sleep(1) | |||
|      | |||
| 
 | |||
|     def run_pipeline(self): | |||
|         self.mesh_pipeline.run_pipeline_sequence(self.stop_event) | |||
|         self.monitor_new_images() | |||
| 
 | |||
|     def run(self): | |||
|         self.input_thread = threading.Thread(target=self._handle_input) | |||
|         self.input_thread.start() | |||
|         self.input_thread.join() | |||
| 
 | |||
|     def _handle_input(self): | |||
|         while True: | |||
|             command = input("\nEnter a command, <start> <stop> <exit>: ") | |||
|             if command.lower() == 'exit': | |||
|                 print("Exiting the program.") | |||
|                 self.stop_generation() | |||
|                 break | |||
|             elif command.lower() == 'start': | |||
|                 if not self.running: | |||
|                     print("\nStarting continuous generation.") | |||
|                     self.start_generation() | |||
|                 else: | |||
|                     print("\nGeneration already running.") | |||
|             elif command.lower() == 'stop': | |||
|                 print("\nStopping continuous generation.") | |||
|                 self.stop_generation() | |||
| 
 | |||
| 
 | |||
| if __name__ == '__main__': | |||
|     parser = argparse.ArgumentParser() | |||
|     parser.add_argument('config', type=str, help='Path to config file.') | |||
|     parser.add_argument('input_path', type=str, help='Path to input image or directory.') | |||
|     parser.add_argument('--output_path', type=str, default='outputs/', help='Output directory.') | |||
|     parser.add_argument('--diffusion_steps', type=int, default=75, help='Denoising Sampling steps.') | |||
|     parser.add_argument('--seed', type=int, default=42, help='Random seed for sampling.') | |||
|     parser.add_argument('--scale', type=float, default=1.0, help='Scale of generated object.') | |||
|     parser.add_argument('--distance', type=float, default=4.5, help='Render distance.') | |||
|     parser.add_argument('--view', type=int, default=6, choices=[4, 6], help='Number of input views.') | |||
|     parser.add_argument('--no_rembg', action='store_true', help='Do not remove input background.') | |||
|     parser.add_argument('--export_texmap', action='store_true', help='Export a mesh with texture map.') | |||
|     parser.add_argument('--save_video', action='store_true', help='Save a circular-view video.') | |||
|     parser.add_argument('--gltf', action='store_true', help='Export a gtlf file.') | |||
|     parser.add_argument('--remove', action='store_true', help='Removes obj, mtl, texmap, nv files.') | |||
|     parser.add_argument('--gltf_path', type=str, default='C:/Users/caile/Desktop/InstantMesh/ex', help='Output directory.') | |||
|     args = parser.parse_args() | |||
| 
 | |||
|     app = App(args) | |||
|     app.run() | |||
| After Width: | Height: | Size: 79 KiB | 
					Loading…
					
					
				
		Reference in new issue