You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
33 lines
1.0 KiB
33 lines
1.0 KiB
2 years ago
|
from typing import Any, Dict
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
|
||
|
|
||
|
class SplitVectorDiffusion(nn.Module):
|
||
|
def __init__(self, *, device: torch.device, wrapped: nn.Module, n_ctx: int, d_latent: int):
|
||
|
super().__init__()
|
||
|
self.device = device
|
||
|
self.n_ctx = n_ctx
|
||
|
self.d_latent = d_latent
|
||
|
self.wrapped = wrapped
|
||
|
|
||
|
if hasattr(self.wrapped, "cached_model_kwargs"):
|
||
|
self.cached_model_kwargs = self.wrapped.cached_model_kwargs
|
||
|
|
||
|
def forward(self, x: torch.Tensor, t: torch.Tensor, **kwargs):
|
||
|
h = x.reshape(x.shape[0], self.n_ctx, -1).permute(0, 2, 1)
|
||
|
pre_channels = h.shape[1]
|
||
|
h = self.wrapped(h, t, **kwargs)
|
||
|
assert (
|
||
|
h.shape[1] == pre_channels * 2
|
||
|
), "expected twice as many outputs for variance prediction"
|
||
|
eps, var = torch.chunk(h, 2, dim=1)
|
||
|
return torch.cat(
|
||
|
[
|
||
|
eps.permute(0, 2, 1).flatten(1),
|
||
|
var.permute(0, 2, 1).flatten(1),
|
||
|
],
|
||
|
dim=1,
|
||
|
)
|