a fork of shap-e for gc
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

289 lines
10 KiB

from functools import partial
from typing import Any, Dict, Optional, Sequence, Tuple, Union
import torch
from shap_e.models.nerf.model import NeRFModel
from shap_e.models.nerf.ray import RayVolumeIntegral, StratifiedRaySampler, render_rays
from shap_e.models.nn.meta import subdict
from shap_e.models.nn.utils import to_torch
from shap_e.models.query import Query
from shap_e.models.renderer import RayRenderer, render_views_from_rays
from shap_e.models.stf.base import Model
from shap_e.models.stf.renderer import STFRendererBase, render_views_from_stf
from shap_e.models.volume import BoundingBoxVolume, Volume
from shap_e.rendering.blender.constants import BASIC_AMBIENT_COLOR, BASIC_DIFFUSE_COLOR
from shap_e.util.collections import AttrDict
class NeRSTFRenderer(RayRenderer, STFRendererBase):
def __init__(
self,
sdf: Optional[Model],
tf: Optional[Model],
nerstf: Optional[Model],
void: NeRFModel,
volume: Volume,
grid_size: int,
n_coarse_samples: int,
n_fine_samples: int,
importance_sampling_options: Optional[Dict[str, Any]] = None,
separate_shared_samples: bool = False,
texture_channels: Sequence[str] = ("R", "G", "B"),
channel_scale: Sequence[float] = (255.0, 255.0, 255.0),
ambient_color: Union[float, Tuple[float]] = BASIC_AMBIENT_COLOR,
diffuse_color: Union[float, Tuple[float]] = BASIC_DIFFUSE_COLOR,
specular_color: Union[float, Tuple[float]] = 0.0,
output_srgb: bool = True,
device: torch.device = torch.device("cuda"),
**kwargs,
):
super().__init__(**kwargs)
assert isinstance(volume, BoundingBoxVolume), "cannot sample points in unknown volume"
assert (nerstf is not None) ^ (sdf is not None and tf is not None)
self.sdf = sdf
self.tf = tf
self.nerstf = nerstf
self.void = void
self.volume = volume
self.grid_size = grid_size
self.n_coarse_samples = n_coarse_samples
self.n_fine_samples = n_fine_samples
self.importance_sampling_options = AttrDict(importance_sampling_options or {})
self.separate_shared_samples = separate_shared_samples
self.texture_channels = texture_channels
self.channel_scale = to_torch(channel_scale).to(device)
self.ambient_color = ambient_color
self.diffuse_color = diffuse_color
self.specular_color = specular_color
self.output_srgb = output_srgb
self.device = device
self.to(device)
def _query(
self,
query: Query,
params: AttrDict[str, torch.Tensor],
options: AttrDict[str, Any],
) -> AttrDict:
no_dir_query = query.copy()
no_dir_query.direction = None
if options.get("rendering_mode", "stf") == "stf":
assert query.direction is None
if self.nerstf is not None:
sdf = tf = self.nerstf(
query,
params=subdict(params, "nerstf"),
options=options,
)
else:
sdf = self.sdf(no_dir_query, params=subdict(params, "sdf"), options=options)
tf = self.tf(query, params=subdict(params, "tf"), options=options)
return AttrDict(
density=sdf.density,
signed_distance=sdf.signed_distance,
channels=tf.channels,
aux_losses=dict(),
)
def render_rays(
self,
batch: AttrDict,
params: Optional[Dict] = None,
options: Optional[AttrDict] = None,
) -> AttrDict:
"""
:param batch: has
- rays: [batch_size x ... x 2 x 3] specify the origin and direction of each ray.
:param options: Optional[Dict]
"""
params = self.update(params)
options = AttrDict() if options is None else AttrDict(options)
# Necessary to tell the TF to use specific NeRF channels.
options.rendering_mode = "nerf"
model = partial(self._query, params=params, options=options)
# First, render rays with coarse, stratified samples.
options.nerf_level = "coarse"
parts = [
RayVolumeIntegral(
model=model,
volume=self.volume,
sampler=StratifiedRaySampler(),
n_samples=self.n_coarse_samples,
),
]
coarse_results, samplers, coarse_raw_outputs = render_rays(
batch.rays,
parts,
self.void,
shared=not self.separate_shared_samples,
render_with_direction=options.render_with_direction,
importance_sampling_options=self.importance_sampling_options,
)
# Then, render with additional importance-weighted ray samples.
options.nerf_level = "fine"
parts = [
RayVolumeIntegral(
model=model,
volume=self.volume,
sampler=samplers[0],
n_samples=self.n_fine_samples,
),
]
fine_results, _, raw_outputs = render_rays(
batch.rays,
parts,
self.void,
shared=not self.separate_shared_samples,
prev_raw_outputs=coarse_raw_outputs,
render_with_direction=options.render_with_direction,
)
raw = raw_outputs[0]
aux_losses = fine_results.output.aux_losses.copy()
if self.separate_shared_samples:
for key, val in coarse_results.output.aux_losses.items():
aux_losses[key + "_coarse"] = val
channels = fine_results.output.channels
shape = [1] * (channels.ndim - 1) + [len(self.texture_channels)]
channels = channels * self.channel_scale.view(*shape)
res = AttrDict(
channels=channels,
transmittance=fine_results.transmittance,
raw_signed_distance=raw.signed_distance,
raw_density=raw.density,
distances=fine_results.output.distances,
t0=fine_results.volume_range.t0,
t1=fine_results.volume_range.t1,
intersected=fine_results.volume_range.intersected,
aux_losses=aux_losses,
)
if self.separate_shared_samples:
res.update(
dict(
channels_coarse=(
coarse_results.output.channels * self.channel_scale.view(*shape)
),
distances_coarse=coarse_results.output.distances,
transmittance_coarse=coarse_results.transmittance,
)
)
return res
def render_views(
self,
batch: AttrDict,
params: Optional[Dict] = None,
options: Optional[AttrDict] = None,
) -> AttrDict:
"""
Returns a backproppable rendering of a view
:param batch: contains either ["poses", "camera"], or ["cameras"]. Can
optionally contain any of ["height", "width", "query_batch_size"]
:param params: Meta parameters
contains rendering_mode in ["stf", "nerf"]
:param options: controls checkpointing, caching, and rendering.
Can provide a `rendering_mode` in ["stf", "nerf"]
"""
params = self.update(params)
options = AttrDict() if options is None else AttrDict(options)
if options.cache is None:
created_cache = True
options.cache = AttrDict()
else:
created_cache = False
rendering_mode = options.get("rendering_mode", "stf")
if rendering_mode == "nerf":
output = render_views_from_rays(
self.render_rays,
batch,
params=params,
options=options,
device=self.device,
)
elif rendering_mode == "stf":
sdf_fn = tf_fn = nerstf_fn = None
if self.nerstf is not None:
nerstf_fn = partial(
self.nerstf.forward_batched,
params=subdict(params, "nerstf"),
options=options,
)
else:
sdf_fn = partial(
self.sdf.forward_batched,
params=subdict(params, "sdf"),
options=options,
)
tf_fn = partial(
self.tf.forward_batched,
params=subdict(params, "tf"),
options=options,
)
output = render_views_from_stf(
batch,
options,
sdf_fn=sdf_fn,
tf_fn=tf_fn,
nerstf_fn=nerstf_fn,
volume=self.volume,
grid_size=self.grid_size,
channel_scale=self.channel_scale,
texture_channels=self.texture_channels,
ambient_color=self.ambient_color,
diffuse_color=self.diffuse_color,
specular_color=self.specular_color,
output_srgb=self.output_srgb,
device=self.device,
)
else:
raise NotImplementedError
if created_cache:
del options["cache"]
return output
def get_signed_distance(
self,
query: Query,
params: Dict[str, torch.Tensor],
options: AttrDict[str, Any],
) -> torch.Tensor:
if self.sdf is not None:
return self.sdf(query, params=subdict(params, "sdf"), options=options).signed_distance
assert self.nerstf is not None
return self.nerstf(query, params=subdict(params, "nerstf"), options=options).signed_distance
def get_texture(
self,
query: Query,
params: Dict[str, torch.Tensor],
options: AttrDict[str, Any],
) -> torch.Tensor:
if self.tf is not None:
return self.tf(query, params=subdict(params, "tf"), options=options).channels
assert self.nerstf is not None
return self.nerstf(query, params=subdict(params, "nerstf"), options=options).channels