a fork of shap-e for gc
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

257 lines
8.4 KiB

2 years ago
import tempfile
from contextlib import contextmanager
from typing import Iterator, Optional, Union
import blobfile as bf
import numpy as np
import torch
from PIL import Image
from shap_e.rendering.blender.render import render_mesh, render_model
from shap_e.rendering.blender.view_data import BlenderViewData
from shap_e.rendering.mesh import TriMesh
from shap_e.rendering.point_cloud import PointCloud
from shap_e.rendering.view_data import ViewData
from shap_e.util.collections import AttrDict
from shap_e.util.image_util import center_crop, get_alpha, remove_alpha, resize
def load_or_create_multimodal_batch(
device: torch.device,
*,
mesh_path: Optional[str] = None,
model_path: Optional[str] = None,
cache_dir: Optional[str] = None,
point_count: int = 2**14,
random_sample_count: int = 2**19,
pc_num_views: int = 40,
mv_light_mode: Optional[str] = None,
mv_num_views: int = 20,
mv_image_size: int = 512,
mv_alpha_removal: str = "black",
verbose: bool = False,
) -> AttrDict:
if verbose:
print("creating point cloud...")
pc = load_or_create_pc(
mesh_path=mesh_path,
model_path=model_path,
cache_dir=cache_dir,
random_sample_count=random_sample_count,
point_count=point_count,
num_views=pc_num_views,
verbose=verbose,
)
raw_pc = np.concatenate([pc.coords, pc.select_channels(["R", "G", "B"])], axis=-1)
encode_me = torch.from_numpy(raw_pc).float().to(device)
batch = AttrDict(points=encode_me.t()[None])
if mv_light_mode:
if verbose:
print("creating multiview...")
with load_or_create_multiview(
mesh_path=mesh_path,
model_path=model_path,
cache_dir=cache_dir,
num_views=mv_num_views,
extract_material=False,
light_mode=mv_light_mode,
verbose=verbose,
) as mv:
cameras, views, view_alphas, depths = [], [], [], []
for view_idx in range(mv.num_views):
camera, view = mv.load_view(
view_idx,
["R", "G", "B", "A"] if "A" in mv.channel_names else ["R", "G", "B"],
)
depth = None
if "D" in mv.channel_names:
_, depth = mv.load_view(view_idx, ["D"])
depth = process_depth(depth, mv_image_size)
view, alpha = process_image(
np.round(view * 255.0).astype(np.uint8), mv_alpha_removal, mv_image_size
)
camera = camera.center_crop().resize_image(mv_image_size, mv_image_size)
cameras.append(camera)
views.append(view)
view_alphas.append(alpha)
depths.append(depth)
batch.depths = [depths]
batch.views = [views]
batch.view_alphas = [view_alphas]
batch.cameras = [cameras]
return normalize_input_batch(batch, pc_scale=2.0, color_scale=1.0 / 255.0)
def load_or_create_pc(
*,
mesh_path: Optional[str],
model_path: Optional[str],
cache_dir: Optional[str],
random_sample_count: int,
point_count: int,
num_views: int,
verbose: bool = False,
) -> PointCloud:
assert (model_path is not None) ^ (
mesh_path is not None
), "must specify exactly one of model_path or mesh_path"
path = model_path if model_path is not None else mesh_path
if cache_dir is not None:
cache_path = bf.join(
cache_dir,
f"pc_{bf.basename(path)}_mat_{num_views}_{random_sample_count}_{point_count}.npz",
)
if bf.exists(cache_path):
return PointCloud.load(cache_path)
else:
cache_path = None
with load_or_create_multiview(
mesh_path=mesh_path,
model_path=model_path,
cache_dir=cache_dir,
num_views=num_views,
verbose=verbose,
) as mv:
if verbose:
print("extracting point cloud from multiview...")
pc = mv_to_pc(
multiview=mv, random_sample_count=random_sample_count, point_count=point_count
)
if cache_path is not None:
pc.save(cache_path)
return pc
@contextmanager
def load_or_create_multiview(
*,
mesh_path: Optional[str],
model_path: Optional[str],
cache_dir: Optional[str],
num_views: int = 20,
extract_material: bool = True,
light_mode: Optional[str] = None,
verbose: bool = False,
) -> Iterator[BlenderViewData]:
assert (model_path is not None) ^ (
mesh_path is not None
), "must specify exactly one of model_path or mesh_path"
path = model_path if model_path is not None else mesh_path
if extract_material:
assert light_mode is None, "light_mode is ignored when extract_material=True"
else:
assert light_mode is not None, "must specify light_mode when extract_material=False"
if cache_dir is not None:
if extract_material:
cache_path = bf.join(cache_dir, f"mv_{bf.basename(path)}_mat_{num_views}.zip")
else:
cache_path = bf.join(cache_dir, f"mv_{bf.basename(path)}_{light_mode}_{num_views}.zip")
if bf.exists(cache_path):
with bf.BlobFile(cache_path, "rb") as f:
yield BlenderViewData(f)
return
else:
cache_path = None
common_kwargs = dict(
fast_mode=True,
extract_material=extract_material,
camera_pose="random",
light_mode=light_mode or "uniform",
verbose=verbose,
)
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = bf.join(tmp_dir, "out.zip")
if mesh_path is not None:
mesh = TriMesh.load(mesh_path)
render_mesh(
mesh=mesh,
output_path=tmp_path,
num_images=num_views,
backend="BLENDER_EEVEE",
**common_kwargs,
)
elif model_path is not None:
render_model(
model_path,
output_path=tmp_path,
num_images=num_views,
backend="BLENDER_EEVEE",
**common_kwargs,
)
if cache_path is not None:
bf.copy(tmp_path, cache_path)
with bf.BlobFile(tmp_path, "rb") as f:
yield BlenderViewData(f)
def mv_to_pc(multiview: ViewData, random_sample_count: int, point_count: int) -> PointCloud:
pc = PointCloud.from_rgbd(multiview)
# Handle empty samples.
if len(pc.coords) == 0:
pc = PointCloud(
coords=np.zeros([1, 3]),
channels=dict(zip("RGB", np.zeros([3, 1]))),
)
while len(pc.coords) < point_count:
pc = pc.combine(pc)
# Prevent duplicate points; some models may not like it.
pc.coords += np.random.normal(size=pc.coords.shape) * 1e-4
pc = pc.random_sample(random_sample_count)
pc = pc.farthest_point_sample(point_count, average_neighbors=True)
return pc
def normalize_input_batch(batch: AttrDict, *, pc_scale: float, color_scale: float) -> AttrDict:
res = batch.copy()
scale_vec = torch.tensor([*([pc_scale] * 3), *([color_scale] * 3)], device=batch.points.device)
res.points = res.points * scale_vec[:, None]
if "cameras" in res:
res.cameras = [[cam.scale_scene(pc_scale) for cam in cams] for cams in res.cameras]
if "depths" in res:
res.depths = [[depth * pc_scale for depth in depths] for depths in res.depths]
return res
def process_depth(depth_img: np.ndarray, image_size: int) -> np.ndarray:
depth_img = center_crop(depth_img)
depth_img = resize(depth_img, width=image_size, height=image_size)
return np.squeeze(depth_img)
def process_image(
img_or_img_arr: Union[Image.Image, np.ndarray], alpha_removal: str, image_size: int
):
if isinstance(img_or_img_arr, np.ndarray):
img = Image.fromarray(img_or_img_arr)
img_arr = img_or_img_arr
else:
img = img_or_img_arr
img_arr = np.array(img)
if len(img_arr.shape) == 2:
# Grayscale
rgb = Image.new("RGB", img.size)
rgb.paste(img)
img = rgb
img_arr = np.array(img)
img = center_crop(img)
alpha = get_alpha(img)
img = remove_alpha(img, mode=alpha_removal)
alpha = alpha.resize((image_size,) * 2, resample=Image.BILINEAR)
img = img.resize((image_size,) * 2, resample=Image.BILINEAR)
return img, alpha