You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							84 lines
						
					
					
						
							3.0 KiB
						
					
					
				
			
		
		
		
			
			
			
				
					
				
				
					
				
			
		
		
	
	
							84 lines
						
					
					
						
							3.0 KiB
						
					
					
				| import itertools | |
| import json | |
| import zipfile | |
| from typing import BinaryIO, List, Tuple | |
| 
 | |
| import numpy as np | |
| from PIL import Image | |
| 
 | |
| from shap_e.rendering.view_data import Camera, ProjectiveCamera, ViewData | |
| 
 | |
| 
 | |
| class BlenderViewData(ViewData): | |
|     """ | |
|     Interact with a dataset zipfile exported by view_data.py. | |
|     """ | |
| 
 | |
|     def __init__(self, f_obj: BinaryIO): | |
|         self.zipfile = zipfile.ZipFile(f_obj, mode="r") | |
|         self.infos = [] | |
|         with self.zipfile.open("info.json", "r") as f: | |
|             self.info = json.load(f) | |
|         self.channels = list(self.info.get("channels", "RGBAD")) | |
|         assert set("RGBA").issubset( | |
|             set(self.channels) | |
|         ), "The blender output should at least have RGBA images." | |
|         names = set(x.filename for x in self.zipfile.infolist()) | |
|         for i in itertools.count(): | |
|             name = f"{i:05}.json" | |
|             if name not in names: | |
|                 break | |
|             with self.zipfile.open(name, "r") as f: | |
|                 self.infos.append(json.load(f)) | |
| 
 | |
|     @property | |
|     def num_views(self) -> int: | |
|         return len(self.infos) | |
| 
 | |
|     @property | |
|     def channel_names(self) -> List[str]: | |
|         return list(self.channels) | |
| 
 | |
|     def load_view(self, index: int, channels: List[str]) -> Tuple[Camera, np.ndarray]: | |
|         for ch in channels: | |
|             if ch not in self.channel_names: | |
|                 raise ValueError(f"unsupported channel: {ch}") | |
| 
 | |
|         # Gather (a superset of) the requested channels. | |
|         channel_map = {} | |
|         if any(x in channels for x in "RGBA"): | |
|             with self.zipfile.open(f"{index:05}.png", "r") as f: | |
|                 rgba = np.array(Image.open(f)).astype(np.float32) / 255.0 | |
|                 channel_map.update(zip("RGBA", rgba.transpose([2, 0, 1]))) | |
|         if "D" in channels: | |
|             with self.zipfile.open(f"{index:05}_depth.png", "r") as f: | |
|                 # Decode a 16-bit fixed-point number. | |
|                 fp = np.array(Image.open(f)) | |
|                 inf_dist = fp == 0xFFFF | |
|                 channel_map["D"] = np.where( | |
|                     inf_dist, | |
|                     np.inf, | |
|                     self.infos[index]["max_depth"] * (fp.astype(np.float32) / 65536), | |
|                 ) | |
|         if "MatAlpha" in channels: | |
|             with self.zipfile.open(f"{index:05}_MatAlpha.png", "r") as f: | |
|                 channel_map["MatAlpha"] = np.array(Image.open(f)).astype(np.float32) / 65536 | |
| 
 | |
|         # The order of channels is user-specified. | |
|         combined = np.stack([channel_map[k] for k in channels], axis=-1) | |
| 
 | |
|         h, w, _ = combined.shape | |
|         return self.camera(index, w, h), combined | |
| 
 | |
|     def camera(self, index: int, width: int, height: int) -> ProjectiveCamera: | |
|         info = self.infos[index] | |
|         return ProjectiveCamera( | |
|             origin=np.array(info["origin"], dtype=np.float32), | |
|             x=np.array(info["x"], dtype=np.float32), | |
|             y=np.array(info["y"], dtype=np.float32), | |
|             z=np.array(info["z"], dtype=np.float32), | |
|             width=width, | |
|             height=height, | |
|             x_fov=info["x_fov"], | |
|             y_fov=info["y_fov"], | |
|         )
 | |
| 
 |