You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
113 lines
3.7 KiB
113 lines
3.7 KiB
2 years ago
|
"""
|
||
|
Adapted from: https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/download.py
|
||
|
"""
|
||
|
|
||
|
import os
|
||
|
from functools import lru_cache
|
||
|
from typing import Dict, Optional
|
||
|
|
||
|
import requests
|
||
|
import torch
|
||
|
import yaml
|
||
|
from filelock import FileLock
|
||
|
from tqdm.auto import tqdm
|
||
|
|
||
|
MODEL_PATHS = {
|
||
|
"transmitter": "https://openaipublic.azureedge.net/main/shap-e/transmitter.pt",
|
||
|
"decoder": "https://openaipublic.azureedge.net/main/shap-e/vector_decoder.pt",
|
||
|
"text300M": "https://openaipublic.azureedge.net/main/shap-e/text_cond.pt",
|
||
|
"image300M": "https://openaipublic.azureedge.net/main/shap-e/image_cond.pt",
|
||
|
}
|
||
|
|
||
|
CONFIG_PATHS = {
|
||
|
"transmitter": "https://openaipublic.azureedge.net/main/shap-e/transmitter_config.yaml",
|
||
|
"decoder": "https://openaipublic.azureedge.net/main/shap-e/vector_decoder_config.yaml",
|
||
|
"text300M": "https://openaipublic.azureedge.net/main/shap-e/text_cond_config.yaml",
|
||
|
"image300M": "https://openaipublic.azureedge.net/main/shap-e/image_cond_config.yaml",
|
||
|
"diffusion": "https://openaipublic.azureedge.net/main/shap-e/diffusion_config.yaml",
|
||
|
}
|
||
|
|
||
|
|
||
|
@lru_cache()
|
||
|
def default_cache_dir() -> str:
|
||
|
return os.path.join(os.path.abspath(os.getcwd()), "shap_e_model_cache")
|
||
|
|
||
|
|
||
|
def fetch_file_cached(
|
||
|
url: str, progress: bool = True, cache_dir: Optional[str] = None, chunk_size: int = 4096
|
||
|
) -> str:
|
||
|
"""
|
||
|
Download the file at the given URL into a local file and return the path.
|
||
|
If cache_dir is specified, it will be used to download the files.
|
||
|
Otherwise, default_cache_dir() is used.
|
||
|
"""
|
||
|
if cache_dir is None:
|
||
|
cache_dir = default_cache_dir()
|
||
|
os.makedirs(cache_dir, exist_ok=True)
|
||
|
local_path = os.path.join(cache_dir, url.split("/")[-1])
|
||
|
if os.path.exists(local_path):
|
||
|
return local_path
|
||
|
|
||
|
response = requests.get(url, stream=True)
|
||
|
size = int(response.headers.get("content-length", "0"))
|
||
|
with FileLock(local_path + ".lock"):
|
||
|
if progress:
|
||
|
pbar = tqdm(total=size, unit="iB", unit_scale=True)
|
||
|
tmp_path = local_path + ".tmp"
|
||
|
with open(tmp_path, "wb") as f:
|
||
|
for chunk in response.iter_content(chunk_size):
|
||
|
if progress:
|
||
|
pbar.update(len(chunk))
|
||
|
f.write(chunk)
|
||
|
os.rename(tmp_path, local_path)
|
||
|
if progress:
|
||
|
pbar.close()
|
||
|
return local_path
|
||
|
|
||
|
|
||
|
def load_config(
|
||
|
config_name: str,
|
||
|
progress: bool = False,
|
||
|
cache_dir: Optional[str] = None,
|
||
|
chunk_size: int = 4096,
|
||
|
):
|
||
|
if config_name not in CONFIG_PATHS:
|
||
|
raise ValueError(
|
||
|
f"Unknown config name {config_name}. Known names are: {CONFIG_PATHS.keys()}."
|
||
|
)
|
||
|
path = fetch_file_cached(
|
||
|
CONFIG_PATHS[config_name], progress=progress, cache_dir=cache_dir, chunk_size=chunk_size
|
||
|
)
|
||
|
with open(path, "r") as f:
|
||
|
return yaml.safe_load(f)
|
||
|
|
||
|
|
||
|
def load_checkpoint(
|
||
|
checkpoint_name: str,
|
||
|
device: torch.device,
|
||
|
progress: bool = True,
|
||
|
cache_dir: Optional[str] = None,
|
||
|
chunk_size: int = 4096,
|
||
|
) -> Dict[str, torch.Tensor]:
|
||
|
if checkpoint_name not in MODEL_PATHS:
|
||
|
raise ValueError(
|
||
|
f"Unknown checkpoint name {checkpoint_name}. Known names are: {MODEL_PATHS.keys()}."
|
||
|
)
|
||
|
path = fetch_file_cached(
|
||
|
MODEL_PATHS[checkpoint_name], progress=progress, cache_dir=cache_dir, chunk_size=chunk_size
|
||
|
)
|
||
|
return torch.load(path, map_location=device)
|
||
|
|
||
|
|
||
|
def load_model(
|
||
|
model_name: str,
|
||
|
device: torch.device,
|
||
|
**kwargs,
|
||
|
) -> Dict[str, torch.Tensor]:
|
||
|
from .configs import model_from_config
|
||
|
|
||
|
model = model_from_config(load_config(model_name, **kwargs), device=device)
|
||
|
model.load_state_dict(load_checkpoint(model_name, device=device, **kwargs))
|
||
|
model.eval()
|
||
|
return model
|