You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
248 lines
8.1 KiB
248 lines
8.1 KiB
import copy
|
|
import inspect
|
|
from typing import Any, Callable, List, Sequence, Tuple, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
from pytorch3d.renderer import (
|
|
BlendParams,
|
|
DirectionalLights,
|
|
FoVPerspectiveCameras,
|
|
MeshRasterizer,
|
|
MeshRenderer,
|
|
RasterizationSettings,
|
|
SoftPhongShader,
|
|
TexturesVertex,
|
|
)
|
|
from pytorch3d.renderer.utils import TensorProperties
|
|
from pytorch3d.structures import Meshes
|
|
|
|
from shap_e.models.nn.checkpoint import checkpoint
|
|
|
|
from .blender.constants import BASIC_AMBIENT_COLOR, BASIC_DIFFUSE_COLOR, UNIFORM_LIGHT_DIRECTION
|
|
from .torch_mesh import TorchMesh
|
|
from .view_data import ProjectiveCamera
|
|
|
|
# Using a lower value like 1e-4 seems to result in weird issues
|
|
# for our high-poly meshes.
|
|
DEFAULT_RENDER_SIGMA = 1e-5
|
|
|
|
DEFAULT_RENDER_GAMMA = 1e-4
|
|
|
|
|
|
def render_images(
|
|
image_size: int,
|
|
meshes: Meshes,
|
|
cameras: Any,
|
|
lights: Any,
|
|
sigma: float = DEFAULT_RENDER_SIGMA,
|
|
gamma: float = DEFAULT_RENDER_GAMMA,
|
|
max_faces_per_bin=100000,
|
|
faces_per_pixel=50,
|
|
bin_size=None,
|
|
use_checkpoint: bool = False,
|
|
) -> torch.Tensor:
|
|
if use_checkpoint:
|
|
# Decompose all of our arguments into a bunch of tensor lists
|
|
# so that autograd can keep track of what the op depends on.
|
|
verts_list = meshes.verts_list()
|
|
faces_list = meshes.faces_list()
|
|
assert isinstance(meshes.textures, TexturesVertex)
|
|
assert isinstance(lights, BidirectionalLights)
|
|
textures = meshes.textures.verts_features_padded()
|
|
light_vecs, light_fn = _deconstruct_tensor_props(lights)
|
|
camera_vecs, camera_fn = _deconstruct_tensor_props(cameras)
|
|
|
|
def ckpt_fn(
|
|
*args: torch.Tensor,
|
|
num_verts=len(verts_list),
|
|
num_light_vecs=len(light_vecs),
|
|
num_camera_vecs=len(camera_vecs),
|
|
light_fn=light_fn,
|
|
camera_fn=camera_fn,
|
|
faces_list=faces_list
|
|
):
|
|
args = list(args)
|
|
verts_list = args[:num_verts]
|
|
del args[:num_verts]
|
|
light_vecs = args[:num_light_vecs]
|
|
del args[:num_light_vecs]
|
|
camera_vecs = args[:num_camera_vecs]
|
|
del args[:num_camera_vecs]
|
|
textures = args.pop(0)
|
|
|
|
meshes = Meshes(verts=verts_list, faces=faces_list, textures=TexturesVertex(textures))
|
|
lights = light_fn(light_vecs)
|
|
cameras = camera_fn(camera_vecs)
|
|
return render_images(
|
|
image_size=image_size,
|
|
meshes=meshes,
|
|
cameras=cameras,
|
|
lights=lights,
|
|
sigma=sigma,
|
|
gamma=gamma,
|
|
max_faces_per_bin=max_faces_per_bin,
|
|
faces_per_pixel=faces_per_pixel,
|
|
bin_size=bin_size,
|
|
use_checkpoint=False,
|
|
)
|
|
|
|
result = checkpoint(ckpt_fn, (*verts_list, *light_vecs, *camera_vecs, textures), (), True)
|
|
else:
|
|
raster_settings_soft = RasterizationSettings(
|
|
image_size=image_size,
|
|
blur_radius=np.log(1.0 / 1e-4 - 1.0) * sigma,
|
|
faces_per_pixel=faces_per_pixel,
|
|
max_faces_per_bin=max_faces_per_bin,
|
|
bin_size=bin_size,
|
|
perspective_correct=False,
|
|
)
|
|
renderer = MeshRenderer(
|
|
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings_soft),
|
|
shader=SoftPhongShader(
|
|
device=meshes.device,
|
|
cameras=cameras,
|
|
lights=lights,
|
|
blend_params=BlendParams(sigma=sigma, gamma=gamma, background_color=(0, 0, 0)),
|
|
),
|
|
)
|
|
result = renderer(meshes)
|
|
|
|
return result
|
|
|
|
|
|
def _deconstruct_tensor_props(
|
|
props: TensorProperties,
|
|
) -> Tuple[List[torch.Tensor], Callable[[List[torch.Tensor]], TensorProperties]]:
|
|
vecs = []
|
|
names = []
|
|
other_props = {}
|
|
for k in dir(props):
|
|
if k.startswith("__"):
|
|
continue
|
|
v = getattr(props, k)
|
|
if inspect.ismethod(v):
|
|
continue
|
|
if torch.is_tensor(v):
|
|
vecs.append(v)
|
|
names.append(k)
|
|
else:
|
|
other_props[k] = v
|
|
|
|
def recreate_fn(vecs_arg):
|
|
other = type(props)(device=props.device)
|
|
for k, v in other_props.items():
|
|
setattr(other, k, copy.deepcopy(v))
|
|
for name, vec in zip(names, vecs_arg):
|
|
setattr(other, name, vec)
|
|
return other
|
|
|
|
return vecs, recreate_fn
|
|
|
|
|
|
|
|
def convert_meshes(raw_meshes: Sequence[TorchMesh], default_brightness=0.8) -> Meshes:
|
|
meshes = Meshes(
|
|
verts=[mesh.verts for mesh in raw_meshes], faces=[mesh.faces for mesh in raw_meshes]
|
|
)
|
|
rgbs = []
|
|
for mesh in raw_meshes:
|
|
if mesh.vertex_channels and all(k in mesh.vertex_channels for k in "RGB"):
|
|
rgbs.append(torch.stack([mesh.vertex_channels[k] for k in "RGB"], axis=-1))
|
|
else:
|
|
rgbs.append(
|
|
torch.ones(
|
|
len(mesh.verts) * default_brightness,
|
|
3,
|
|
device=mesh.verts.device,
|
|
dtype=mesh.verts.dtype,
|
|
)
|
|
)
|
|
meshes.textures = TexturesVertex(verts_features=rgbs)
|
|
return meshes
|
|
|
|
|
|
def convert_cameras(
|
|
cameras: Sequence[ProjectiveCamera], device: torch.device
|
|
) -> FoVPerspectiveCameras:
|
|
Rs = []
|
|
Ts = []
|
|
for camera in cameras:
|
|
assert (
|
|
camera.width == camera.height and camera.x_fov == camera.y_fov
|
|
), "viewports must be square"
|
|
assert camera.x_fov == cameras[0].x_fov, "all cameras must have same field-of-view"
|
|
R = np.stack([-camera.x, -camera.y, camera.z], axis=0).T
|
|
T = -R.T @ camera.origin
|
|
Rs.append(R)
|
|
Ts.append(T)
|
|
return FoVPerspectiveCameras(
|
|
R=np.stack(Rs, axis=0),
|
|
T=np.stack(Ts, axis=0),
|
|
fov=cameras[0].x_fov,
|
|
degrees=False,
|
|
device=device,
|
|
)
|
|
|
|
|
|
def convert_cameras_torch(
|
|
origins: torch.Tensor, xs: torch.Tensor, ys: torch.Tensor, zs: torch.Tensor, fov: float
|
|
) -> FoVPerspectiveCameras:
|
|
Rs = []
|
|
Ts = []
|
|
for origin, x, y, z in zip(origins, xs, ys, zs):
|
|
R = torch.stack([-x, -y, z], axis=0).T
|
|
T = -R.T @ origin
|
|
Rs.append(R)
|
|
Ts.append(T)
|
|
return FoVPerspectiveCameras(
|
|
R=torch.stack(Rs, dim=0),
|
|
T=torch.stack(Ts, dim=0),
|
|
fov=fov,
|
|
degrees=False,
|
|
device=origins.device,
|
|
)
|
|
|
|
|
|
def blender_uniform_lights(
|
|
batch_size: int,
|
|
device: torch.device,
|
|
ambient_color: Union[float, Tuple[float]] = BASIC_AMBIENT_COLOR,
|
|
diffuse_color: Union[float, Tuple[float]] = BASIC_DIFFUSE_COLOR,
|
|
specular_color: Union[float, Tuple[float]] = 0.0,
|
|
) -> "BidirectionalLights":
|
|
"""
|
|
Create a light that attempts to match the light used by the Blender
|
|
renderer when run with `--light_mode basic`.
|
|
"""
|
|
if isinstance(ambient_color, float):
|
|
ambient_color = (ambient_color,) * 3
|
|
if isinstance(diffuse_color, float):
|
|
diffuse_color = (diffuse_color,) * 3
|
|
if isinstance(specular_color, float):
|
|
specular_color = (specular_color,) * 3
|
|
return BidirectionalLights(
|
|
ambient_color=(ambient_color,) * batch_size,
|
|
diffuse_color=(diffuse_color,) * batch_size,
|
|
specular_color=(specular_color,) * batch_size,
|
|
direction=(UNIFORM_LIGHT_DIRECTION,) * batch_size,
|
|
device=device,
|
|
)
|
|
|
|
|
|
class BidirectionalLights(DirectionalLights):
|
|
"""
|
|
Adapted from here, but effectively shines the light in both positive and negative directions:
|
|
https://github.com/facebookresearch/pytorch3d/blob/efea540bbcab56fccde6f4bc729d640a403dac56/pytorch3d/renderer/lighting.py#L159
|
|
"""
|
|
|
|
def diffuse(self, normals, points=None) -> torch.Tensor:
|
|
return torch.maximum(
|
|
super().diffuse(normals, points=points), super().diffuse(-normals, points=points)
|
|
)
|
|
|
|
def specular(self, normals, points, camera_position, shininess) -> torch.Tensor:
|
|
return torch.maximum(
|
|
super().specular(normals, points, camera_position, shininess),
|
|
super().specular(-normals, points, camera_position, shininess),
|
|
)
|
|
|