Browse Source

add models volume cache

main
catid 8 months ago
parent
commit
5b18954620
  1. 9
      app.py
  2. 5
      docker/Dockerfile

9
app.py

@ -30,6 +30,9 @@ else:
device0 = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device0 = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device1 = device0 device1 = device0
# Define the cache directory for model files
model_cache_dir = './models/'
os.makedirs(model_cache_dir, exist_ok=True)
def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False): def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False):
""" """
@ -89,7 +92,7 @@ pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
) )
# load custom white-background UNet # load custom white-background UNet
unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model") unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model", cache_dir=model_cache_dir)
state_dict = torch.load(unet_ckpt_path, map_location='cpu') state_dict = torch.load(unet_ckpt_path, map_location='cpu')
pipeline.unet.load_state_dict(state_dict, strict=True) pipeline.unet.load_state_dict(state_dict, strict=True)
@ -97,7 +100,7 @@ pipeline = pipeline.to(device0)
# load reconstruction model # load reconstruction model
print('Loading reconstruction model ...') print('Loading reconstruction model ...')
model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="instant_mesh_large.ckpt", repo_type="model") model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="instant_mesh_large.ckpt", repo_type="model", cache_dir=model_cache_dir)
model = instantiate_from_config(model_config) model = instantiate_from_config(model_config)
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict'] state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k} state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k}
@ -375,4 +378,4 @@ with gr.Blocks() as demo:
) )
demo.queue(max_size=10) demo.queue(max_size=10)
demo.launch(server_name="0.0.0.0", server_port=43839) demo.launch(server_name="0.0.0.0", server_port=43839, share=True)

5
docker/Dockerfile

@ -51,5 +51,8 @@ RUN pip install -r requirements.txt
COPY . /workspace/instantmesh COPY . /workspace/instantmesh
# Add a volume for downloaded models
VOLUME /workspace/models
# Run the command when the container starts # Run the command when the container starts
CMD ["python", "app.py"] CMD ["python", "app.py"]

Loading…
Cancel
Save