diff --git a/run.py b/run.py index 5b09fed..1454253 100644 --- a/run.py +++ b/run.py @@ -9,6 +9,7 @@ from pytorch_lightning import seed_everything from omegaconf import OmegaConf from einops import rearrange, repeat from tqdm import tqdm +from huggingface_hub import hf_hub_download from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler from src.utils.train_util import instantiate_from_config @@ -106,7 +107,11 @@ pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( # load custom white-background UNet print('Loading custom white-background unet ...') -state_dict = torch.load(infer_config.unet_path, map_location='cpu') +if os.path.exists(infer_config.unet_path): + unet_ckpt_path = infer_config.unet_path +else: + unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model") +state_dict = torch.load(unet_ckpt_path, map_location='cpu') pipeline.unet.load_state_dict(state_dict, strict=True) pipeline = pipeline.to(device) @@ -114,7 +119,11 @@ pipeline = pipeline.to(device) # load reconstruction model print('Loading reconstruction model ...') model = instantiate_from_config(model_config) -state_dict = torch.load(infer_config.model_path, map_location='cpu')['state_dict'] +if os.path.exists(infer_config.model_path): + model_ckpt_path = infer_config.model_path +else: + model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename=f"{config_name.replace('-', '_')}.ckpt", repo_type="model") +state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict'] state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')} model.load_state_dict(state_dict, strict=True)