You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
31 lines
893 B
31 lines
893 B
2 years ago
|
from dataclasses import dataclass
|
||
|
from typing import Callable, Optional
|
||
|
|
||
|
import torch
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class Query:
|
||
|
# Both of these are of shape [batch_size x ... x 3]
|
||
|
position: torch.Tensor
|
||
|
direction: Optional[torch.Tensor] = None
|
||
|
|
||
|
t_min: Optional[torch.Tensor] = None
|
||
|
t_max: Optional[torch.Tensor] = None
|
||
|
|
||
|
def copy(self) -> "Query":
|
||
|
return Query(
|
||
|
position=self.position,
|
||
|
direction=self.direction,
|
||
|
t_min=self.t_min,
|
||
|
t_max=self.t_max,
|
||
|
)
|
||
|
|
||
|
def map_tensors(self, f: Callable[[torch.Tensor], torch.Tensor]) -> "Query":
|
||
|
return Query(
|
||
|
position=f(self.position),
|
||
|
direction=f(self.direction) if self.direction is not None else None,
|
||
|
t_min=f(self.t_min) if self.t_min is not None else None,
|
||
|
t_max=f(self.t_max) if self.t_max is not None else None,
|
||
|
)
|