|
@ -256,6 +256,16 @@ def render_views_from_stf( |
|
|
if output_srgb: |
|
|
if output_srgb: |
|
|
tf_out.channels = _convert_srgb_to_linear(tf_out.channels) |
|
|
tf_out.channels = _convert_srgb_to_linear(tf_out.channels) |
|
|
|
|
|
|
|
|
|
|
|
# Make sure the raw meshes have colors. |
|
|
|
|
|
with torch.autocast(device_type, enabled=False): |
|
|
|
|
|
textures = tf_out.channels.float() |
|
|
|
|
|
assert len(textures.shape) == 3 and textures.shape[-1] == len( |
|
|
|
|
|
texture_channels |
|
|
|
|
|
), f"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}" |
|
|
|
|
|
for m, texture in zip(raw_meshes, textures): |
|
|
|
|
|
texture = texture[: len(m.verts)] |
|
|
|
|
|
m.vertex_channels = {name: ch for name, ch in zip(texture_channels, texture.unbind(-1))} |
|
|
|
|
|
|
|
|
args = dict( |
|
|
args = dict( |
|
|
options=options, |
|
|
options=options, |
|
|
texture_channels=texture_channels, |
|
|
texture_channels=texture_channels, |
|
@ -315,6 +325,8 @@ def _render_with_pytorch3d( |
|
|
raw_meshes: List[TorchMesh], |
|
|
raw_meshes: List[TorchMesh], |
|
|
tf_out: AttrDict, |
|
|
tf_out: AttrDict, |
|
|
): |
|
|
): |
|
|
|
|
|
_ = tf_out |
|
|
|
|
|
|
|
|
# Lazy import because pytorch3d is installed lazily. |
|
|
# Lazy import because pytorch3d is installed lazily. |
|
|
from shap_e.rendering.pytorch3d_util import ( |
|
|
from shap_e.rendering.pytorch3d_util import ( |
|
|
blender_uniform_lights, |
|
|
blender_uniform_lights, |
|
@ -328,14 +340,6 @@ def _render_with_pytorch3d( |
|
|
device_type = device.type |
|
|
device_type = device.type |
|
|
|
|
|
|
|
|
with torch.autocast(device_type, enabled=False): |
|
|
with torch.autocast(device_type, enabled=False): |
|
|
textures = tf_out.channels.float() |
|
|
|
|
|
assert len(textures.shape) == 3 and textures.shape[-1] == len( |
|
|
|
|
|
texture_channels |
|
|
|
|
|
), f"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}" |
|
|
|
|
|
for m, texture in zip(raw_meshes, textures): |
|
|
|
|
|
texture = texture[: len(m.verts)] |
|
|
|
|
|
m.vertex_channels = {name: ch for name, ch in zip(texture_channels, texture.unbind(-1))} |
|
|
|
|
|
|
|
|
|
|
|
meshes = convert_meshes(raw_meshes) |
|
|
meshes = convert_meshes(raw_meshes) |
|
|
|
|
|
|
|
|
lights = blender_uniform_lights( |
|
|
lights = blender_uniform_lights( |
|
|