import torch import torch.nn as nn from .util import timestep_embedding class PooledMLP(nn.Module): def __init__( self, device: torch.device, *, input_channels: int = 3, output_channels: int = 6, hidden_size: int = 256, resblocks: int = 4, pool_op: str = "max", ): super().__init__() self.input_embed = nn.Conv1d(input_channels, hidden_size, kernel_size=1, device=device) self.time_embed = nn.Linear(hidden_size, hidden_size, device=device) blocks = [] for _ in range(resblocks): blocks.append(ResBlock(hidden_size, pool_op, device=device)) self.sequence = nn.Sequential(*blocks) self.out = nn.Conv1d(hidden_size, output_channels, kernel_size=1, device=device) with torch.no_grad(): self.out.bias.zero_() self.out.weight.zero_() def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: in_embed = self.input_embed(x) t_embed = self.time_embed(timestep_embedding(t, in_embed.shape[1])) h = in_embed + t_embed[..., None] h = self.sequence(h) h = self.out(h) return h class ResBlock(nn.Module): def __init__(self, hidden_size: int, pool_op: str, device: torch.device): super().__init__() assert pool_op in ["mean", "max"] self.pool_op = pool_op self.body = nn.Sequential( nn.SiLU(), nn.LayerNorm((hidden_size,), device=device), nn.Linear(hidden_size, hidden_size, device=device), nn.SiLU(), nn.LayerNorm((hidden_size,), device=device), nn.Linear(hidden_size, hidden_size, device=device), ) self.gate = nn.Sequential( nn.Linear(hidden_size, hidden_size, device=device), nn.Tanh(), ) def forward(self, x: torch.Tensor): N, C, T = x.shape out = self.body(x.permute(0, 2, 1).reshape(N * T, C)).reshape([N, T, C]).permute(0, 2, 1) pooled = pool(self.pool_op, x) gate = self.gate(pooled) return x + out * gate[..., None] def pool(op_name: str, x: torch.Tensor) -> torch.Tensor: if op_name == "max": pooled, _ = torch.max(x, dim=-1) elif op_name == "mean": pooled, _ = torch.mean(x, dim=-1) else: raise ValueError(f"unknown pool op: {op_name}") return pooled