Browse Source

init commit

main
cailean 5 months ago
parent
commit
228de6a636
  1. 259
      MeshRenderingPipeline.py
  2. 127
      application.py
  3. 2
      configs/instant-mesh-large.yaml
  4. BIN
      examples/tayq.jpg
  5. 19
      src/utils/mesh_util.py

259
MeshRenderingPipeline.py

@ -0,0 +1,259 @@
import os
import numpy as np
import torch
import rembg
from PIL import Image
from torchvision.transforms import v2
from tqdm import tqdm
from huggingface_hub import hf_hub_download
from pytorch_lightning import seed_everything
from omegaconf import OmegaConf
from einops import rearrange
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
import subprocess
from src.utils.train_util import instantiate_from_config
from src.utils.camera_util import (
FOV_to_intrinsics,
get_zero123plus_input_cameras,
get_circular_camera_poses,
)
from src.utils.mesh_util import save_obj, save_obj_with_mtl, save_gltf
from src.utils.infer_util import remove_background, resize_foreground, save_video
class MeshRenderingPipeline:
def __init__(self, config_file, input_path, output_path='outputs/', diffusion_steps=75,
seed=42, scale=1.0, distance=4.5, view=6, no_rembg=False, export_texmap=False,
save_video=False, gltf=False, remove=False, gltf_path='C:/Users/caile/Desktop/'):
self.config_file = config_file
self.input_path = input_path
self.output_path = output_path
self.diffusion_steps = diffusion_steps
self.seed = seed
self.scale = scale
self.distance = distance
self.view = view
self.no_rembg = no_rembg
self.export_texmap = export_texmap
self.save_video = save_video
self.gltf = gltf
self.remove = remove
self.gltf_path = gltf_path
# Parse configuration and setup
self._parse_config()
self._setup()
def _parse_config(self):
# Parse configuration file
self.config = OmegaConf.load(self.config_file)
self.config_name = os.path.basename(self.config_file).replace('.yaml', '')
self.model_config = self.config.model_config
self.infer_config = self.config.infer_config
self.IS_FLEXICUBES = True if self.config_name.startswith('instant-mesh') else False
def _setup(self):
# Seed for reproducibility
seed_everything(self.seed)
# Device setup
self.device = torch.device('cuda')
# Load diffusion model
print('Loading diffusion model ...')
self.pipeline = DiffusionPipeline.from_pretrained(
"sudo-ai/zero123plus-v1.2",
custom_pipeline="zero123plus",
torch_dtype=torch.float16,
)
self.pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
self.pipeline.scheduler.config, timestep_spacing='trailing'
)
# Load custom white-background UNet
print('Loading custom white-background unet ...')
if os.path.exists(self.infer_config.unet_path):
unet_ckpt_path = self.infer_config.unet_path
else:
unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin",
repo_type="model")
state_dict = torch.load(unet_ckpt_path, map_location='cpu')
self.pipeline.unet.load_state_dict(state_dict, strict=True)
self.pipeline = self.pipeline.to(self.device)
# Load reconstruction model
print('Loading reconstruction model ...')
self.model = instantiate_from_config(self.model_config)
if os.path.exists(self.infer_config.model_path):
model_ckpt_path = self.infer_config.model_path
else:
model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh",
filename=f"{self.config_name.replace('-', '_')}.ckpt", repo_type="model")
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
self.model.load_state_dict(state_dict, strict=True)
self.model = self.model.to(self.device)
if self.IS_FLEXICUBES:
self.model.init_flexicubes_geometry(self.device, fovy=30.0)
self.model = self.model.eval()
# Make output directories
self.image_path = os.path.join(self.output_path, self.config_name, 'images')
self.mesh_path = os.path.join(self.output_path, self.config_name, 'meshes')
self.video_path = os.path.join(self.output_path, self.config_name, 'videos')
os.makedirs(self.image_path, exist_ok=True)
os.makedirs(self.mesh_path, exist_ok=True)
os.makedirs(self.video_path, exist_ok=True)
def process_image(self, image_file, idx, total_num_files, rembg_session):
if rembg_session == None:
rembg_session = None if self.no_rembg else rembg.new_session()
name = os.path.basename(image_file).split('.')[0]
print(f'[{idx + 1}/{total_num_files}] Creating novel viewpoints of {name} ...')
# Remove background optionally
input_image = Image.open(image_file)
if not self.no_rembg:
input_image = remove_background(input_image, rembg_session)
input_image = resize_foreground(input_image, 0.85)
# Sampling
output_image = self.pipeline(
input_image,
num_inference_steps=self.diffusion_steps,
).images[0]
img_path = os.path.join(self.image_path, f'{name}.png')
output_image.save(img_path)
print(f"Image of viewpoints saved to {os.path.join(self.image_path, f'{name}.png')}")
images = np.asarray(output_image, dtype=np.float32) / 255.0
images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320)
input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0 * self.scale).to(self.device)
chunk_size = 20 if self.IS_FLEXICUBES else 1
print(f'Creating {name} ...')
images = images.unsqueeze(0).to(self.device)
images = v2.functional.resize(images, 320, interpolation=3, antialias=True).clamp(0, 1)
if self.view == 4:
indices = torch.tensor([0, 2, 4, 5]).long().to(self.device)
images = images[:, indices]
input_cameras = input_cameras[:, indices]
with torch.no_grad():
# Get triplane
planes = self.model.forward_planes(images, input_cameras)
# Get mesh
mesh_path_idx = os.path.join(self.mesh_path, f'{name}.obj')
mtl_path_idx = os.path.join(self.mesh_path, f'{name}.mtl')
texmap_path_idx = os.path.join(self.mesh_path, f'{name}.png')
mesh_out = self.model.extract_mesh(
planes,
use_texture_map=self.export_texmap,
**self.infer_config,
)
if self.export_texmap:
vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
save_obj_with_mtl(
vertices.data.cpu().numpy(),
uvs.data.cpu().numpy(),
faces.data.cpu().numpy(),
mesh_tex_idx.data.cpu().numpy(),
tex_map.permute(1, 2, 0).data.cpu().numpy(),
mesh_path_idx,
)
else:
vertices, faces, vertex_colors = mesh_out
save_obj(vertices, faces, vertex_colors, mesh_path_idx)
print(f"Mesh saved to {mesh_path_idx}")
# Get video
if self.save_video:
video_path_idx = os.path.join(self.video_path, f'{name}.mp4')
render_size = self.infer_config.render_resolution
render_cameras = self._get_render_cameras()
frames = self._render_frames(
planes,
render_cameras=render_cameras,
render_size=render_size,
chunk_size=chunk_size,
)
save_video(
frames,
video_path_idx,
fps=30,
)
print(f"Video saved to {video_path_idx}")
if self.gltf:
output_path = os.path.join(self.gltf_path, f'{name}.gltf')
command = f' obj2gltf -i {mesh_path_idx} -o {output_path}'
process = subprocess.run(command, shell=True, capture_output=True, text=True)
# Check if the process was successful
if process.returncode == 0:
print(f'Successfully converted {mesh_path_idx} to {output_path}')
else:
print(f'Error converting {mesh_path_idx}: {process.stderr}')
if self.remove:
print(f'Removing files (mtl, obj, diffusion, texmap) for {name}')
os.remove(mesh_path_idx)
os.remove(mtl_path_idx)
os.remove(img_path)
os.remove(texmap_path_idx)
pass
def run_pipeline_sequence(self, stop_event):
# Process input files
if os.path.isdir(self.input_path):
input_files = [
os.path.join(self.input_path, file)
for file in os.listdir(self.input_path)
if file.endswith('.png') or file.endswith('.jpg') or file.endswith('.webp')
]
else:
input_files = [self.input_path]
print(f'\nTotal number of input images: {len(input_files)}')
rembg_session = None if self.no_rembg else rembg.new_session()
for idx, image_file in enumerate(input_files):
if stop_event.is_set(): # Check if stop event is set
print("\nStopping pipeline sequence.")
break
self.process_image(image_file, idx, len(input_files), rembg_session)
def _render_frames(self, planes, render_cameras, render_size=512, chunk_size=1):
# Render frames from triplanes
frames = []
for i in tqdm(range(0, render_cameras.shape[1], chunk_size)):
if self.IS_FLEXICUBES:
frame = self.model.forward_geometry(
planes,
render_cameras[:, i:i + chunk_size],
render_size=render_size,
)['img']
else:
frame = self.model.forward_synthesizer(
planes,
render_cameras[:, i:i + chunk_size],
render_size=render_size,
)['images_rgb']
frames.append(frame)
frames = torch.cat(frames, dim=1)[0] # we suppose batch size is always 1
return frames

127
application.py

@ -0,0 +1,127 @@
import argparse
from MeshRenderingPipeline import MeshRenderingPipeline
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
import threading
import time
import random
class ImageHandler(FileSystemEventHandler):
def __init__(self, mesh_pipeline):
self.mesh_pipeline = mesh_pipeline
def on_created(self, event):
if event.is_directory:
return
if event.src_path.endswith(('.png', '.jpg', '.webp')):
print(f"New image detected: {event.src_path}")
self.mesh_pipeline.process_image(event.src_path, 0, 1, None)
class App:
def __init__(self, args):
self.args = args
self.running = False
self.stop_event = threading.Event()
self.thread = None
self.input_thread = None
self.observer = None
self.mesh_pipeline = MeshRenderingPipeline(
args.config,
args.input_path,
args.output_path,
args.diffusion_steps,
args.seed,
args.scale,
args.distance,
args.view,
args.no_rembg,
args.export_texmap,
args.save_video,
args.gltf,
args.remove,
args.gltf_path
)
def start_generation(self):
self.running = True
self.stop_event.clear()
self.thread = threading.Thread(target=self._generate_objects)
self.thread.start()
def stop_generation(self):
self.stop_event.set()
self.running = False
if self.thread:
self.thread.join()
self.stop_observer()
def monitor_new_images(self):
event_handler = ImageHandler(self.mesh_pipeline)
self.observer = Observer()
self.observer.schedule(event_handler, self.args.input_path, recursive=False)
self.observer.start()
print("\nWaiting for new image..")
while not self.stop_event.is_set():
time.sleep(1)
self.stop_observer() # Ensure the observer is stopped on interruption
def stop_observer(self):
if self.observer:
self.observer.stop()
self.observer.join()
self.observer = None
def _generate_objects(self):
while not self.stop_event.is_set():
self.run_pipeline()
time.sleep(1)
def run_pipeline(self):
self.mesh_pipeline.run_pipeline_sequence(self.stop_event)
self.monitor_new_images()
def run(self):
self.input_thread = threading.Thread(target=self._handle_input)
self.input_thread.start()
self.input_thread.join()
def _handle_input(self):
while True:
command = input("\nEnter a command, <start> <stop> <exit>: ")
if command.lower() == 'exit':
print("Exiting the program.")
self.stop_generation()
break
elif command.lower() == 'start':
if not self.running:
print("\nStarting continuous generation.")
self.start_generation()
else:
print("\nGeneration already running.")
elif command.lower() == 'stop':
print("\nStopping continuous generation.")
self.stop_generation()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('config', type=str, help='Path to config file.')
parser.add_argument('input_path', type=str, help='Path to input image or directory.')
parser.add_argument('--output_path', type=str, default='outputs/', help='Output directory.')
parser.add_argument('--diffusion_steps', type=int, default=75, help='Denoising Sampling steps.')
parser.add_argument('--seed', type=int, default=42, help='Random seed for sampling.')
parser.add_argument('--scale', type=float, default=1.0, help='Scale of generated object.')
parser.add_argument('--distance', type=float, default=4.5, help='Render distance.')
parser.add_argument('--view', type=int, default=6, choices=[4, 6], help='Number of input views.')
parser.add_argument('--no_rembg', action='store_true', help='Do not remove input background.')
parser.add_argument('--export_texmap', action='store_true', help='Export a mesh with texture map.')
parser.add_argument('--save_video', action='store_true', help='Save a circular-view video.')
parser.add_argument('--gltf', action='store_true', help='Export a gtlf file.')
parser.add_argument('--remove', action='store_true', help='Removes obj, mtl, texmap, nv files.')
parser.add_argument('--gltf_path', type=str, default='C:/Users/caile/Desktop/InstantMesh/ex', help='Output directory.')
args = parser.parse_args()
app = App(args)
app.run()

2
configs/instant-mesh-large.yaml

@ -18,5 +18,5 @@ model_config:
infer_config: infer_config:
unet_path: ckpts/diffusion_pytorch_model.bin unet_path: ckpts/diffusion_pytorch_model.bin
model_path: ckpts/instant_mesh_large.ckpt model_path: ckpts/instant_mesh_large.ckpt
texture_resolution: 1024 texture_resolution: 512
render_resolution: 512 render_resolution: 512

BIN
examples/tayq.jpg

Binary file not shown.

After

Width:  |  Height:  |  Size: 79 KiB

19
src/utils/mesh_util.py

@ -39,9 +39,28 @@ def save_glb(pointnp_px3, facenp_fx3, colornp_px3, fpath):
) )
mesh.export(fpath, 'glb') mesh.export(fpath, 'glb')
def save_gltf(pointnp_px3, facenp_fx3, colornp_px3, fpath):
# Transform the points and faces to match the desired coordinate system
pointnp_px3 = pointnp_px3 @ np.array([[1, 0, 0], [0, 1, 0], [0, 0, -1]])
facenp_fx3 = facenp_fx3[:, [2, 1, 0]]
# Create a Trimesh object
mesh = trimesh.Trimesh(
vertices=pointnp_px3,
faces=facenp_fx3,
vertex_colors=colornp_px3,
)
# Export the mesh to a GLTF file
# mesh.export(fpath, 'gltf')
trimesh_scene = trimesh.Scene(geometry=mesh)
with open(fpath, 'wb') as f:
f.write(trimesh.exchange.gltf.export_glb(trimesh_scene, include_normals=True))
def save_obj_with_mtl(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, texmap_hxwx3, fname): def save_obj_with_mtl(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, texmap_hxwx3, fname):
import os import os
fol, na = os.path.split(fname) fol, na = os.path.split(fname)
na, _ = os.path.splitext(na) na, _ = os.path.splitext(na)

Loading…
Cancel
Save