You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
257 lines
8.4 KiB
257 lines
8.4 KiB
2 years ago
|
import tempfile
|
||
|
from contextlib import contextmanager
|
||
|
from typing import Iterator, Optional, Union
|
||
|
|
||
|
import blobfile as bf
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
from PIL import Image
|
||
|
|
||
|
from shap_e.rendering.blender.render import render_mesh, render_model
|
||
|
from shap_e.rendering.blender.view_data import BlenderViewData
|
||
|
from shap_e.rendering.mesh import TriMesh
|
||
|
from shap_e.rendering.point_cloud import PointCloud
|
||
|
from shap_e.rendering.view_data import ViewData
|
||
|
from shap_e.util.collections import AttrDict
|
||
|
from shap_e.util.image_util import center_crop, get_alpha, remove_alpha, resize
|
||
|
|
||
|
|
||
|
def load_or_create_multimodal_batch(
|
||
|
device: torch.device,
|
||
|
*,
|
||
|
mesh_path: Optional[str] = None,
|
||
|
model_path: Optional[str] = None,
|
||
|
cache_dir: Optional[str] = None,
|
||
|
point_count: int = 2**14,
|
||
|
random_sample_count: int = 2**19,
|
||
|
pc_num_views: int = 40,
|
||
|
mv_light_mode: Optional[str] = None,
|
||
|
mv_num_views: int = 20,
|
||
|
mv_image_size: int = 512,
|
||
|
mv_alpha_removal: str = "black",
|
||
|
verbose: bool = False,
|
||
|
) -> AttrDict:
|
||
|
if verbose:
|
||
|
print("creating point cloud...")
|
||
|
pc = load_or_create_pc(
|
||
|
mesh_path=mesh_path,
|
||
|
model_path=model_path,
|
||
|
cache_dir=cache_dir,
|
||
|
random_sample_count=random_sample_count,
|
||
|
point_count=point_count,
|
||
|
num_views=pc_num_views,
|
||
|
verbose=verbose,
|
||
|
)
|
||
|
raw_pc = np.concatenate([pc.coords, pc.select_channels(["R", "G", "B"])], axis=-1)
|
||
|
encode_me = torch.from_numpy(raw_pc).float().to(device)
|
||
|
batch = AttrDict(points=encode_me.t()[None])
|
||
|
if mv_light_mode:
|
||
|
if verbose:
|
||
|
print("creating multiview...")
|
||
|
with load_or_create_multiview(
|
||
|
mesh_path=mesh_path,
|
||
|
model_path=model_path,
|
||
|
cache_dir=cache_dir,
|
||
|
num_views=mv_num_views,
|
||
|
extract_material=False,
|
||
|
light_mode=mv_light_mode,
|
||
|
verbose=verbose,
|
||
|
) as mv:
|
||
|
cameras, views, view_alphas, depths = [], [], [], []
|
||
|
for view_idx in range(mv.num_views):
|
||
|
camera, view = mv.load_view(
|
||
|
view_idx,
|
||
|
["R", "G", "B", "A"] if "A" in mv.channel_names else ["R", "G", "B"],
|
||
|
)
|
||
|
depth = None
|
||
|
if "D" in mv.channel_names:
|
||
|
_, depth = mv.load_view(view_idx, ["D"])
|
||
|
depth = process_depth(depth, mv_image_size)
|
||
|
view, alpha = process_image(
|
||
|
np.round(view * 255.0).astype(np.uint8), mv_alpha_removal, mv_image_size
|
||
|
)
|
||
|
camera = camera.center_crop().resize_image(mv_image_size, mv_image_size)
|
||
|
cameras.append(camera)
|
||
|
views.append(view)
|
||
|
view_alphas.append(alpha)
|
||
|
depths.append(depth)
|
||
|
batch.depths = [depths]
|
||
|
batch.views = [views]
|
||
|
batch.view_alphas = [view_alphas]
|
||
|
batch.cameras = [cameras]
|
||
|
return normalize_input_batch(batch, pc_scale=2.0, color_scale=1.0 / 255.0)
|
||
|
|
||
|
|
||
|
def load_or_create_pc(
|
||
|
*,
|
||
|
mesh_path: Optional[str],
|
||
|
model_path: Optional[str],
|
||
|
cache_dir: Optional[str],
|
||
|
random_sample_count: int,
|
||
|
point_count: int,
|
||
|
num_views: int,
|
||
|
verbose: bool = False,
|
||
|
) -> PointCloud:
|
||
|
|
||
|
assert (model_path is not None) ^ (
|
||
|
mesh_path is not None
|
||
|
), "must specify exactly one of model_path or mesh_path"
|
||
|
path = model_path if model_path is not None else mesh_path
|
||
|
|
||
|
if cache_dir is not None:
|
||
|
cache_path = bf.join(
|
||
|
cache_dir,
|
||
|
f"pc_{bf.basename(path)}_mat_{num_views}_{random_sample_count}_{point_count}.npz",
|
||
|
)
|
||
|
if bf.exists(cache_path):
|
||
|
return PointCloud.load(cache_path)
|
||
|
else:
|
||
|
cache_path = None
|
||
|
|
||
|
with load_or_create_multiview(
|
||
|
mesh_path=mesh_path,
|
||
|
model_path=model_path,
|
||
|
cache_dir=cache_dir,
|
||
|
num_views=num_views,
|
||
|
verbose=verbose,
|
||
|
) as mv:
|
||
|
if verbose:
|
||
|
print("extracting point cloud from multiview...")
|
||
|
pc = mv_to_pc(
|
||
|
multiview=mv, random_sample_count=random_sample_count, point_count=point_count
|
||
|
)
|
||
|
if cache_path is not None:
|
||
|
pc.save(cache_path)
|
||
|
return pc
|
||
|
|
||
|
|
||
|
@contextmanager
|
||
|
def load_or_create_multiview(
|
||
|
*,
|
||
|
mesh_path: Optional[str],
|
||
|
model_path: Optional[str],
|
||
|
cache_dir: Optional[str],
|
||
|
num_views: int = 20,
|
||
|
extract_material: bool = True,
|
||
|
light_mode: Optional[str] = None,
|
||
|
verbose: bool = False,
|
||
|
) -> Iterator[BlenderViewData]:
|
||
|
|
||
|
assert (model_path is not None) ^ (
|
||
|
mesh_path is not None
|
||
|
), "must specify exactly one of model_path or mesh_path"
|
||
|
path = model_path if model_path is not None else mesh_path
|
||
|
|
||
|
if extract_material:
|
||
|
assert light_mode is None, "light_mode is ignored when extract_material=True"
|
||
|
else:
|
||
|
assert light_mode is not None, "must specify light_mode when extract_material=False"
|
||
|
|
||
|
if cache_dir is not None:
|
||
|
if extract_material:
|
||
|
cache_path = bf.join(cache_dir, f"mv_{bf.basename(path)}_mat_{num_views}.zip")
|
||
|
else:
|
||
|
cache_path = bf.join(cache_dir, f"mv_{bf.basename(path)}_{light_mode}_{num_views}.zip")
|
||
|
if bf.exists(cache_path):
|
||
|
with bf.BlobFile(cache_path, "rb") as f:
|
||
|
yield BlenderViewData(f)
|
||
|
return
|
||
|
else:
|
||
|
cache_path = None
|
||
|
|
||
|
common_kwargs = dict(
|
||
|
fast_mode=True,
|
||
|
extract_material=extract_material,
|
||
|
camera_pose="random",
|
||
|
light_mode=light_mode or "uniform",
|
||
|
verbose=verbose,
|
||
|
)
|
||
|
|
||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||
|
tmp_path = bf.join(tmp_dir, "out.zip")
|
||
|
if mesh_path is not None:
|
||
|
mesh = TriMesh.load(mesh_path)
|
||
|
render_mesh(
|
||
|
mesh=mesh,
|
||
|
output_path=tmp_path,
|
||
|
num_images=num_views,
|
||
|
backend="BLENDER_EEVEE",
|
||
|
**common_kwargs,
|
||
|
)
|
||
|
elif model_path is not None:
|
||
|
render_model(
|
||
|
model_path,
|
||
|
output_path=tmp_path,
|
||
|
num_images=num_views,
|
||
|
backend="BLENDER_EEVEE",
|
||
|
**common_kwargs,
|
||
|
)
|
||
|
if cache_path is not None:
|
||
|
bf.copy(tmp_path, cache_path)
|
||
|
with bf.BlobFile(tmp_path, "rb") as f:
|
||
|
yield BlenderViewData(f)
|
||
|
|
||
|
|
||
|
def mv_to_pc(multiview: ViewData, random_sample_count: int, point_count: int) -> PointCloud:
|
||
|
pc = PointCloud.from_rgbd(multiview)
|
||
|
|
||
|
# Handle empty samples.
|
||
|
if len(pc.coords) == 0:
|
||
|
pc = PointCloud(
|
||
|
coords=np.zeros([1, 3]),
|
||
|
channels=dict(zip("RGB", np.zeros([3, 1]))),
|
||
|
)
|
||
|
while len(pc.coords) < point_count:
|
||
|
pc = pc.combine(pc)
|
||
|
# Prevent duplicate points; some models may not like it.
|
||
|
pc.coords += np.random.normal(size=pc.coords.shape) * 1e-4
|
||
|
|
||
|
pc = pc.random_sample(random_sample_count)
|
||
|
pc = pc.farthest_point_sample(point_count, average_neighbors=True)
|
||
|
|
||
|
return pc
|
||
|
|
||
|
|
||
|
def normalize_input_batch(batch: AttrDict, *, pc_scale: float, color_scale: float) -> AttrDict:
|
||
|
res = batch.copy()
|
||
|
scale_vec = torch.tensor([*([pc_scale] * 3), *([color_scale] * 3)], device=batch.points.device)
|
||
|
res.points = res.points * scale_vec[:, None]
|
||
|
|
||
|
if "cameras" in res:
|
||
|
res.cameras = [[cam.scale_scene(pc_scale) for cam in cams] for cams in res.cameras]
|
||
|
|
||
|
if "depths" in res:
|
||
|
res.depths = [[depth * pc_scale for depth in depths] for depths in res.depths]
|
||
|
|
||
|
return res
|
||
|
|
||
|
|
||
|
def process_depth(depth_img: np.ndarray, image_size: int) -> np.ndarray:
|
||
|
depth_img = center_crop(depth_img)
|
||
|
depth_img = resize(depth_img, width=image_size, height=image_size)
|
||
|
return np.squeeze(depth_img)
|
||
|
|
||
|
|
||
|
def process_image(
|
||
|
img_or_img_arr: Union[Image.Image, np.ndarray], alpha_removal: str, image_size: int
|
||
|
):
|
||
|
if isinstance(img_or_img_arr, np.ndarray):
|
||
|
img = Image.fromarray(img_or_img_arr)
|
||
|
img_arr = img_or_img_arr
|
||
|
else:
|
||
|
img = img_or_img_arr
|
||
|
img_arr = np.array(img)
|
||
|
if len(img_arr.shape) == 2:
|
||
|
# Grayscale
|
||
|
rgb = Image.new("RGB", img.size)
|
||
|
rgb.paste(img)
|
||
|
img = rgb
|
||
|
img_arr = np.array(img)
|
||
|
|
||
|
img = center_crop(img)
|
||
|
alpha = get_alpha(img)
|
||
|
img = remove_alpha(img, mode=alpha_removal)
|
||
|
alpha = alpha.resize((image_size,) * 2, resample=Image.BILINEAR)
|
||
|
img = img.resize((image_size,) * 2, resample=Image.BILINEAR)
|
||
|
return img, alpha
|