Browse Source

update app.py

main
xt4d 6 months ago
parent
commit
d23bedb93e
  1. 62
      app.py
  2. 2
      requirements.txt

62
app.py

@ -27,6 +27,7 @@ from src.utils.infer_util import remove_background, resize_foreground, images_to
import tempfile import tempfile
from functools import partial from functools import partial
from huggingface_hub import hf_hub_download
def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False): def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False):
@ -65,7 +66,7 @@ def images_to_video(images, output_path, fps=30):
seed_everything(0) seed_everything(0)
config_path = 'configs/instant-mesh-large-eval.yaml' config_path = 'configs/instant-mesh-large.yaml'
config = OmegaConf.load(config_path) config = OmegaConf.load(config_path)
config_name = os.path.basename(config_path).replace('.yaml', '') config_name = os.path.basename(config_path).replace('.yaml', '')
model_config = config.model_config model_config = config.model_config
@ -87,15 +88,17 @@ pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
) )
# load custom white-background UNet # load custom white-background UNet
state_dict = torch.load(infer_config.unet_path, map_location='cpu') unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
state_dict = torch.load(unet_ckpt_path, map_location='cpu')
pipeline.unet.load_state_dict(state_dict, strict=True) pipeline.unet.load_state_dict(state_dict, strict=True)
pipeline = pipeline.to(device) pipeline = pipeline.to(device)
# load reconstruction model # load reconstruction model
print('Loading reconstruction model ...') print('Loading reconstruction model ...')
model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="instant_mesh_large.ckpt", repo_type="model")
model = instantiate_from_config(model_config) model = instantiate_from_config(model_config)
state_dict = torch.load(infer_config.model_path, map_location='cpu')['state_dict'] state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k} state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k}
model.load_state_dict(state_dict, strict=True) model.load_state_dict(state_dict, strict=True)
@ -115,10 +118,9 @@ def check_input_image(input_image):
def preprocess(input_image, do_remove_background): def preprocess(input_image, do_remove_background):
rembg_session = rembg.new_session() if do_remove_background else None rembg_session = rembg.new_session() if do_remove_background else None
#input_image = Image.open(image_file)
if do_remove_background: if do_remove_background:
input_image = remove_background(input_image, rembg_session) input_image = remove_background(input_image, rembg_session)
input_image = resize_foreground(input_image, 0.85)
return input_image return input_image
@ -173,8 +175,8 @@ def make3d(images):
images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640) images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320) images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320)
input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=2.5).to(device) input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to(device)
render_cameras = get_render_cameras(batch_size=1, radius=2.5, is_flexicubes=IS_FLEXICUBES).to(device) render_cameras = get_render_cameras(batch_size=1, radius=4.0, is_flexicubes=IS_FLEXICUBES).to(device)
images = images.unsqueeze(0).to(device) images = images.unsqueeze(0).to(device)
images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1) images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
@ -223,31 +225,31 @@ def make3d(images):
return video_fpath, mesh_fpath return video_fpath, mesh_fpath
def run_example(image_file):
preprocessed = preprocess(image_file, False, 0.85)
mv_images, _ = generate_mvs(preprocessed, 20, 0)
video_name, mesh_fpath, planes = make3d(mv_images)
mesh_name = make_mesh(mesh_fpath, planes)
return preprocessed, mesh_name, video_name
import gradio as gr import gradio as gr
HEADER = ''' _HEADER_ = '''
<h3> <h2><b>Official 🤗 Gradio Demo</b></h2><h2><a href='https://github.com/TencentARC/InstantMesh' target='_blank'><b>InstantMesh: Efficient 3D Mesh Generation from a Single Image with Sparse-view Large Reconstruction Models</b></a></h2>
<b>Official 🤗 Gradio demo</b> for
<a href='https://github.com/TencentARC/InstantMesh' target='_blank'>
<b>InstantMesh: Efficient 3D Mesh Generation from a Single Image with Sparse-view Large Reconstruction Models</b>
</a>.
</h3>
<br>
* If the output is unsatisfying, try to use a different seed.
''' '''
_LINKS_ = '''
<h3>Code is available at <a href='https://github.com/TencentARC/InstantMesh' target='_blank'>GitHub</a></h3>
<h3>Report is available at <a href='https://arxiv.org/abs/2404.07191' target='_blank'>ArXiv</a></h3>
'''
_CITE_ = r"""
```bibtex
@article{xu2024instantmesh,
title={InstantMesh: Efficient 3D Mesh Generation from a Single Image with Sparse-view Large Reconstruction Models},
author={Xu, Jiale and Cheng, Weihao and Gao, Yiming and Wang, Xintao and Gao, Shenghua and Shan, Ying},
journal={arXiv preprint arXiv:2404.07191},
year={2024}
}
```
"""
with gr.Blocks() as demo: with gr.Blocks() as demo:
gr.Markdown(HEADER) gr.Markdown(_HEADER_)
with gr.Row(variant="panel"): with gr.Row(variant="panel"):
with gr.Column(): with gr.Column():
with gr.Row(): with gr.Row():
@ -273,7 +275,7 @@ with gr.Blocks() as demo:
do_remove_background = gr.Checkbox( do_remove_background = gr.Checkbox(
label="Remove Background", value=True label="Remove Background", value=True
) )
sample_seed = gr.Number(value=42, label="Seed", precision=0) sample_seed = gr.Number(value=42, label="Seed (Try a different value if the result is unsatisfying)", precision=0)
sample_steps = gr.Slider( sample_steps = gr.Slider(
label="Sample Steps", label="Sample Steps",
@ -292,9 +294,6 @@ with gr.Blocks() as demo:
os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples")) os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))
], ],
inputs=[input_image], inputs=[input_image],
# outputs=[processed_image, output_model_obj, output_video],
# fn=partial(run_example),
# cache_examples=True,
label="Examples", label="Examples",
examples_per_page=20 examples_per_page=20
) )
@ -325,7 +324,8 @@ with gr.Blocks() as demo:
width=768, width=768,
interactive=False, interactive=False,
) )
gr.Markdown(_LINKS_)
gr.Markdown(_CITE_)
mv_images = gr.State() mv_images = gr.State()
submit.click(fn=check_input_image, inputs=[input_image]).success( submit.click(fn=check_input_image, inputs=[input_image]).success(

2
requirements.txt

@ -1,5 +1,5 @@
pytorch-lightning==2.1.2 pytorch-lightning==2.1.2
gradio gradio==3.41.2
huggingface-hub huggingface-hub
einops einops
omegaconf omegaconf

Loading…
Cancel
Save