Browse Source

auto download model weights

main
bluestyle97 9 months ago
parent
commit
e17ac11d9e
  1. 13
      run.py

13
run.py

@ -9,6 +9,7 @@ from pytorch_lightning import seed_everything
from omegaconf import OmegaConf from omegaconf import OmegaConf
from einops import rearrange, repeat from einops import rearrange, repeat
from tqdm import tqdm from tqdm import tqdm
from huggingface_hub import hf_hub_download
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
from src.utils.train_util import instantiate_from_config from src.utils.train_util import instantiate_from_config
@ -106,7 +107,11 @@ pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
# load custom white-background UNet # load custom white-background UNet
print('Loading custom white-background unet ...') print('Loading custom white-background unet ...')
state_dict = torch.load(infer_config.unet_path, map_location='cpu') if os.path.exists(infer_config.unet_path):
unet_ckpt_path = infer_config.unet_path
else:
unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
state_dict = torch.load(unet_ckpt_path, map_location='cpu')
pipeline.unet.load_state_dict(state_dict, strict=True) pipeline.unet.load_state_dict(state_dict, strict=True)
pipeline = pipeline.to(device) pipeline = pipeline.to(device)
@ -114,7 +119,11 @@ pipeline = pipeline.to(device)
# load reconstruction model # load reconstruction model
print('Loading reconstruction model ...') print('Loading reconstruction model ...')
model = instantiate_from_config(model_config) model = instantiate_from_config(model_config)
state_dict = torch.load(infer_config.model_path, map_location='cpu')['state_dict'] if os.path.exists(infer_config.model_path):
model_ckpt_path = infer_config.model_path
else:
model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename=f"{config_name.replace('-', '_')}.ckpt", repo_type="model")
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')} state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
model.load_state_dict(state_dict, strict=True) model.load_state_dict(state_dict, strict=True)

Loading…
Cancel
Save