Browse Source

add 2gpu support

main
catid 6 months ago
parent
commit
c030efb5e9
  1. 21
      app.py

21
app.py

@ -23,6 +23,13 @@ from src.utils.infer_util import remove_background, resize_foreground, images_to
import tempfile import tempfile
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
if torch.cuda.is_available() and torch.cuda.device_count() >= 2:
device0 = torch.device('cuda:0')
device1 = torch.device('cuda:1')
else:
device0 = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device1 = device0
def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False): def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False):
""" """
@ -86,7 +93,7 @@ unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="dif
state_dict = torch.load(unet_ckpt_path, map_location='cpu') state_dict = torch.load(unet_ckpt_path, map_location='cpu')
pipeline.unet.load_state_dict(state_dict, strict=True) pipeline.unet.load_state_dict(state_dict, strict=True)
pipeline = pipeline.to(device) pipeline = pipeline.to(device0)
# load reconstruction model # load reconstruction model
print('Loading reconstruction model ...') print('Loading reconstruction model ...')
@ -96,9 +103,9 @@ state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k} state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k}
model.load_state_dict(state_dict, strict=True) model.load_state_dict(state_dict, strict=True)
model = model.to(device) model = model.to(device1)
if IS_FLEXICUBES: if IS_FLEXICUBES:
model.init_flexicubes_geometry(device, fovy=30.0) model.init_flexicubes_geometry(device1, fovy=30.0)
model = model.eval() model = model.eval()
print('Loading Finished!') print('Loading Finished!')
@ -124,7 +131,7 @@ def generate_mvs(input_image, sample_steps, sample_seed):
seed_everything(sample_seed) seed_everything(sample_seed)
# sampling # sampling
generator = torch.Generator(device=device) generator = torch.Generator(device=device0)
z123_image = pipeline( z123_image = pipeline(
input_image, input_image,
num_inference_steps=sample_steps, num_inference_steps=sample_steps,
@ -172,11 +179,11 @@ def make3d(images):
images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640) images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320) images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320)
input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to(device) input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to(device1)
render_cameras = get_render_cameras( render_cameras = get_render_cameras(
batch_size=1, radius=4.5, elevation=20.0, is_flexicubes=IS_FLEXICUBES).to(device) batch_size=1, radius=4.5, elevation=20.0, is_flexicubes=IS_FLEXICUBES).to(device1)
images = images.unsqueeze(0).to(device) images = images.unsqueeze(0).to(device1)
images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1) images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name

Loading…
Cancel
Save