diff --git a/MeshRenderingPipeline.py b/MeshRenderingPipeline.py new file mode 100644 index 0000000..40e65d1 --- /dev/null +++ b/MeshRenderingPipeline.py @@ -0,0 +1,259 @@ +import os +import numpy as np +import torch +import rembg +from PIL import Image +from torchvision.transforms import v2 +from tqdm import tqdm +from huggingface_hub import hf_hub_download +from pytorch_lightning import seed_everything +from omegaconf import OmegaConf +from einops import rearrange +from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler +import subprocess + + +from src.utils.train_util import instantiate_from_config +from src.utils.camera_util import ( + FOV_to_intrinsics, + get_zero123plus_input_cameras, + get_circular_camera_poses, +) +from src.utils.mesh_util import save_obj, save_obj_with_mtl, save_gltf +from src.utils.infer_util import remove_background, resize_foreground, save_video + + +class MeshRenderingPipeline: + def __init__(self, config_file, input_path, output_path='outputs/', diffusion_steps=75, + seed=42, scale=1.0, distance=4.5, view=6, no_rembg=False, export_texmap=False, + save_video=False, gltf=False, remove=False, gltf_path='C:/Users/caile/Desktop/'): + self.config_file = config_file + self.input_path = input_path + self.output_path = output_path + self.diffusion_steps = diffusion_steps + self.seed = seed + self.scale = scale + self.distance = distance + self.view = view + self.no_rembg = no_rembg + self.export_texmap = export_texmap + self.save_video = save_video + self.gltf = gltf + self.remove = remove + self.gltf_path = gltf_path + + # Parse configuration and setup + self._parse_config() + self._setup() + + def _parse_config(self): + # Parse configuration file + self.config = OmegaConf.load(self.config_file) + self.config_name = os.path.basename(self.config_file).replace('.yaml', '') + self.model_config = self.config.model_config + self.infer_config = self.config.infer_config + self.IS_FLEXICUBES = True if self.config_name.startswith('instant-mesh') else False + + def _setup(self): + # Seed for reproducibility + seed_everything(self.seed) + + # Device setup + self.device = torch.device('cuda') + + # Load diffusion model + print('Loading diffusion model ...') + self.pipeline = DiffusionPipeline.from_pretrained( + "sudo-ai/zero123plus-v1.2", + custom_pipeline="zero123plus", + torch_dtype=torch.float16, + ) + self.pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( + self.pipeline.scheduler.config, timestep_spacing='trailing' + ) + + # Load custom white-background UNet + print('Loading custom white-background unet ...') + if os.path.exists(self.infer_config.unet_path): + unet_ckpt_path = self.infer_config.unet_path + else: + unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", + repo_type="model") + state_dict = torch.load(unet_ckpt_path, map_location='cpu') + self.pipeline.unet.load_state_dict(state_dict, strict=True) + + self.pipeline = self.pipeline.to(self.device) + + # Load reconstruction model + print('Loading reconstruction model ...') + self.model = instantiate_from_config(self.model_config) + if os.path.exists(self.infer_config.model_path): + model_ckpt_path = self.infer_config.model_path + else: + model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", + filename=f"{self.config_name.replace('-', '_')}.ckpt", repo_type="model") + state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict'] + state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')} + self.model.load_state_dict(state_dict, strict=True) + + self.model = self.model.to(self.device) + if self.IS_FLEXICUBES: + self.model.init_flexicubes_geometry(self.device, fovy=30.0) + self.model = self.model.eval() + + # Make output directories + self.image_path = os.path.join(self.output_path, self.config_name, 'images') + self.mesh_path = os.path.join(self.output_path, self.config_name, 'meshes') + self.video_path = os.path.join(self.output_path, self.config_name, 'videos') + os.makedirs(self.image_path, exist_ok=True) + os.makedirs(self.mesh_path, exist_ok=True) + os.makedirs(self.video_path, exist_ok=True) + + def process_image(self, image_file, idx, total_num_files, rembg_session): + if rembg_session == None: + rembg_session = None if self.no_rembg else rembg.new_session() + + name = os.path.basename(image_file).split('.')[0] + print(f'[{idx + 1}/{total_num_files}] Creating novel viewpoints of {name} ...') + + # Remove background optionally + input_image = Image.open(image_file) + if not self.no_rembg: + input_image = remove_background(input_image, rembg_session) + input_image = resize_foreground(input_image, 0.85) + + # Sampling + output_image = self.pipeline( + input_image, + num_inference_steps=self.diffusion_steps, + ).images[0] + + img_path = os.path.join(self.image_path, f'{name}.png') + output_image.save(img_path) + print(f"Image of viewpoints saved to {os.path.join(self.image_path, f'{name}.png')}") + + images = np.asarray(output_image, dtype=np.float32) / 255.0 + images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640) + images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320) + + input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0 * self.scale).to(self.device) + chunk_size = 20 if self.IS_FLEXICUBES else 1 + print(f'Creating {name} ...') + images = images.unsqueeze(0).to(self.device) + images = v2.functional.resize(images, 320, interpolation=3, antialias=True).clamp(0, 1) + + if self.view == 4: + indices = torch.tensor([0, 2, 4, 5]).long().to(self.device) + images = images[:, indices] + input_cameras = input_cameras[:, indices] + + with torch.no_grad(): + # Get triplane + planes = self.model.forward_planes(images, input_cameras) + + # Get mesh + mesh_path_idx = os.path.join(self.mesh_path, f'{name}.obj') + mtl_path_idx = os.path.join(self.mesh_path, f'{name}.mtl') + texmap_path_idx = os.path.join(self.mesh_path, f'{name}.png') + + mesh_out = self.model.extract_mesh( + planes, + use_texture_map=self.export_texmap, + **self.infer_config, + ) + + + if self.export_texmap: + vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out + save_obj_with_mtl( + vertices.data.cpu().numpy(), + uvs.data.cpu().numpy(), + faces.data.cpu().numpy(), + mesh_tex_idx.data.cpu().numpy(), + tex_map.permute(1, 2, 0).data.cpu().numpy(), + mesh_path_idx, + ) + else: + vertices, faces, vertex_colors = mesh_out + save_obj(vertices, faces, vertex_colors, mesh_path_idx) + + print(f"Mesh saved to {mesh_path_idx}") + + # Get video + if self.save_video: + video_path_idx = os.path.join(self.video_path, f'{name}.mp4') + render_size = self.infer_config.render_resolution + render_cameras = self._get_render_cameras() + + frames = self._render_frames( + planes, + render_cameras=render_cameras, + render_size=render_size, + chunk_size=chunk_size, + ) + + save_video( + frames, + video_path_idx, + fps=30, + ) + print(f"Video saved to {video_path_idx}") + + if self.gltf: + output_path = os.path.join(self.gltf_path, f'{name}.gltf') + command = f' obj2gltf -i {mesh_path_idx} -o {output_path}' + process = subprocess.run(command, shell=True, capture_output=True, text=True) + # Check if the process was successful + if process.returncode == 0: + print(f'Successfully converted {mesh_path_idx} to {output_path}') + else: + print(f'Error converting {mesh_path_idx}: {process.stderr}') + + if self.remove: + print(f'Removing files (mtl, obj, diffusion, texmap) for {name}') + os.remove(mesh_path_idx) + os.remove(mtl_path_idx) + os.remove(img_path) + os.remove(texmap_path_idx) + pass + + def run_pipeline_sequence(self, stop_event): + # Process input files + if os.path.isdir(self.input_path): + input_files = [ + os.path.join(self.input_path, file) + for file in os.listdir(self.input_path) + if file.endswith('.png') or file.endswith('.jpg') or file.endswith('.webp') + ] + else: + input_files = [self.input_path] + print(f'\nTotal number of input images: {len(input_files)}') + + rembg_session = None if self.no_rembg else rembg.new_session() + + for idx, image_file in enumerate(input_files): + if stop_event.is_set(): # Check if stop event is set + print("\nStopping pipeline sequence.") + break + self.process_image(image_file, idx, len(input_files), rembg_session) + + def _render_frames(self, planes, render_cameras, render_size=512, chunk_size=1): + # Render frames from triplanes + frames = [] + for i in tqdm(range(0, render_cameras.shape[1], chunk_size)): + if self.IS_FLEXICUBES: + frame = self.model.forward_geometry( + planes, + render_cameras[:, i:i + chunk_size], + render_size=render_size, + )['img'] + else: + frame = self.model.forward_synthesizer( + planes, + render_cameras[:, i:i + chunk_size], + render_size=render_size, + )['images_rgb'] + frames.append(frame) + + frames = torch.cat(frames, dim=1)[0] # we suppose batch size is always 1 + return frames diff --git a/application.py b/application.py new file mode 100644 index 0000000..52686f1 --- /dev/null +++ b/application.py @@ -0,0 +1,127 @@ +import argparse +from MeshRenderingPipeline import MeshRenderingPipeline +from watchdog.observers import Observer +from watchdog.events import FileSystemEventHandler +import threading +import time +import random + +class ImageHandler(FileSystemEventHandler): + def __init__(self, mesh_pipeline): + self.mesh_pipeline = mesh_pipeline + + def on_created(self, event): + if event.is_directory: + return + if event.src_path.endswith(('.png', '.jpg', '.webp')): + print(f"New image detected: {event.src_path}") + self.mesh_pipeline.process_image(event.src_path, 0, 1, None) + +class App: + def __init__(self, args): + self.args = args + self.running = False + self.stop_event = threading.Event() + self.thread = None + self.input_thread = None + self.observer = None + self.mesh_pipeline = MeshRenderingPipeline( + args.config, + args.input_path, + args.output_path, + args.diffusion_steps, + args.seed, + args.scale, + args.distance, + args.view, + args.no_rembg, + args.export_texmap, + args.save_video, + args.gltf, + args.remove, + args.gltf_path + ) + + def start_generation(self): + self.running = True + self.stop_event.clear() + self.thread = threading.Thread(target=self._generate_objects) + self.thread.start() + + def stop_generation(self): + self.stop_event.set() + self.running = False + if self.thread: + self.thread.join() + self.stop_observer() + + def monitor_new_images(self): + event_handler = ImageHandler(self.mesh_pipeline) + self.observer = Observer() + self.observer.schedule(event_handler, self.args.input_path, recursive=False) + self.observer.start() + print("\nWaiting for new image..") + while not self.stop_event.is_set(): + time.sleep(1) + self.stop_observer() # Ensure the observer is stopped on interruption + + + def stop_observer(self): + if self.observer: + self.observer.stop() + self.observer.join() + self.observer = None + + def _generate_objects(self): + while not self.stop_event.is_set(): + self.run_pipeline() + time.sleep(1) + + + def run_pipeline(self): + self.mesh_pipeline.run_pipeline_sequence(self.stop_event) + self.monitor_new_images() + + def run(self): + self.input_thread = threading.Thread(target=self._handle_input) + self.input_thread.start() + self.input_thread.join() + + def _handle_input(self): + while True: + command = input("\nEnter a command, : ") + if command.lower() == 'exit': + print("Exiting the program.") + self.stop_generation() + break + elif command.lower() == 'start': + if not self.running: + print("\nStarting continuous generation.") + self.start_generation() + else: + print("\nGeneration already running.") + elif command.lower() == 'stop': + print("\nStopping continuous generation.") + self.stop_generation() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('config', type=str, help='Path to config file.') + parser.add_argument('input_path', type=str, help='Path to input image or directory.') + parser.add_argument('--output_path', type=str, default='outputs/', help='Output directory.') + parser.add_argument('--diffusion_steps', type=int, default=75, help='Denoising Sampling steps.') + parser.add_argument('--seed', type=int, default=42, help='Random seed for sampling.') + parser.add_argument('--scale', type=float, default=1.0, help='Scale of generated object.') + parser.add_argument('--distance', type=float, default=4.5, help='Render distance.') + parser.add_argument('--view', type=int, default=6, choices=[4, 6], help='Number of input views.') + parser.add_argument('--no_rembg', action='store_true', help='Do not remove input background.') + parser.add_argument('--export_texmap', action='store_true', help='Export a mesh with texture map.') + parser.add_argument('--save_video', action='store_true', help='Save a circular-view video.') + parser.add_argument('--gltf', action='store_true', help='Export a gtlf file.') + parser.add_argument('--remove', action='store_true', help='Removes obj, mtl, texmap, nv files.') + parser.add_argument('--gltf_path', type=str, default='C:/Users/caile/Desktop/InstantMesh/ex', help='Output directory.') + args = parser.parse_args() + + app = App(args) + app.run() \ No newline at end of file diff --git a/configs/instant-mesh-large.yaml b/configs/instant-mesh-large.yaml index e296bc8..066240d 100644 --- a/configs/instant-mesh-large.yaml +++ b/configs/instant-mesh-large.yaml @@ -18,5 +18,5 @@ model_config: infer_config: unet_path: ckpts/diffusion_pytorch_model.bin model_path: ckpts/instant_mesh_large.ckpt - texture_resolution: 1024 + texture_resolution: 512 render_resolution: 512 \ No newline at end of file diff --git a/examples/tayq.jpg b/examples/tayq.jpg new file mode 100644 index 0000000..5909390 Binary files /dev/null and b/examples/tayq.jpg differ diff --git a/src/utils/mesh_util.py b/src/utils/mesh_util.py index 0ec4663..1ce3d15 100644 --- a/src/utils/mesh_util.py +++ b/src/utils/mesh_util.py @@ -39,9 +39,28 @@ def save_glb(pointnp_px3, facenp_fx3, colornp_px3, fpath): ) mesh.export(fpath, 'glb') +def save_gltf(pointnp_px3, facenp_fx3, colornp_px3, fpath): + # Transform the points and faces to match the desired coordinate system + pointnp_px3 = pointnp_px3 @ np.array([[1, 0, 0], [0, 1, 0], [0, 0, -1]]) + facenp_fx3 = facenp_fx3[:, [2, 1, 0]] + + # Create a Trimesh object + mesh = trimesh.Trimesh( + vertices=pointnp_px3, + faces=facenp_fx3, + vertex_colors=colornp_px3, + ) + + # Export the mesh to a GLTF file + # mesh.export(fpath, 'gltf') + trimesh_scene = trimesh.Scene(geometry=mesh) + with open(fpath, 'wb') as f: + f.write(trimesh.exchange.gltf.export_glb(trimesh_scene, include_normals=True)) def save_obj_with_mtl(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, texmap_hxwx3, fname): import os + + fol, na = os.path.split(fname) na, _ = os.path.splitext(na)