You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
254 lines
8.8 KiB
254 lines
8.8 KiB
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass
|
|
from typing import Dict, Optional, Tuple
|
|
|
|
import torch
|
|
|
|
from shap_e.models.nn.meta import MetaModule
|
|
from shap_e.models.nn.utils import ArrayType, safe_divide, to_torch
|
|
|
|
|
|
@dataclass
|
|
class VolumeRange:
|
|
t0: torch.Tensor
|
|
t1: torch.Tensor
|
|
intersected: torch.Tensor
|
|
|
|
def __post_init__(self):
|
|
assert self.t0.shape == self.t1.shape == self.intersected.shape
|
|
|
|
def next_t0(self):
|
|
"""
|
|
Given convex volume1 and volume2, where volume1 is contained in
|
|
volume2, this function returns the t0 at which rays leave volume1 and
|
|
intersect with volume2 \\ volume1.
|
|
"""
|
|
return self.t1 * self.intersected.float()
|
|
|
|
def extend(self, another: "VolumeRange") -> "VolumeRange":
|
|
"""
|
|
The ranges at which rays intersect with either one, or both, or none of
|
|
the self and another are merged together.
|
|
"""
|
|
return VolumeRange(
|
|
t0=torch.where(self.intersected, self.t0, another.t0),
|
|
t1=torch.where(another.intersected, another.t1, self.t1),
|
|
intersected=torch.logical_or(self.intersected, another.intersected),
|
|
)
|
|
|
|
def partition(self, ts) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Partitions t0 and t1 into n_samples intervals.
|
|
|
|
:param ts: [batch_size, *shape, n_samples, 1]
|
|
:return: a tuple of (
|
|
lower: [batch_size, *shape, n_samples, 1]
|
|
upper: [batch_size, *shape, n_samples, 1]
|
|
delta: [batch_size, *shape, n_samples, 1]
|
|
) where
|
|
|
|
ts \\in [lower, upper]
|
|
deltas = upper - lower
|
|
"""
|
|
mids = (ts[..., 1:, :] + ts[..., :-1, :]) * 0.5
|
|
lower = torch.cat([self.t0[..., None, :], mids], dim=-2)
|
|
upper = torch.cat([mids, self.t1[..., None, :]], dim=-2)
|
|
delta = upper - lower
|
|
assert lower.shape == upper.shape == delta.shape == ts.shape
|
|
return lower, upper, delta
|
|
|
|
|
|
class Volume(ABC):
|
|
"""
|
|
An abstraction of rendering volume.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def intersect(
|
|
self,
|
|
origin: torch.Tensor,
|
|
direction: torch.Tensor,
|
|
t0_lower: Optional[torch.Tensor] = None,
|
|
params: Optional[Dict] = None,
|
|
epsilon: float = 1e-6,
|
|
) -> VolumeRange:
|
|
"""
|
|
:param origin: [batch_size, *shape, 3]
|
|
:param direction: [batch_size, *shape, 3]
|
|
:param t0_lower: Optional [batch_size, *shape, 1] lower bound of t0 when intersecting this volume.
|
|
:param params: Optional meta parameters in case Volume is parametric
|
|
:param epsilon: to stabilize calculations
|
|
|
|
:return: A tuple of (t0, t1, intersected) where each has a shape
|
|
[batch_size, *shape, 1]. If a ray intersects with the volume, `o + td` is
|
|
in the volume for all t in [t0, t1]. If the volume is bounded, t1 is guaranteed
|
|
to be on the boundary of the volume.
|
|
"""
|
|
|
|
|
|
class BoundingBoxVolume(MetaModule, Volume):
|
|
"""
|
|
Axis-aligned bounding box defined by the two opposite corners.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
bbox_min: ArrayType,
|
|
bbox_max: ArrayType,
|
|
min_dist: float = 0.0,
|
|
min_t_range: float = 1e-3,
|
|
device: torch.device = torch.device("cuda"),
|
|
):
|
|
"""
|
|
:param bbox_min: the left/bottommost corner of the bounding box
|
|
:param bbox_max: the other corner of the bounding box
|
|
:param min_dist: all rays should start at least this distance away from the origin.
|
|
"""
|
|
super().__init__()
|
|
|
|
self.bbox_min = to_torch(bbox_min).to(device)
|
|
self.bbox_max = to_torch(bbox_max).to(device)
|
|
self.min_dist = min_dist
|
|
self.min_t_range = min_t_range
|
|
self.bbox = torch.stack([self.bbox_min, self.bbox_max])
|
|
assert self.bbox.shape == (2, 3)
|
|
assert self.min_dist >= 0.0
|
|
assert self.min_t_range > 0.0
|
|
self.device = device
|
|
|
|
def intersect(
|
|
self,
|
|
origin: torch.Tensor,
|
|
direction: torch.Tensor,
|
|
t0_lower: Optional[torch.Tensor] = None,
|
|
params: Optional[Dict] = None,
|
|
epsilon=1e-6,
|
|
) -> VolumeRange:
|
|
"""
|
|
:param origin: [batch_size, *shape, 3]
|
|
:param direction: [batch_size, *shape, 3]
|
|
:param t0_lower: Optional [batch_size, *shape, 1] lower bound of t0 when intersecting this volume.
|
|
:param params: Optional meta parameters in case Volume is parametric
|
|
:param epsilon: to stabilize calculations
|
|
|
|
:return: A tuple of (t0, t1, intersected) where each has a shape
|
|
[batch_size, *shape, 1]. If a ray intersects with the volume, `o + td` is
|
|
in the volume for all t in [t0, t1]. If the volume is bounded, t1 is guaranteed
|
|
to be on the boundary of the volume.
|
|
"""
|
|
|
|
batch_size, *shape, _ = origin.shape
|
|
ones = [1] * len(shape)
|
|
bbox = self.bbox.view(1, *ones, 2, 3)
|
|
ts = safe_divide(bbox - origin[..., None, :], direction[..., None, :], epsilon=epsilon)
|
|
|
|
# Cases to think about:
|
|
#
|
|
# 1. t1 <= t0: the ray does not pass through the AABB.
|
|
# 2. t0 < t1 <= 0: the ray intersects but the BB is behind the origin.
|
|
# 3. t0 <= 0 <= t1: the ray starts from inside the BB
|
|
# 4. 0 <= t0 < t1: the ray is not inside and intersects with the BB twice.
|
|
#
|
|
# 1 and 4 are clearly handled from t0 < t1 below.
|
|
# Making t0 at least min_dist (>= 0) takes care of 2 and 3.
|
|
t0 = ts.min(dim=-2).values.max(dim=-1, keepdim=True).values.clamp(self.min_dist)
|
|
t1 = ts.max(dim=-2).values.min(dim=-1, keepdim=True).values
|
|
assert t0.shape == t1.shape == (batch_size, *shape, 1)
|
|
if t0_lower is not None:
|
|
assert t0.shape == t0_lower.shape
|
|
t0 = torch.maximum(t0, t0_lower)
|
|
|
|
intersected = t0 + self.min_t_range < t1
|
|
t0 = torch.where(intersected, t0, torch.zeros_like(t0))
|
|
t1 = torch.where(intersected, t1, torch.ones_like(t1))
|
|
|
|
return VolumeRange(t0=t0, t1=t1, intersected=intersected)
|
|
|
|
|
|
class UnboundedVolume(MetaModule, Volume):
|
|
"""
|
|
Originally used in NeRF. Unbounded volume but with a limited visibility
|
|
when rendering (e.g. objects that are farther away than the max_dist from
|
|
the ray origin are not considered)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
max_dist: float,
|
|
min_dist: float = 0.0,
|
|
min_t_range: float = 1e-3,
|
|
device: torch.device = torch.device("cuda"),
|
|
):
|
|
super().__init__()
|
|
self.max_dist = max_dist
|
|
self.min_dist = min_dist
|
|
self.min_t_range = min_t_range
|
|
assert self.min_dist >= 0.0
|
|
assert self.min_t_range > 0.0
|
|
self.device = device
|
|
|
|
def intersect(
|
|
self,
|
|
origin: torch.Tensor,
|
|
direction: torch.Tensor,
|
|
t0_lower: Optional[torch.Tensor] = None,
|
|
params: Optional[Dict] = None,
|
|
) -> VolumeRange:
|
|
"""
|
|
:param origin: [batch_size, *shape, 3]
|
|
:param direction: [batch_size, *shape, 3]
|
|
:param t0_lower: Optional [batch_size, *shape, 1] lower bound of t0 when intersecting this volume.
|
|
:param params: Optional meta parameters in case Volume is parametric
|
|
:param epsilon: to stabilize calculations
|
|
|
|
:return: A tuple of (t0, t1, intersected) where each has a shape
|
|
[batch_size, *shape, 1]. If a ray intersects with the volume, `o + td` is
|
|
in the volume for all t in [t0, t1]. If the volume is bounded, t1 is guaranteed
|
|
to be on the boundary of the volume.
|
|
"""
|
|
|
|
batch_size, *shape, _ = origin.shape
|
|
t0 = torch.zeros(batch_size, *shape, 1, dtype=origin.dtype, device=origin.device)
|
|
if t0_lower is not None:
|
|
t0 = torch.maximum(t0, t0_lower)
|
|
t1 = t0 + self.max_dist
|
|
t0 = t0.clamp(self.min_dist)
|
|
return VolumeRange(t0=t0, t1=t1, intersected=t0 + self.min_t_range < t1)
|
|
|
|
|
|
class SphericalVolume(MetaModule, Volume):
|
|
"""
|
|
Used in NeRF++ but will not be used probably unless we want to reproduce
|
|
their results.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
radius: float,
|
|
center: ArrayType = (0.0, 0.0, 0.0),
|
|
min_dist: float = 0.0,
|
|
min_t_range: float = 1e-3,
|
|
device: torch.device = torch.device("cuda"),
|
|
):
|
|
super().__init__()
|
|
|
|
self.radius = radius
|
|
self.center = to_torch(center).to(device)
|
|
self.min_dist = min_dist
|
|
self.min_t_range = min_t_range
|
|
assert self.min_dist >= 0.0
|
|
assert self.min_t_range > 0.0
|
|
self.device = device
|
|
|
|
def intersect(
|
|
self,
|
|
origin: torch.Tensor,
|
|
direction: torch.Tensor,
|
|
t0_lower: Optional[torch.Tensor] = None,
|
|
params: Optional[Dict] = None,
|
|
epsilon=1e-6,
|
|
) -> VolumeRange:
|
|
raise NotImplementedError
|
|
|