|
@ -9,6 +9,7 @@ from pytorch_lightning import seed_everything |
|
|
from omegaconf import OmegaConf |
|
|
from omegaconf import OmegaConf |
|
|
from einops import rearrange, repeat |
|
|
from einops import rearrange, repeat |
|
|
from tqdm import tqdm |
|
|
from tqdm import tqdm |
|
|
|
|
|
from huggingface_hub import hf_hub_download |
|
|
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler |
|
|
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler |
|
|
|
|
|
|
|
|
from src.utils.train_util import instantiate_from_config |
|
|
from src.utils.train_util import instantiate_from_config |
|
@ -106,7 +107,11 @@ pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( |
|
|
|
|
|
|
|
|
# load custom white-background UNet |
|
|
# load custom white-background UNet |
|
|
print('Loading custom white-background unet ...') |
|
|
print('Loading custom white-background unet ...') |
|
|
state_dict = torch.load(infer_config.unet_path, map_location='cpu') |
|
|
if os.path.exists(infer_config.unet_path): |
|
|
|
|
|
unet_ckpt_path = infer_config.unet_path |
|
|
|
|
|
else: |
|
|
|
|
|
unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model") |
|
|
|
|
|
state_dict = torch.load(unet_ckpt_path, map_location='cpu') |
|
|
pipeline.unet.load_state_dict(state_dict, strict=True) |
|
|
pipeline.unet.load_state_dict(state_dict, strict=True) |
|
|
|
|
|
|
|
|
pipeline = pipeline.to(device) |
|
|
pipeline = pipeline.to(device) |
|
@ -114,7 +119,11 @@ pipeline = pipeline.to(device) |
|
|
# load reconstruction model |
|
|
# load reconstruction model |
|
|
print('Loading reconstruction model ...') |
|
|
print('Loading reconstruction model ...') |
|
|
model = instantiate_from_config(model_config) |
|
|
model = instantiate_from_config(model_config) |
|
|
state_dict = torch.load(infer_config.model_path, map_location='cpu')['state_dict'] |
|
|
if os.path.exists(infer_config.model_path): |
|
|
|
|
|
model_ckpt_path = infer_config.model_path |
|
|
|
|
|
else: |
|
|
|
|
|
model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename=f"{config_name.replace('-', '_')}.ckpt", repo_type="model") |
|
|
|
|
|
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict'] |
|
|
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')} |
|
|
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')} |
|
|
model.load_state_dict(state_dict, strict=True) |
|
|
model.load_state_dict(state_dict, strict=True) |
|
|
|
|
|
|
|
|