You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
58 lines
2.0 KiB
58 lines
2.0 KiB
2 years ago
|
from typing import Optional, Sequence
|
||
|
|
||
|
import torch
|
||
|
|
||
|
from shap_e.rendering.blender.constants import (
|
||
|
BASIC_AMBIENT_COLOR,
|
||
|
BASIC_DIFFUSE_COLOR,
|
||
|
UNIFORM_LIGHT_DIRECTION,
|
||
|
)
|
||
|
from shap_e.rendering.view_data import ProjectiveCamera
|
||
|
|
||
|
from .cast import cast_camera
|
||
|
from .types import RayCollisions, TriMesh
|
||
|
|
||
|
|
||
|
def render_diffuse_mesh(
|
||
|
camera: ProjectiveCamera,
|
||
|
mesh: TriMesh,
|
||
|
light_direction: Sequence[float] = tuple(UNIFORM_LIGHT_DIRECTION),
|
||
|
diffuse: float = BASIC_DIFFUSE_COLOR,
|
||
|
ambient: float = BASIC_AMBIENT_COLOR,
|
||
|
ray_batch_size: Optional[int] = None,
|
||
|
checkpoint: Optional[bool] = None,
|
||
|
) -> torch.Tensor:
|
||
|
"""
|
||
|
Return an [H x W x 4] RGBA tensor of the rendered image.
|
||
|
The pixels are floating points, with alpha in the range [0, 1] and the
|
||
|
other colors matching the scale used by the mesh's vertex colors.
|
||
|
"""
|
||
|
light_direction = torch.tensor(
|
||
|
light_direction, device=mesh.vertices.device, dtype=mesh.vertices.dtype
|
||
|
)
|
||
|
|
||
|
all_collisions = RayCollisions.collect(
|
||
|
cast_camera(
|
||
|
camera=camera,
|
||
|
mesh=mesh,
|
||
|
ray_batch_size=ray_batch_size,
|
||
|
checkpoint=checkpoint,
|
||
|
)
|
||
|
)
|
||
|
num_rays = len(all_collisions.normals)
|
||
|
if mesh.vertex_colors is None:
|
||
|
vertex_colors = torch.tensor([[0.8, 0.8, 0.8]]).to(mesh.vertices).repeat(num_rays, 1)
|
||
|
else:
|
||
|
vertex_colors = mesh.vertex_colors
|
||
|
|
||
|
light_coeffs = ambient + (
|
||
|
diffuse * torch.sum(all_collisions.normals * light_direction, dim=-1).abs()
|
||
|
)
|
||
|
vertex_colors = mesh.vertex_colors[mesh.faces[all_collisions.tri_indices]]
|
||
|
bary_products = torch.sum(vertex_colors * all_collisions.barycentric[..., None], axis=-2)
|
||
|
out_colors = bary_products * light_coeffs[..., None]
|
||
|
res = torch.where(all_collisions.collides[:, None], out_colors, torch.zeros_like(out_colors))
|
||
|
return torch.cat([res, all_collisions.collides[:, None].float()], dim=-1).view(
|
||
|
camera.height, camera.width, 4
|
||
|
)
|