You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
290 lines
10 KiB
290 lines
10 KiB
2 years ago
|
from functools import partial
|
||
|
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
||
|
|
||
|
import torch
|
||
|
|
||
|
from shap_e.models.nerf.model import NeRFModel
|
||
|
from shap_e.models.nerf.ray import RayVolumeIntegral, StratifiedRaySampler, render_rays
|
||
|
from shap_e.models.nn.meta import subdict
|
||
|
from shap_e.models.nn.utils import to_torch
|
||
|
from shap_e.models.query import Query
|
||
|
from shap_e.models.renderer import RayRenderer, render_views_from_rays
|
||
|
from shap_e.models.stf.base import Model
|
||
|
from shap_e.models.stf.renderer import STFRendererBase, render_views_from_stf
|
||
|
from shap_e.models.volume import BoundingBoxVolume, Volume
|
||
|
from shap_e.rendering.blender.constants import BASIC_AMBIENT_COLOR, BASIC_DIFFUSE_COLOR
|
||
|
from shap_e.util.collections import AttrDict
|
||
|
|
||
|
|
||
|
class NeRSTFRenderer(RayRenderer, STFRendererBase):
|
||
|
def __init__(
|
||
|
self,
|
||
|
sdf: Optional[Model],
|
||
|
tf: Optional[Model],
|
||
|
nerstf: Optional[Model],
|
||
|
void: NeRFModel,
|
||
|
volume: Volume,
|
||
|
grid_size: int,
|
||
|
n_coarse_samples: int,
|
||
|
n_fine_samples: int,
|
||
|
importance_sampling_options: Optional[Dict[str, Any]] = None,
|
||
|
separate_shared_samples: bool = False,
|
||
|
texture_channels: Sequence[str] = ("R", "G", "B"),
|
||
|
channel_scale: Sequence[float] = (255.0, 255.0, 255.0),
|
||
|
ambient_color: Union[float, Tuple[float]] = BASIC_AMBIENT_COLOR,
|
||
|
diffuse_color: Union[float, Tuple[float]] = BASIC_DIFFUSE_COLOR,
|
||
|
specular_color: Union[float, Tuple[float]] = 0.0,
|
||
|
output_srgb: bool = True,
|
||
|
device: torch.device = torch.device("cuda"),
|
||
|
**kwargs,
|
||
|
):
|
||
|
super().__init__(**kwargs)
|
||
|
assert isinstance(volume, BoundingBoxVolume), "cannot sample points in unknown volume"
|
||
|
assert (nerstf is not None) ^ (sdf is not None and tf is not None)
|
||
|
self.sdf = sdf
|
||
|
self.tf = tf
|
||
|
self.nerstf = nerstf
|
||
|
self.void = void
|
||
|
self.volume = volume
|
||
|
self.grid_size = grid_size
|
||
|
self.n_coarse_samples = n_coarse_samples
|
||
|
self.n_fine_samples = n_fine_samples
|
||
|
self.importance_sampling_options = AttrDict(importance_sampling_options or {})
|
||
|
self.separate_shared_samples = separate_shared_samples
|
||
|
self.texture_channels = texture_channels
|
||
|
self.channel_scale = to_torch(channel_scale).to(device)
|
||
|
self.ambient_color = ambient_color
|
||
|
self.diffuse_color = diffuse_color
|
||
|
self.specular_color = specular_color
|
||
|
self.output_srgb = output_srgb
|
||
|
self.device = device
|
||
|
self.to(device)
|
||
|
|
||
|
def _query(
|
||
|
self,
|
||
|
query: Query,
|
||
|
params: AttrDict[str, torch.Tensor],
|
||
|
options: AttrDict[str, Any],
|
||
|
) -> AttrDict:
|
||
|
no_dir_query = query.copy()
|
||
|
no_dir_query.direction = None
|
||
|
|
||
|
if options.get("rendering_mode", "stf") == "stf":
|
||
|
assert query.direction is None
|
||
|
|
||
|
if self.nerstf is not None:
|
||
|
sdf = tf = self.nerstf(
|
||
|
query,
|
||
|
params=subdict(params, "nerstf"),
|
||
|
options=options,
|
||
|
)
|
||
|
else:
|
||
|
sdf = self.sdf(no_dir_query, params=subdict(params, "sdf"), options=options)
|
||
|
tf = self.tf(query, params=subdict(params, "tf"), options=options)
|
||
|
|
||
|
return AttrDict(
|
||
|
density=sdf.density,
|
||
|
signed_distance=sdf.signed_distance,
|
||
|
channels=tf.channels,
|
||
|
aux_losses=dict(),
|
||
|
)
|
||
|
|
||
|
def render_rays(
|
||
|
self,
|
||
|
batch: AttrDict,
|
||
|
params: Optional[Dict] = None,
|
||
|
options: Optional[AttrDict] = None,
|
||
|
) -> AttrDict:
|
||
|
"""
|
||
|
:param batch: has
|
||
|
|
||
|
- rays: [batch_size x ... x 2 x 3] specify the origin and direction of each ray.
|
||
|
:param options: Optional[Dict]
|
||
|
"""
|
||
|
params = self.update(params)
|
||
|
options = AttrDict() if options is None else AttrDict(options)
|
||
|
|
||
|
# Necessary to tell the TF to use specific NeRF channels.
|
||
|
options.rendering_mode = "nerf"
|
||
|
|
||
|
model = partial(self._query, params=params, options=options)
|
||
|
|
||
|
# First, render rays with coarse, stratified samples.
|
||
|
options.nerf_level = "coarse"
|
||
|
parts = [
|
||
|
RayVolumeIntegral(
|
||
|
model=model,
|
||
|
volume=self.volume,
|
||
|
sampler=StratifiedRaySampler(),
|
||
|
n_samples=self.n_coarse_samples,
|
||
|
),
|
||
|
]
|
||
|
coarse_results, samplers, coarse_raw_outputs = render_rays(
|
||
|
batch.rays,
|
||
|
parts,
|
||
|
self.void,
|
||
|
shared=not self.separate_shared_samples,
|
||
|
render_with_direction=options.render_with_direction,
|
||
|
importance_sampling_options=self.importance_sampling_options,
|
||
|
)
|
||
|
|
||
|
# Then, render with additional importance-weighted ray samples.
|
||
|
options.nerf_level = "fine"
|
||
|
parts = [
|
||
|
RayVolumeIntegral(
|
||
|
model=model,
|
||
|
volume=self.volume,
|
||
|
sampler=samplers[0],
|
||
|
n_samples=self.n_fine_samples,
|
||
|
),
|
||
|
]
|
||
|
fine_results, _, raw_outputs = render_rays(
|
||
|
batch.rays,
|
||
|
parts,
|
||
|
self.void,
|
||
|
shared=not self.separate_shared_samples,
|
||
|
prev_raw_outputs=coarse_raw_outputs,
|
||
|
render_with_direction=options.render_with_direction,
|
||
|
)
|
||
|
raw = raw_outputs[0]
|
||
|
|
||
|
aux_losses = fine_results.output.aux_losses.copy()
|
||
|
if self.separate_shared_samples:
|
||
|
for key, val in coarse_results.output.aux_losses.items():
|
||
|
aux_losses[key + "_coarse"] = val
|
||
|
|
||
|
channels = fine_results.output.channels
|
||
|
shape = [1] * (channels.ndim - 1) + [len(self.texture_channels)]
|
||
|
channels = channels * self.channel_scale.view(*shape)
|
||
|
|
||
|
res = AttrDict(
|
||
|
channels=channels,
|
||
|
transmittance=fine_results.transmittance,
|
||
|
raw_signed_distance=raw.signed_distance,
|
||
|
raw_density=raw.density,
|
||
|
distances=fine_results.output.distances,
|
||
|
t0=fine_results.volume_range.t0,
|
||
|
t1=fine_results.volume_range.t1,
|
||
|
intersected=fine_results.volume_range.intersected,
|
||
|
aux_losses=aux_losses,
|
||
|
)
|
||
|
|
||
|
if self.separate_shared_samples:
|
||
|
res.update(
|
||
|
dict(
|
||
|
channels_coarse=(
|
||
|
coarse_results.output.channels * self.channel_scale.view(*shape)
|
||
|
),
|
||
|
distances_coarse=coarse_results.output.distances,
|
||
|
transmittance_coarse=coarse_results.transmittance,
|
||
|
)
|
||
|
)
|
||
|
|
||
|
return res
|
||
|
|
||
|
def render_views(
|
||
|
self,
|
||
|
batch: AttrDict,
|
||
|
params: Optional[Dict] = None,
|
||
|
options: Optional[AttrDict] = None,
|
||
|
) -> AttrDict:
|
||
|
"""
|
||
|
Returns a backproppable rendering of a view
|
||
|
|
||
|
:param batch: contains either ["poses", "camera"], or ["cameras"]. Can
|
||
|
optionally contain any of ["height", "width", "query_batch_size"]
|
||
|
|
||
|
:param params: Meta parameters
|
||
|
contains rendering_mode in ["stf", "nerf"]
|
||
|
:param options: controls checkpointing, caching, and rendering.
|
||
|
Can provide a `rendering_mode` in ["stf", "nerf"]
|
||
|
"""
|
||
|
params = self.update(params)
|
||
|
options = AttrDict() if options is None else AttrDict(options)
|
||
|
|
||
|
if options.cache is None:
|
||
|
created_cache = True
|
||
|
options.cache = AttrDict()
|
||
|
else:
|
||
|
created_cache = False
|
||
|
|
||
|
rendering_mode = options.get("rendering_mode", "stf")
|
||
|
|
||
|
if rendering_mode == "nerf":
|
||
|
|
||
|
output = render_views_from_rays(
|
||
|
self.render_rays,
|
||
|
batch,
|
||
|
params=params,
|
||
|
options=options,
|
||
|
device=self.device,
|
||
|
)
|
||
|
|
||
|
elif rendering_mode == "stf":
|
||
|
|
||
|
sdf_fn = tf_fn = nerstf_fn = None
|
||
|
if self.nerstf is not None:
|
||
|
nerstf_fn = partial(
|
||
|
self.nerstf.forward_batched,
|
||
|
params=subdict(params, "nerstf"),
|
||
|
options=options,
|
||
|
)
|
||
|
else:
|
||
|
sdf_fn = partial(
|
||
|
self.sdf.forward_batched,
|
||
|
params=subdict(params, "sdf"),
|
||
|
options=options,
|
||
|
)
|
||
|
tf_fn = partial(
|
||
|
self.tf.forward_batched,
|
||
|
params=subdict(params, "tf"),
|
||
|
options=options,
|
||
|
)
|
||
|
output = render_views_from_stf(
|
||
|
batch,
|
||
|
options,
|
||
|
sdf_fn=sdf_fn,
|
||
|
tf_fn=tf_fn,
|
||
|
nerstf_fn=nerstf_fn,
|
||
|
volume=self.volume,
|
||
|
grid_size=self.grid_size,
|
||
|
channel_scale=self.channel_scale,
|
||
|
texture_channels=self.texture_channels,
|
||
|
ambient_color=self.ambient_color,
|
||
|
diffuse_color=self.diffuse_color,
|
||
|
specular_color=self.specular_color,
|
||
|
output_srgb=self.output_srgb,
|
||
|
device=self.device,
|
||
|
)
|
||
|
|
||
|
else:
|
||
|
|
||
|
raise NotImplementedError
|
||
|
|
||
|
if created_cache:
|
||
|
del options["cache"]
|
||
|
|
||
|
return output
|
||
|
|
||
|
def get_signed_distance(
|
||
|
self,
|
||
|
query: Query,
|
||
|
params: Dict[str, torch.Tensor],
|
||
|
options: AttrDict[str, Any],
|
||
|
) -> torch.Tensor:
|
||
|
if self.sdf is not None:
|
||
|
return self.sdf(query, params=subdict(params, "sdf"), options=options).signed_distance
|
||
|
assert self.nerstf is not None
|
||
|
return self.nerstf(query, params=subdict(params, "nerstf"), options=options).signed_distance
|
||
|
|
||
|
def get_texture(
|
||
|
self,
|
||
|
query: Query,
|
||
|
params: Dict[str, torch.Tensor],
|
||
|
options: AttrDict[str, Any],
|
||
|
) -> torch.Tensor:
|
||
|
if self.tf is not None:
|
||
|
return self.tf(query, params=subdict(params, "tf"), options=options).channels
|
||
|
assert self.nerstf is not None
|
||
|
return self.nerstf(query, params=subdict(params, "nerstf"), options=options).channels
|