|
|
|
import os
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
import rembg
|
|
|
|
import time
|
|
|
|
from PIL import Image
|
|
|
|
from torchvision.transforms import v2
|
|
|
|
from tqdm import tqdm
|
|
|
|
from huggingface_hub import hf_hub_download
|
|
|
|
from pytorch_lightning import seed_everything
|
|
|
|
from omegaconf import OmegaConf
|
|
|
|
from einops import rearrange
|
|
|
|
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
|
|
|
|
import subprocess
|
|
|
|
from ObjSender import ObjSender
|
|
|
|
|
|
|
|
from src.utils.train_util import instantiate_from_config
|
|
|
|
from src.utils.camera_util import (
|
|
|
|
FOV_to_intrinsics,
|
|
|
|
get_zero123plus_input_cameras,
|
|
|
|
get_circular_camera_poses,
|
|
|
|
)
|
|
|
|
from src.utils.mesh_util import save_obj, save_obj_with_mtl, save_gltf
|
|
|
|
from src.utils.infer_util import remove_background, resize_foreground, save_video
|
|
|
|
|
|
|
|
|
|
|
|
class MeshRenderingPipeline:
|
|
|
|
def __init__(self, config_file, input_path, output_path='outputs/', diffusion_steps=75,
|
|
|
|
seed=42, scale=1.0, distance=4.5, view=4, no_rembg=False, export_texmap=False,
|
|
|
|
save_video=False, gltf=False, remove=False, gltf_path='C:/Users/caile/Desktop/', local=False):
|
|
|
|
self.config_file = config_file
|
|
|
|
self.input_path = input_path
|
|
|
|
self.output_path = output_path
|
|
|
|
self.diffusion_steps = diffusion_steps
|
|
|
|
self.seed = seed
|
|
|
|
self.scale = scale
|
|
|
|
self.distance = distance
|
|
|
|
self.view = view
|
|
|
|
self.no_rembg = no_rembg
|
|
|
|
self.export_texmap = export_texmap
|
|
|
|
self.save_video = save_video
|
|
|
|
self.gltf = gltf
|
|
|
|
self.remove = remove
|
|
|
|
self.gltf_path = gltf_path
|
|
|
|
self.sender = ObjSender("localhost", "3000")
|
|
|
|
self.local = local
|
|
|
|
|
|
|
|
# Parse configuration and setup
|
|
|
|
self._parse_config()
|
|
|
|
self._setup()
|
|
|
|
|
|
|
|
def _parse_config(self):
|
|
|
|
# Parse configuration file
|
|
|
|
self.config = OmegaConf.load(self.config_file)
|
|
|
|
self.config_name = os.path.basename(self.config_file).replace('.yaml', '')
|
|
|
|
self.model_config = self.config.model_config
|
|
|
|
self.infer_config = self.config.infer_config
|
|
|
|
self.IS_FLEXICUBES = True if self.config_name.startswith('instant-mesh') else False
|
|
|
|
|
|
|
|
def _setup(self):
|
|
|
|
# Seed for reproducibility
|
|
|
|
seed_everything(self.seed)
|
|
|
|
|
|
|
|
# Device setup
|
|
|
|
self.device = torch.device('cuda')
|
|
|
|
|
|
|
|
# Load diffusion model
|
|
|
|
print('Loading diffusion model ...')
|
|
|
|
self.pipeline = DiffusionPipeline.from_pretrained(
|
|
|
|
"sudo-ai/zero123plus-v1.2",
|
|
|
|
custom_pipeline="zero123plus",
|
|
|
|
torch_dtype=torch.float16,
|
|
|
|
)
|
|
|
|
self.pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
|
|
|
self.pipeline.scheduler.config, timestep_spacing='trailing'
|
|
|
|
)
|
|
|
|
|
|
|
|
# Load custom white-background UNet
|
|
|
|
print('Loading custom white-background unet ...')
|
|
|
|
if os.path.exists(self.infer_config.unet_path):
|
|
|
|
unet_ckpt_path = self.infer_config.unet_path
|
|
|
|
else:
|
|
|
|
unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin",
|
|
|
|
repo_type="model")
|
|
|
|
state_dict = torch.load(unet_ckpt_path, map_location='cpu')
|
|
|
|
self.pipeline.unet.load_state_dict(state_dict, strict=True)
|
|
|
|
|
|
|
|
self.pipeline = self.pipeline.to(self.device)
|
|
|
|
|
|
|
|
|
|
|
|
# Load reconstruction model
|
|
|
|
print('Loading reconstruction model ...')
|
|
|
|
self.model = instantiate_from_config(self.model_config)
|
|
|
|
if os.path.exists(self.infer_config.model_path):
|
|
|
|
model_ckpt_path = self.infer_config.model_path
|
|
|
|
else:
|
|
|
|
model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh",
|
|
|
|
filename=f"{self.config_name.replace('-', '_')}.ckpt", repo_type="model")
|
|
|
|
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
|
|
|
|
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
|
|
|
|
self.model.load_state_dict(state_dict, strict=True)
|
|
|
|
|
|
|
|
self.model = self.model.to(self.device)
|
|
|
|
if self.IS_FLEXICUBES:
|
|
|
|
self.model.init_flexicubes_geometry(self.device, fovy=30.0)
|
|
|
|
self.model = self.model.eval()
|
|
|
|
|
|
|
|
# Make output directories
|
|
|
|
self.image_path = os.path.join(self.output_path, self.config_name, 'images')
|
|
|
|
self.mesh_path = os.path.join(self.output_path, self.config_name, 'meshes')
|
|
|
|
self.video_path = os.path.join(self.output_path, self.config_name, 'videos')
|
|
|
|
os.makedirs(self.image_path, exist_ok=True)
|
|
|
|
os.makedirs(self.mesh_path, exist_ok=True)
|
|
|
|
os.makedirs(self.video_path, exist_ok=True)
|
|
|
|
|
|
|
|
def process_image(self, image_file, idx, total_num_files, rembg_session):
|
|
|
|
|
|
|
|
if rembg_session == None:
|
|
|
|
rembg_session = None if self.no_rembg else rembg.new_session()
|
|
|
|
|
|
|
|
name = os.path.basename(image_file).split('.')[0]
|
|
|
|
print(f'[{idx + 1}/{total_num_files}] Creating novel viewpoints of {name} ...')
|
|
|
|
|
|
|
|
# Remove background optionally
|
|
|
|
input_image = Image.open(image_file)
|
|
|
|
if not self.no_rembg:
|
|
|
|
input_image = remove_background(input_image, rembg_session)
|
|
|
|
input_image = resize_foreground(input_image, 0.85)
|
|
|
|
|
|
|
|
generator = torch.Generator(device=self.device)
|
|
|
|
|
|
|
|
self.pipeline = self.pipeline.to(self.device)
|
|
|
|
|
|
|
|
# Sampling
|
|
|
|
output_image = self.pipeline(
|
|
|
|
input_image,
|
|
|
|
num_inference_steps=self.diffusion_steps,
|
|
|
|
generator=generator
|
|
|
|
).images[0]
|
|
|
|
|
|
|
|
self.pipeline.to("cpu")
|
|
|
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
torch.cuda.reset_max_memory_cached()
|
|
|
|
|
|
|
|
img_path = os.path.join(self.image_path, f'{name}.png')
|
|
|
|
output_image.save(img_path)
|
|
|
|
print(f"Image of viewpoints saved to {os.path.join(self.image_path, f'{name}.png')}")
|
|
|
|
|
|
|
|
images = np.asarray(output_image, dtype=np.float32) / 255.0
|
|
|
|
images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
|
|
|
|
images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320)
|
|
|
|
|
|
|
|
input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0 * self.scale).to(self.device)
|
|
|
|
chunk_size = 20 if self.IS_FLEXICUBES else 1
|
|
|
|
print(f'Creating {name} ...')
|
|
|
|
start_time = time.time()
|
|
|
|
images = images.unsqueeze(0).to(self.device)
|
|
|
|
images = v2.functional.resize(images, 320, interpolation=3, antialias=True).clamp(0, 1)
|
|
|
|
|
|
|
|
if self.view == 4:
|
|
|
|
indices = torch.tensor([0, 2, 4, 5]).long().to(self.device)
|
|
|
|
images = images[:, indices]
|
|
|
|
input_cameras = input_cameras[:, indices]
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
# Get triplane
|
|
|
|
planes = self.model.forward_planes(images, input_cameras)
|
|
|
|
|
|
|
|
# Get mesh
|
|
|
|
mesh_path_idx = os.path.join(self.mesh_path, f'{name}.obj')
|
|
|
|
mtl_path_idx = os.path.join(self.mesh_path, f'{name}.mtl')
|
|
|
|
texmap_path_idx = os.path.join(self.mesh_path, f'{name}.png')
|
|
|
|
|
|
|
|
mesh_out = self.model.extract_mesh(
|
|
|
|
planes,
|
|
|
|
use_texture_map=self.export_texmap,
|
|
|
|
**self.infer_config,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if self.export_texmap:
|
|
|
|
vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
|
|
|
|
save_obj_with_mtl(
|
|
|
|
vertices.data.cpu().numpy(),
|
|
|
|
uvs.data.cpu().numpy(),
|
|
|
|
faces.data.cpu().numpy(),
|
|
|
|
mesh_tex_idx.data.cpu().numpy(),
|
|
|
|
tex_map.permute(1, 2, 0).data.cpu().numpy(),
|
|
|
|
mesh_path_idx,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
vertices, faces, vertex_colors = mesh_out
|
|
|
|
save_obj(vertices, faces, vertex_colors, mesh_path_idx)
|
|
|
|
|
|
|
|
print(f"Mesh saved to {mesh_path_idx}")
|
|
|
|
|
|
|
|
# Get video
|
|
|
|
if self.save_video:
|
|
|
|
video_path_idx = os.path.join(self.video_path, f'{name}.mp4')
|
|
|
|
render_size = self.infer_config.render_resolution
|
|
|
|
render_cameras = self._get_render_cameras()
|
|
|
|
|
|
|
|
frames = self._render_frames(
|
|
|
|
planes,
|
|
|
|
render_cameras=render_cameras,
|
|
|
|
render_size=render_size,
|
|
|
|
chunk_size=chunk_size,
|
|
|
|
)
|
|
|
|
|
|
|
|
save_video(
|
|
|
|
frames,
|
|
|
|
video_path_idx,
|
|
|
|
fps=30,
|
|
|
|
)
|
|
|
|
print(f"Video saved to {video_path_idx}")
|
|
|
|
|
|
|
|
if self.gltf:
|
|
|
|
output_path = os.path.join(self.gltf_path, f'{name}.gltf')
|
|
|
|
command = f' obj2gltf -i {mesh_path_idx} -o {output_path}'
|
|
|
|
process = subprocess.run(command, shell=True, capture_output=True, text=True)
|
|
|
|
# Check if the process was successful
|
|
|
|
if process.returncode == 0:
|
|
|
|
print(f'Successfully converted {mesh_path_idx} to {output_path}')
|
|
|
|
if self.local is False:
|
|
|
|
self.sender.send_file(output_path)
|
|
|
|
else:
|
|
|
|
print(f'Error converting {mesh_path_idx}: {process.stderr}')
|
|
|
|
|
|
|
|
if self.remove:
|
|
|
|
print(f'Removing files (mtl, obj, diffusion, texmap) for {name}')
|
|
|
|
os.remove(mesh_path_idx)
|
|
|
|
os.remove(mtl_path_idx)
|
|
|
|
os.remove(img_path)
|
|
|
|
os.remove(texmap_path_idx)
|
|
|
|
|
|
|
|
end_time = time.time()
|
|
|
|
elapsed_time = end_time - start_time
|
|
|
|
print(f'Total Time: {elapsed_time}')
|
|
|
|
pass
|
|
|
|
|
|
|
|
def run_pipeline_sequence(self, stop_event):
|
|
|
|
# Process input files
|
|
|
|
if os.path.isdir(self.input_path):
|
|
|
|
input_files = [
|
|
|
|
os.path.join(self.input_path, file)
|
|
|
|
for file in os.listdir(self.input_path)
|
|
|
|
if file.endswith('.png') or file.endswith('.jpg') or file.endswith('.webp')
|
|
|
|
]
|
|
|
|
else:
|
|
|
|
input_files = [self.input_path]
|
|
|
|
print(f'\nTotal number of input images: {len(input_files)}')
|
|
|
|
|
|
|
|
rembg_session = None if self.no_rembg else rembg.new_session()
|
|
|
|
|
|
|
|
for idx, image_file in enumerate(input_files):
|
|
|
|
if stop_event.is_set(): # Check if stop event is set
|
|
|
|
print("\nStopping pipeline sequence.")
|
|
|
|
break
|
|
|
|
self.process_image(image_file, idx, len(input_files), rembg_session)
|
|
|
|
|
|
|
|
def _render_frames(self, planes, render_cameras, render_size=512, chunk_size=1):
|
|
|
|
# Render frames from triplanes
|
|
|
|
frames = []
|
|
|
|
for i in tqdm(range(0, render_cameras.shape[1], chunk_size)):
|
|
|
|
if self.IS_FLEXICUBES:
|
|
|
|
frame = self.model.forward_geometry(
|
|
|
|
planes,
|
|
|
|
render_cameras[:, i:i + chunk_size],
|
|
|
|
render_size=render_size,
|
|
|
|
)['img']
|
|
|
|
else:
|
|
|
|
frame = self.model.forward_synthesizer(
|
|
|
|
planes,
|
|
|
|
render_cameras[:, i:i + chunk_size],
|
|
|
|
render_size=render_size,
|
|
|
|
)['images_rgb']
|
|
|
|
frames.append(frame)
|
|
|
|
|
|
|
|
frames = torch.cat(frames, dim=1)[0] # we suppose batch size is always 1
|
|
|
|
return frames
|