|
@ -27,6 +27,7 @@ from src.utils.infer_util import remove_background, resize_foreground, images_to |
|
|
|
|
|
|
|
|
import tempfile |
|
|
import tempfile |
|
|
from functools import partial |
|
|
from functools import partial |
|
|
|
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False): |
|
|
def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False): |
|
@ -65,7 +66,7 @@ def images_to_video(images, output_path, fps=30): |
|
|
|
|
|
|
|
|
seed_everything(0) |
|
|
seed_everything(0) |
|
|
|
|
|
|
|
|
config_path = 'configs/instant-mesh-large-eval.yaml' |
|
|
config_path = 'configs/instant-mesh-large.yaml' |
|
|
config = OmegaConf.load(config_path) |
|
|
config = OmegaConf.load(config_path) |
|
|
config_name = os.path.basename(config_path).replace('.yaml', '') |
|
|
config_name = os.path.basename(config_path).replace('.yaml', '') |
|
|
model_config = config.model_config |
|
|
model_config = config.model_config |
|
@ -87,15 +88,17 @@ pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
# load custom white-background UNet |
|
|
# load custom white-background UNet |
|
|
state_dict = torch.load(infer_config.unet_path, map_location='cpu') |
|
|
unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model") |
|
|
|
|
|
state_dict = torch.load(unet_ckpt_path, map_location='cpu') |
|
|
pipeline.unet.load_state_dict(state_dict, strict=True) |
|
|
pipeline.unet.load_state_dict(state_dict, strict=True) |
|
|
|
|
|
|
|
|
pipeline = pipeline.to(device) |
|
|
pipeline = pipeline.to(device) |
|
|
|
|
|
|
|
|
# load reconstruction model |
|
|
# load reconstruction model |
|
|
print('Loading reconstruction model ...') |
|
|
print('Loading reconstruction model ...') |
|
|
|
|
|
model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="instant_mesh_large.ckpt", repo_type="model") |
|
|
model = instantiate_from_config(model_config) |
|
|
model = instantiate_from_config(model_config) |
|
|
state_dict = torch.load(infer_config.model_path, map_location='cpu')['state_dict'] |
|
|
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict'] |
|
|
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k} |
|
|
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k} |
|
|
model.load_state_dict(state_dict, strict=True) |
|
|
model.load_state_dict(state_dict, strict=True) |
|
|
|
|
|
|
|
@ -115,10 +118,9 @@ def check_input_image(input_image): |
|
|
def preprocess(input_image, do_remove_background): |
|
|
def preprocess(input_image, do_remove_background): |
|
|
|
|
|
|
|
|
rembg_session = rembg.new_session() if do_remove_background else None |
|
|
rembg_session = rembg.new_session() if do_remove_background else None |
|
|
|
|
|
|
|
|
#input_image = Image.open(image_file) |
|
|
|
|
|
if do_remove_background: |
|
|
if do_remove_background: |
|
|
input_image = remove_background(input_image, rembg_session) |
|
|
input_image = remove_background(input_image, rembg_session) |
|
|
|
|
|
input_image = resize_foreground(input_image, 0.85) |
|
|
|
|
|
|
|
|
return input_image |
|
|
return input_image |
|
|
|
|
|
|
|
@ -173,8 +175,8 @@ def make3d(images): |
|
|
images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640) |
|
|
images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640) |
|
|
images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320) |
|
|
images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320) |
|
|
|
|
|
|
|
|
input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=2.5).to(device) |
|
|
input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to(device) |
|
|
render_cameras = get_render_cameras(batch_size=1, radius=2.5, is_flexicubes=IS_FLEXICUBES).to(device) |
|
|
render_cameras = get_render_cameras(batch_size=1, radius=4.0, is_flexicubes=IS_FLEXICUBES).to(device) |
|
|
|
|
|
|
|
|
images = images.unsqueeze(0).to(device) |
|
|
images = images.unsqueeze(0).to(device) |
|
|
images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1) |
|
|
images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1) |
|
@ -223,31 +225,31 @@ def make3d(images): |
|
|
return video_fpath, mesh_fpath |
|
|
return video_fpath, mesh_fpath |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_example(image_file): |
|
|
|
|
|
|
|
|
|
|
|
preprocessed = preprocess(image_file, False, 0.85) |
|
|
|
|
|
mv_images, _ = generate_mvs(preprocessed, 20, 0) |
|
|
|
|
|
video_name, mesh_fpath, planes = make3d(mv_images) |
|
|
|
|
|
mesh_name = make_mesh(mesh_fpath, planes) |
|
|
|
|
|
|
|
|
|
|
|
return preprocessed, mesh_name, video_name |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
HEADER = ''' |
|
|
_HEADER_ = ''' |
|
|
<h3> |
|
|
<h2><b>Official 🤗 Gradio Demo</b></h2><h2><a href='https://github.com/TencentARC/InstantMesh' target='_blank'><b>InstantMesh: Efficient 3D Mesh Generation from a Single Image with Sparse-view Large Reconstruction Models</b></a></h2> |
|
|
<b>Official 🤗 Gradio demo</b> for |
|
|
|
|
|
<a href='https://github.com/TencentARC/InstantMesh' target='_blank'> |
|
|
|
|
|
<b>InstantMesh: Efficient 3D Mesh Generation from a Single Image with Sparse-view Large Reconstruction Models</b> |
|
|
|
|
|
</a>. |
|
|
|
|
|
</h3> |
|
|
|
|
|
<br> |
|
|
|
|
|
* If the output is unsatisfying, try to use a different seed. |
|
|
|
|
|
''' |
|
|
''' |
|
|
|
|
|
|
|
|
|
|
|
_LINKS_ = ''' |
|
|
|
|
|
<h3>Code is available at <a href='https://github.com/TencentARC/InstantMesh' target='_blank'>GitHub</a></h3> |
|
|
|
|
|
<h3>Report is available at <a href='https://arxiv.org/abs/2404.07191' target='_blank'>ArXiv</a></h3> |
|
|
|
|
|
''' |
|
|
|
|
|
|
|
|
|
|
|
_CITE_ = r""" |
|
|
|
|
|
```bibtex |
|
|
|
|
|
@article{xu2024instantmesh, |
|
|
|
|
|
title={InstantMesh: Efficient 3D Mesh Generation from a Single Image with Sparse-view Large Reconstruction Models}, |
|
|
|
|
|
author={Xu, Jiale and Cheng, Weihao and Gao, Yiming and Wang, Xintao and Gao, Shenghua and Shan, Ying}, |
|
|
|
|
|
journal={arXiv preprint arXiv:2404.07191}, |
|
|
|
|
|
year={2024} |
|
|
|
|
|
} |
|
|
|
|
|
``` |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown(HEADER) |
|
|
gr.Markdown(_HEADER_) |
|
|
with gr.Row(variant="panel"): |
|
|
with gr.Row(variant="panel"): |
|
|
with gr.Column(): |
|
|
with gr.Column(): |
|
|
with gr.Row(): |
|
|
with gr.Row(): |
|
@ -273,7 +275,7 @@ with gr.Blocks() as demo: |
|
|
do_remove_background = gr.Checkbox( |
|
|
do_remove_background = gr.Checkbox( |
|
|
label="Remove Background", value=True |
|
|
label="Remove Background", value=True |
|
|
) |
|
|
) |
|
|
sample_seed = gr.Number(value=42, label="Seed", precision=0) |
|
|
sample_seed = gr.Number(value=42, label="Seed (Try a different value if the result is unsatisfying)", precision=0) |
|
|
|
|
|
|
|
|
sample_steps = gr.Slider( |
|
|
sample_steps = gr.Slider( |
|
|
label="Sample Steps", |
|
|
label="Sample Steps", |
|
@ -292,9 +294,6 @@ with gr.Blocks() as demo: |
|
|
os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples")) |
|
|
os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples")) |
|
|
], |
|
|
], |
|
|
inputs=[input_image], |
|
|
inputs=[input_image], |
|
|
# outputs=[processed_image, output_model_obj, output_video], |
|
|
|
|
|
# fn=partial(run_example), |
|
|
|
|
|
# cache_examples=True, |
|
|
|
|
|
label="Examples", |
|
|
label="Examples", |
|
|
examples_per_page=20 |
|
|
examples_per_page=20 |
|
|
) |
|
|
) |
|
@ -325,7 +324,8 @@ with gr.Blocks() as demo: |
|
|
width=768, |
|
|
width=768, |
|
|
interactive=False, |
|
|
interactive=False, |
|
|
) |
|
|
) |
|
|
|
|
|
gr.Markdown(_LINKS_) |
|
|
|
|
|
gr.Markdown(_CITE_) |
|
|
mv_images = gr.State() |
|
|
mv_images = gr.State() |
|
|
|
|
|
|
|
|
submit.click(fn=check_input_image, inputs=[input_image]).success( |
|
|
submit.click(fn=check_input_image, inputs=[input_image]).success( |
|
|