|
@ -30,6 +30,9 @@ else: |
|
|
device0 = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
device0 = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
device1 = device0 |
|
|
device1 = device0 |
|
|
|
|
|
|
|
|
|
|
|
# Define the cache directory for model files |
|
|
|
|
|
model_cache_dir = './models/' |
|
|
|
|
|
os.makedirs(model_cache_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False): |
|
|
def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False): |
|
|
""" |
|
|
""" |
|
@ -89,7 +92,7 @@ pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
# load custom white-background UNet |
|
|
# load custom white-background UNet |
|
|
unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model") |
|
|
unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model", cache_dir=model_cache_dir) |
|
|
state_dict = torch.load(unet_ckpt_path, map_location='cpu') |
|
|
state_dict = torch.load(unet_ckpt_path, map_location='cpu') |
|
|
pipeline.unet.load_state_dict(state_dict, strict=True) |
|
|
pipeline.unet.load_state_dict(state_dict, strict=True) |
|
|
|
|
|
|
|
@ -97,7 +100,7 @@ pipeline = pipeline.to(device0) |
|
|
|
|
|
|
|
|
# load reconstruction model |
|
|
# load reconstruction model |
|
|
print('Loading reconstruction model ...') |
|
|
print('Loading reconstruction model ...') |
|
|
model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="instant_mesh_large.ckpt", repo_type="model") |
|
|
model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="instant_mesh_large.ckpt", repo_type="model", cache_dir=model_cache_dir) |
|
|
model = instantiate_from_config(model_config) |
|
|
model = instantiate_from_config(model_config) |
|
|
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict'] |
|
|
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict'] |
|
|
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k} |
|
|
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k} |
|
@ -375,4 +378,4 @@ with gr.Blocks() as demo: |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
demo.queue(max_size=10) |
|
|
demo.queue(max_size=10) |
|
|
demo.launch(server_name="0.0.0.0", server_port=43839) |
|
|
demo.launch(server_name="0.0.0.0", server_port=43839, share=True) |
|
|