|
|
@ -2,6 +2,7 @@ |
|
|
|
Adapted from: https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/download.py |
|
|
|
""" |
|
|
|
|
|
|
|
import hashlib |
|
|
|
import os |
|
|
|
from functools import lru_cache |
|
|
|
from typing import Dict, Optional |
|
|
@ -27,6 +28,18 @@ CONFIG_PATHS = { |
|
|
|
"diffusion": "https://openaipublic.azureedge.net/main/shap-e/diffusion_config.yaml", |
|
|
|
} |
|
|
|
|
|
|
|
URL_HASHES = { |
|
|
|
"https://openaipublic.azureedge.net/main/shap-e/transmitter.pt": "af02a0b85a8abdfb3919584b63c540ba175f6ad4790f574a7fef4617e5acdc3b", |
|
|
|
"https://openaipublic.azureedge.net/main/shap-e/vector_decoder.pt": "d7e7ebbfe3780499ae89b2da5e7c1354012dba5a6abfe295bed42f25c3be1b98", |
|
|
|
"https://openaipublic.azureedge.net/main/shap-e/text_cond.pt": "e6b4fa599a7b3c3b16c222d5f5fe56f9db9289ff0b6575fbe5c11bc97106aad4", |
|
|
|
"https://openaipublic.azureedge.net/main/shap-e/image_cond.pt": "cb8072c64bbbcf6910488814d212227de5db291780d4ea99c6152f9346cf12aa", |
|
|
|
"https://openaipublic.azureedge.net/main/shap-e/transmitter_config.yaml": "ffe1bcb405104a37d9408391182ab118a4ef313c391e07689684f1f62071605e", |
|
|
|
"https://openaipublic.azureedge.net/main/shap-e/vector_decoder_config.yaml": "e6d373649f8e24d85925f4674b9ac41c57aba5f60e42cde6d10f87381326365c", |
|
|
|
"https://openaipublic.azureedge.net/main/shap-e/text_cond_config.yaml": "f290beeea3d3e9ff15db01bde5382b6e549e463060c0744f89c049505be246c1", |
|
|
|
"https://openaipublic.azureedge.net/main/shap-e/image_cond_config.yaml": "4e0745605a533c543c72add803a78d233e2a6401e0abfa0cad58afb4d74ad0b0", |
|
|
|
"https://openaipublic.azureedge.net/main/shap-e/diffusion_config.yaml": "efcb2cd7ee545b2d27223979d41857802448143990572a42645cd09c2942ed57", |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
@lru_cache() |
|
|
|
def default_cache_dir() -> str: |
|
|
@ -41,11 +54,14 @@ def fetch_file_cached( |
|
|
|
If cache_dir is specified, it will be used to download the files. |
|
|
|
Otherwise, default_cache_dir() is used. |
|
|
|
""" |
|
|
|
expected_hash = URL_HASHES[url] |
|
|
|
|
|
|
|
if cache_dir is None: |
|
|
|
cache_dir = default_cache_dir() |
|
|
|
os.makedirs(cache_dir, exist_ok=True) |
|
|
|
local_path = os.path.join(cache_dir, url.split("/")[-1]) |
|
|
|
if os.path.exists(local_path): |
|
|
|
check_hash(local_path, expected_hash) |
|
|
|
return local_path |
|
|
|
|
|
|
|
response = requests.get(url, stream=True) |
|
|
@ -62,9 +78,30 @@ def fetch_file_cached( |
|
|
|
os.rename(tmp_path, local_path) |
|
|
|
if progress: |
|
|
|
pbar.close() |
|
|
|
check_hash(local_path, expected_hash) |
|
|
|
return local_path |
|
|
|
|
|
|
|
|
|
|
|
def check_hash(path: str, expected_hash: str): |
|
|
|
actual_hash = hash_file(path) |
|
|
|
if actual_hash != expected_hash: |
|
|
|
raise RuntimeError( |
|
|
|
f"The file {path} should have hash {expected_hash} but has {actual_hash}. " |
|
|
|
"Try deleting it and running this call again." |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def hash_file(path: str) -> str: |
|
|
|
sha256_hash = hashlib.sha256() |
|
|
|
with open(path, "rb") as file: |
|
|
|
while True: |
|
|
|
data = file.read(4096) |
|
|
|
if not len(data): |
|
|
|
break |
|
|
|
sha256_hash.update(data) |
|
|
|
return sha256_hash.hexdigest() |
|
|
|
|
|
|
|
|
|
|
|
def load_config( |
|
|
|
config_name: str, |
|
|
|
progress: bool = False, |
|
|
|