cailean
5 months ago
5 changed files with 406 additions and 1 deletions
@ -0,0 +1,259 @@ |
|||||
|
import os |
||||
|
import numpy as np |
||||
|
import torch |
||||
|
import rembg |
||||
|
from PIL import Image |
||||
|
from torchvision.transforms import v2 |
||||
|
from tqdm import tqdm |
||||
|
from huggingface_hub import hf_hub_download |
||||
|
from pytorch_lightning import seed_everything |
||||
|
from omegaconf import OmegaConf |
||||
|
from einops import rearrange |
||||
|
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler |
||||
|
import subprocess |
||||
|
|
||||
|
|
||||
|
from src.utils.train_util import instantiate_from_config |
||||
|
from src.utils.camera_util import ( |
||||
|
FOV_to_intrinsics, |
||||
|
get_zero123plus_input_cameras, |
||||
|
get_circular_camera_poses, |
||||
|
) |
||||
|
from src.utils.mesh_util import save_obj, save_obj_with_mtl, save_gltf |
||||
|
from src.utils.infer_util import remove_background, resize_foreground, save_video |
||||
|
|
||||
|
|
||||
|
class MeshRenderingPipeline: |
||||
|
def __init__(self, config_file, input_path, output_path='outputs/', diffusion_steps=75, |
||||
|
seed=42, scale=1.0, distance=4.5, view=6, no_rembg=False, export_texmap=False, |
||||
|
save_video=False, gltf=False, remove=False, gltf_path='C:/Users/caile/Desktop/'): |
||||
|
self.config_file = config_file |
||||
|
self.input_path = input_path |
||||
|
self.output_path = output_path |
||||
|
self.diffusion_steps = diffusion_steps |
||||
|
self.seed = seed |
||||
|
self.scale = scale |
||||
|
self.distance = distance |
||||
|
self.view = view |
||||
|
self.no_rembg = no_rembg |
||||
|
self.export_texmap = export_texmap |
||||
|
self.save_video = save_video |
||||
|
self.gltf = gltf |
||||
|
self.remove = remove |
||||
|
self.gltf_path = gltf_path |
||||
|
|
||||
|
# Parse configuration and setup |
||||
|
self._parse_config() |
||||
|
self._setup() |
||||
|
|
||||
|
def _parse_config(self): |
||||
|
# Parse configuration file |
||||
|
self.config = OmegaConf.load(self.config_file) |
||||
|
self.config_name = os.path.basename(self.config_file).replace('.yaml', '') |
||||
|
self.model_config = self.config.model_config |
||||
|
self.infer_config = self.config.infer_config |
||||
|
self.IS_FLEXICUBES = True if self.config_name.startswith('instant-mesh') else False |
||||
|
|
||||
|
def _setup(self): |
||||
|
# Seed for reproducibility |
||||
|
seed_everything(self.seed) |
||||
|
|
||||
|
# Device setup |
||||
|
self.device = torch.device('cuda') |
||||
|
|
||||
|
# Load diffusion model |
||||
|
print('Loading diffusion model ...') |
||||
|
self.pipeline = DiffusionPipeline.from_pretrained( |
||||
|
"sudo-ai/zero123plus-v1.2", |
||||
|
custom_pipeline="zero123plus", |
||||
|
torch_dtype=torch.float16, |
||||
|
) |
||||
|
self.pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( |
||||
|
self.pipeline.scheduler.config, timestep_spacing='trailing' |
||||
|
) |
||||
|
|
||||
|
# Load custom white-background UNet |
||||
|
print('Loading custom white-background unet ...') |
||||
|
if os.path.exists(self.infer_config.unet_path): |
||||
|
unet_ckpt_path = self.infer_config.unet_path |
||||
|
else: |
||||
|
unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", |
||||
|
repo_type="model") |
||||
|
state_dict = torch.load(unet_ckpt_path, map_location='cpu') |
||||
|
self.pipeline.unet.load_state_dict(state_dict, strict=True) |
||||
|
|
||||
|
self.pipeline = self.pipeline.to(self.device) |
||||
|
|
||||
|
# Load reconstruction model |
||||
|
print('Loading reconstruction model ...') |
||||
|
self.model = instantiate_from_config(self.model_config) |
||||
|
if os.path.exists(self.infer_config.model_path): |
||||
|
model_ckpt_path = self.infer_config.model_path |
||||
|
else: |
||||
|
model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", |
||||
|
filename=f"{self.config_name.replace('-', '_')}.ckpt", repo_type="model") |
||||
|
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict'] |
||||
|
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')} |
||||
|
self.model.load_state_dict(state_dict, strict=True) |
||||
|
|
||||
|
self.model = self.model.to(self.device) |
||||
|
if self.IS_FLEXICUBES: |
||||
|
self.model.init_flexicubes_geometry(self.device, fovy=30.0) |
||||
|
self.model = self.model.eval() |
||||
|
|
||||
|
# Make output directories |
||||
|
self.image_path = os.path.join(self.output_path, self.config_name, 'images') |
||||
|
self.mesh_path = os.path.join(self.output_path, self.config_name, 'meshes') |
||||
|
self.video_path = os.path.join(self.output_path, self.config_name, 'videos') |
||||
|
os.makedirs(self.image_path, exist_ok=True) |
||||
|
os.makedirs(self.mesh_path, exist_ok=True) |
||||
|
os.makedirs(self.video_path, exist_ok=True) |
||||
|
|
||||
|
def process_image(self, image_file, idx, total_num_files, rembg_session): |
||||
|
if rembg_session == None: |
||||
|
rembg_session = None if self.no_rembg else rembg.new_session() |
||||
|
|
||||
|
name = os.path.basename(image_file).split('.')[0] |
||||
|
print(f'[{idx + 1}/{total_num_files}] Creating novel viewpoints of {name} ...') |
||||
|
|
||||
|
# Remove background optionally |
||||
|
input_image = Image.open(image_file) |
||||
|
if not self.no_rembg: |
||||
|
input_image = remove_background(input_image, rembg_session) |
||||
|
input_image = resize_foreground(input_image, 0.85) |
||||
|
|
||||
|
# Sampling |
||||
|
output_image = self.pipeline( |
||||
|
input_image, |
||||
|
num_inference_steps=self.diffusion_steps, |
||||
|
).images[0] |
||||
|
|
||||
|
img_path = os.path.join(self.image_path, f'{name}.png') |
||||
|
output_image.save(img_path) |
||||
|
print(f"Image of viewpoints saved to {os.path.join(self.image_path, f'{name}.png')}") |
||||
|
|
||||
|
images = np.asarray(output_image, dtype=np.float32) / 255.0 |
||||
|
images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640) |
||||
|
images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320) |
||||
|
|
||||
|
input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0 * self.scale).to(self.device) |
||||
|
chunk_size = 20 if self.IS_FLEXICUBES else 1 |
||||
|
print(f'Creating {name} ...') |
||||
|
images = images.unsqueeze(0).to(self.device) |
||||
|
images = v2.functional.resize(images, 320, interpolation=3, antialias=True).clamp(0, 1) |
||||
|
|
||||
|
if self.view == 4: |
||||
|
indices = torch.tensor([0, 2, 4, 5]).long().to(self.device) |
||||
|
images = images[:, indices] |
||||
|
input_cameras = input_cameras[:, indices] |
||||
|
|
||||
|
with torch.no_grad(): |
||||
|
# Get triplane |
||||
|
planes = self.model.forward_planes(images, input_cameras) |
||||
|
|
||||
|
# Get mesh |
||||
|
mesh_path_idx = os.path.join(self.mesh_path, f'{name}.obj') |
||||
|
mtl_path_idx = os.path.join(self.mesh_path, f'{name}.mtl') |
||||
|
texmap_path_idx = os.path.join(self.mesh_path, f'{name}.png') |
||||
|
|
||||
|
mesh_out = self.model.extract_mesh( |
||||
|
planes, |
||||
|
use_texture_map=self.export_texmap, |
||||
|
**self.infer_config, |
||||
|
) |
||||
|
|
||||
|
|
||||
|
if self.export_texmap: |
||||
|
vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out |
||||
|
save_obj_with_mtl( |
||||
|
vertices.data.cpu().numpy(), |
||||
|
uvs.data.cpu().numpy(), |
||||
|
faces.data.cpu().numpy(), |
||||
|
mesh_tex_idx.data.cpu().numpy(), |
||||
|
tex_map.permute(1, 2, 0).data.cpu().numpy(), |
||||
|
mesh_path_idx, |
||||
|
) |
||||
|
else: |
||||
|
vertices, faces, vertex_colors = mesh_out |
||||
|
save_obj(vertices, faces, vertex_colors, mesh_path_idx) |
||||
|
|
||||
|
print(f"Mesh saved to {mesh_path_idx}") |
||||
|
|
||||
|
# Get video |
||||
|
if self.save_video: |
||||
|
video_path_idx = os.path.join(self.video_path, f'{name}.mp4') |
||||
|
render_size = self.infer_config.render_resolution |
||||
|
render_cameras = self._get_render_cameras() |
||||
|
|
||||
|
frames = self._render_frames( |
||||
|
planes, |
||||
|
render_cameras=render_cameras, |
||||
|
render_size=render_size, |
||||
|
chunk_size=chunk_size, |
||||
|
) |
||||
|
|
||||
|
save_video( |
||||
|
frames, |
||||
|
video_path_idx, |
||||
|
fps=30, |
||||
|
) |
||||
|
print(f"Video saved to {video_path_idx}") |
||||
|
|
||||
|
if self.gltf: |
||||
|
output_path = os.path.join(self.gltf_path, f'{name}.gltf') |
||||
|
command = f' obj2gltf -i {mesh_path_idx} -o {output_path}' |
||||
|
process = subprocess.run(command, shell=True, capture_output=True, text=True) |
||||
|
# Check if the process was successful |
||||
|
if process.returncode == 0: |
||||
|
print(f'Successfully converted {mesh_path_idx} to {output_path}') |
||||
|
else: |
||||
|
print(f'Error converting {mesh_path_idx}: {process.stderr}') |
||||
|
|
||||
|
if self.remove: |
||||
|
print(f'Removing files (mtl, obj, diffusion, texmap) for {name}') |
||||
|
os.remove(mesh_path_idx) |
||||
|
os.remove(mtl_path_idx) |
||||
|
os.remove(img_path) |
||||
|
os.remove(texmap_path_idx) |
||||
|
pass |
||||
|
|
||||
|
def run_pipeline_sequence(self, stop_event): |
||||
|
# Process input files |
||||
|
if os.path.isdir(self.input_path): |
||||
|
input_files = [ |
||||
|
os.path.join(self.input_path, file) |
||||
|
for file in os.listdir(self.input_path) |
||||
|
if file.endswith('.png') or file.endswith('.jpg') or file.endswith('.webp') |
||||
|
] |
||||
|
else: |
||||
|
input_files = [self.input_path] |
||||
|
print(f'\nTotal number of input images: {len(input_files)}') |
||||
|
|
||||
|
rembg_session = None if self.no_rembg else rembg.new_session() |
||||
|
|
||||
|
for idx, image_file in enumerate(input_files): |
||||
|
if stop_event.is_set(): # Check if stop event is set |
||||
|
print("\nStopping pipeline sequence.") |
||||
|
break |
||||
|
self.process_image(image_file, idx, len(input_files), rembg_session) |
||||
|
|
||||
|
def _render_frames(self, planes, render_cameras, render_size=512, chunk_size=1): |
||||
|
# Render frames from triplanes |
||||
|
frames = [] |
||||
|
for i in tqdm(range(0, render_cameras.shape[1], chunk_size)): |
||||
|
if self.IS_FLEXICUBES: |
||||
|
frame = self.model.forward_geometry( |
||||
|
planes, |
||||
|
render_cameras[:, i:i + chunk_size], |
||||
|
render_size=render_size, |
||||
|
)['img'] |
||||
|
else: |
||||
|
frame = self.model.forward_synthesizer( |
||||
|
planes, |
||||
|
render_cameras[:, i:i + chunk_size], |
||||
|
render_size=render_size, |
||||
|
)['images_rgb'] |
||||
|
frames.append(frame) |
||||
|
|
||||
|
frames = torch.cat(frames, dim=1)[0] # we suppose batch size is always 1 |
||||
|
return frames |
@ -0,0 +1,127 @@ |
|||||
|
import argparse |
||||
|
from MeshRenderingPipeline import MeshRenderingPipeline |
||||
|
from watchdog.observers import Observer |
||||
|
from watchdog.events import FileSystemEventHandler |
||||
|
import threading |
||||
|
import time |
||||
|
import random |
||||
|
|
||||
|
class ImageHandler(FileSystemEventHandler): |
||||
|
def __init__(self, mesh_pipeline): |
||||
|
self.mesh_pipeline = mesh_pipeline |
||||
|
|
||||
|
def on_created(self, event): |
||||
|
if event.is_directory: |
||||
|
return |
||||
|
if event.src_path.endswith(('.png', '.jpg', '.webp')): |
||||
|
print(f"New image detected: {event.src_path}") |
||||
|
self.mesh_pipeline.process_image(event.src_path, 0, 1, None) |
||||
|
|
||||
|
class App: |
||||
|
def __init__(self, args): |
||||
|
self.args = args |
||||
|
self.running = False |
||||
|
self.stop_event = threading.Event() |
||||
|
self.thread = None |
||||
|
self.input_thread = None |
||||
|
self.observer = None |
||||
|
self.mesh_pipeline = MeshRenderingPipeline( |
||||
|
args.config, |
||||
|
args.input_path, |
||||
|
args.output_path, |
||||
|
args.diffusion_steps, |
||||
|
args.seed, |
||||
|
args.scale, |
||||
|
args.distance, |
||||
|
args.view, |
||||
|
args.no_rembg, |
||||
|
args.export_texmap, |
||||
|
args.save_video, |
||||
|
args.gltf, |
||||
|
args.remove, |
||||
|
args.gltf_path |
||||
|
) |
||||
|
|
||||
|
def start_generation(self): |
||||
|
self.running = True |
||||
|
self.stop_event.clear() |
||||
|
self.thread = threading.Thread(target=self._generate_objects) |
||||
|
self.thread.start() |
||||
|
|
||||
|
def stop_generation(self): |
||||
|
self.stop_event.set() |
||||
|
self.running = False |
||||
|
if self.thread: |
||||
|
self.thread.join() |
||||
|
self.stop_observer() |
||||
|
|
||||
|
def monitor_new_images(self): |
||||
|
event_handler = ImageHandler(self.mesh_pipeline) |
||||
|
self.observer = Observer() |
||||
|
self.observer.schedule(event_handler, self.args.input_path, recursive=False) |
||||
|
self.observer.start() |
||||
|
print("\nWaiting for new image..") |
||||
|
while not self.stop_event.is_set(): |
||||
|
time.sleep(1) |
||||
|
self.stop_observer() # Ensure the observer is stopped on interruption |
||||
|
|
||||
|
|
||||
|
def stop_observer(self): |
||||
|
if self.observer: |
||||
|
self.observer.stop() |
||||
|
self.observer.join() |
||||
|
self.observer = None |
||||
|
|
||||
|
def _generate_objects(self): |
||||
|
while not self.stop_event.is_set(): |
||||
|
self.run_pipeline() |
||||
|
time.sleep(1) |
||||
|
|
||||
|
|
||||
|
def run_pipeline(self): |
||||
|
self.mesh_pipeline.run_pipeline_sequence(self.stop_event) |
||||
|
self.monitor_new_images() |
||||
|
|
||||
|
def run(self): |
||||
|
self.input_thread = threading.Thread(target=self._handle_input) |
||||
|
self.input_thread.start() |
||||
|
self.input_thread.join() |
||||
|
|
||||
|
def _handle_input(self): |
||||
|
while True: |
||||
|
command = input("\nEnter a command, <start> <stop> <exit>: ") |
||||
|
if command.lower() == 'exit': |
||||
|
print("Exiting the program.") |
||||
|
self.stop_generation() |
||||
|
break |
||||
|
elif command.lower() == 'start': |
||||
|
if not self.running: |
||||
|
print("\nStarting continuous generation.") |
||||
|
self.start_generation() |
||||
|
else: |
||||
|
print("\nGeneration already running.") |
||||
|
elif command.lower() == 'stop': |
||||
|
print("\nStopping continuous generation.") |
||||
|
self.stop_generation() |
||||
|
|
||||
|
|
||||
|
if __name__ == '__main__': |
||||
|
parser = argparse.ArgumentParser() |
||||
|
parser.add_argument('config', type=str, help='Path to config file.') |
||||
|
parser.add_argument('input_path', type=str, help='Path to input image or directory.') |
||||
|
parser.add_argument('--output_path', type=str, default='outputs/', help='Output directory.') |
||||
|
parser.add_argument('--diffusion_steps', type=int, default=75, help='Denoising Sampling steps.') |
||||
|
parser.add_argument('--seed', type=int, default=42, help='Random seed for sampling.') |
||||
|
parser.add_argument('--scale', type=float, default=1.0, help='Scale of generated object.') |
||||
|
parser.add_argument('--distance', type=float, default=4.5, help='Render distance.') |
||||
|
parser.add_argument('--view', type=int, default=6, choices=[4, 6], help='Number of input views.') |
||||
|
parser.add_argument('--no_rembg', action='store_true', help='Do not remove input background.') |
||||
|
parser.add_argument('--export_texmap', action='store_true', help='Export a mesh with texture map.') |
||||
|
parser.add_argument('--save_video', action='store_true', help='Save a circular-view video.') |
||||
|
parser.add_argument('--gltf', action='store_true', help='Export a gtlf file.') |
||||
|
parser.add_argument('--remove', action='store_true', help='Removes obj, mtl, texmap, nv files.') |
||||
|
parser.add_argument('--gltf_path', type=str, default='C:/Users/caile/Desktop/InstantMesh/ex', help='Output directory.') |
||||
|
args = parser.parse_args() |
||||
|
|
||||
|
app = App(args) |
||||
|
app.run() |
After Width: | Height: | Size: 79 KiB |
Loading…
Reference in new issue