You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
52 lines
1.5 KiB
52 lines
1.5 KiB
from abc import ABC, abstractmethod
|
|
from typing import Any, Dict, Optional
|
|
|
|
import torch
|
|
|
|
from shap_e.models.query import Query
|
|
from shap_e.models.renderer import append_tensor
|
|
from shap_e.util.collections import AttrDict
|
|
|
|
|
|
class Model(ABC):
|
|
@abstractmethod
|
|
def forward(
|
|
self,
|
|
query: Query,
|
|
params: Optional[Dict[str, torch.Tensor]] = None,
|
|
options: Optional[Dict[str, Any]] = None,
|
|
) -> AttrDict[str, Any]:
|
|
"""
|
|
Predict an attribute given position
|
|
"""
|
|
|
|
def forward_batched(
|
|
self,
|
|
query: Query,
|
|
query_batch_size: int = 4096,
|
|
params: Optional[Dict[str, torch.Tensor]] = None,
|
|
options: Optional[Dict[str, Any]] = None,
|
|
) -> AttrDict[str, Any]:
|
|
if not query.position.numel():
|
|
# Avoid torch.cat() of zero tensors.
|
|
return self(query, params=params, options=options)
|
|
|
|
if options.cache is None:
|
|
created_cache = True
|
|
options.cache = AttrDict()
|
|
else:
|
|
created_cache = False
|
|
|
|
results_list = AttrDict()
|
|
for i in range(0, query.position.shape[1], query_batch_size):
|
|
out = self(
|
|
query=query.map_tensors(lambda x, i=i: x[:, i : i + query_batch_size]),
|
|
params=params,
|
|
options=options,
|
|
)
|
|
results_list = results_list.combine(out, append_tensor)
|
|
|
|
if created_cache:
|
|
del options["cache"]
|
|
|
|
return results_list.map(lambda key, tensor_list: torch.cat(tensor_list, dim=1))
|
|
|