| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -23,6 +23,16 @@ from src.utils.infer_util import remove_background, resize_foreground, images_to | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import tempfile | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from huggingface_hub import hf_hub_download | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					if torch.cuda.is_available() and torch.cuda.device_count() >= 2: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    device0 = torch.device('cuda:0') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    device1 = torch.device('cuda:1') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    device0 = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    device1 = device0 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					# Define the cache directory for model files | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					model_cache_dir = './models/' | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					os.makedirs(model_cache_dir, exist_ok=True) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    """ | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -76,29 +86,30 @@ pipeline = DiffusionPipeline.from_pretrained( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    "sudo-ai/zero123plus-v1.2",  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    custom_pipeline="zero123plus", | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    torch_dtype=torch.float16, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    cache_dir=model_cache_dir | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    pipeline.scheduler.config, timestep_spacing='trailing' | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					# load custom white-background UNet | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model", cache_dir=model_cache_dir) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					state_dict = torch.load(unet_ckpt_path, map_location='cpu') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					pipeline.unet.load_state_dict(state_dict, strict=True) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					pipeline = pipeline.to(device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					pipeline = pipeline.to(device0) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					# load reconstruction model | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					print('Loading reconstruction model ...') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="instant_mesh_large.ckpt", repo_type="model") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="instant_mesh_large.ckpt", repo_type="model", cache_dir=model_cache_dir) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					model = instantiate_from_config(model_config) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict'] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k} | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					model.load_state_dict(state_dict, strict=True) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					model = model.to(device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					model = model.to(device1) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					if IS_FLEXICUBES: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    model.init_flexicubes_geometry(device, fovy=30.0) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    model.init_flexicubes_geometry(device1, fovy=30.0) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					model = model.eval() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					print('Loading Finished!') | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -124,7 +135,7 @@ def generate_mvs(input_image, sample_steps, sample_seed): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    seed_everything(sample_seed) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    # sampling | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    generator = torch.Generator(device=device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    generator = torch.Generator(device=device0) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    z123_image = pipeline( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        input_image,  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        num_inference_steps=sample_steps,  | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -172,11 +183,11 @@ def make3d(images): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float()     # (3, 960, 640) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2)        # (6, 3, 320, 320) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to(device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to(device1) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    render_cameras = get_render_cameras( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        batch_size=1, radius=4.5, elevation=20.0, is_flexicubes=IS_FLEXICUBES).to(device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        batch_size=1, radius=4.5, elevation=20.0, is_flexicubes=IS_FLEXICUBES).to(device1) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    images = images.unsqueeze(0).to(device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    images = images.unsqueeze(0).to(device1) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				
					
  |