a fork of shap-e for gc
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

75 lines
2.4 KiB

2 years ago
import torch
import torch.nn as nn
from .util import timestep_embedding
class PooledMLP(nn.Module):
def __init__(
self,
device: torch.device,
*,
input_channels: int = 3,
output_channels: int = 6,
hidden_size: int = 256,
resblocks: int = 4,
pool_op: str = "max",
):
super().__init__()
self.input_embed = nn.Conv1d(input_channels, hidden_size, kernel_size=1, device=device)
self.time_embed = nn.Linear(hidden_size, hidden_size, device=device)
blocks = []
for _ in range(resblocks):
blocks.append(ResBlock(hidden_size, pool_op, device=device))
self.sequence = nn.Sequential(*blocks)
self.out = nn.Conv1d(hidden_size, output_channels, kernel_size=1, device=device)
with torch.no_grad():
self.out.bias.zero_()
self.out.weight.zero_()
def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
in_embed = self.input_embed(x)
t_embed = self.time_embed(timestep_embedding(t, in_embed.shape[1]))
h = in_embed + t_embed[..., None]
h = self.sequence(h)
h = self.out(h)
return h
class ResBlock(nn.Module):
def __init__(self, hidden_size: int, pool_op: str, device: torch.device):
super().__init__()
assert pool_op in ["mean", "max"]
self.pool_op = pool_op
self.body = nn.Sequential(
nn.SiLU(),
nn.LayerNorm((hidden_size,), device=device),
nn.Linear(hidden_size, hidden_size, device=device),
nn.SiLU(),
nn.LayerNorm((hidden_size,), device=device),
nn.Linear(hidden_size, hidden_size, device=device),
)
self.gate = nn.Sequential(
nn.Linear(hidden_size, hidden_size, device=device),
nn.Tanh(),
)
def forward(self, x: torch.Tensor):
N, C, T = x.shape
out = self.body(x.permute(0, 2, 1).reshape(N * T, C)).reshape([N, T, C]).permute(0, 2, 1)
pooled = pool(self.pool_op, x)
gate = self.gate(pooled)
return x + out * gate[..., None]
def pool(op_name: str, x: torch.Tensor) -> torch.Tensor:
if op_name == "max":
pooled, _ = torch.max(x, dim=-1)
elif op_name == "mean":
pooled, _ = torch.mean(x, dim=-1)
else:
raise ValueError(f"unknown pool op: {op_name}")
return pooled