You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							32 lines
						
					
					
						
							1.0 KiB
						
					
					
				
			
		
		
		
			
			
			
				
					
				
				
					
				
			
		
		
	
	
							32 lines
						
					
					
						
							1.0 KiB
						
					
					
				| from typing import Any, Dict | |
| 
 | |
| import torch | |
| import torch.nn as nn | |
| 
 | |
| 
 | |
| class SplitVectorDiffusion(nn.Module): | |
|     def __init__(self, *, device: torch.device, wrapped: nn.Module, n_ctx: int, d_latent: int): | |
|         super().__init__() | |
|         self.device = device | |
|         self.n_ctx = n_ctx | |
|         self.d_latent = d_latent | |
|         self.wrapped = wrapped | |
| 
 | |
|         if hasattr(self.wrapped, "cached_model_kwargs"): | |
|             self.cached_model_kwargs = self.wrapped.cached_model_kwargs | |
| 
 | |
|     def forward(self, x: torch.Tensor, t: torch.Tensor, **kwargs): | |
|         h = x.reshape(x.shape[0], self.n_ctx, -1).permute(0, 2, 1) | |
|         pre_channels = h.shape[1] | |
|         h = self.wrapped(h, t, **kwargs) | |
|         assert ( | |
|             h.shape[1] == pre_channels * 2 | |
|         ), "expected twice as many outputs for variance prediction" | |
|         eps, var = torch.chunk(h, 2, dim=1) | |
|         return torch.cat( | |
|             [ | |
|                 eps.permute(0, 2, 1).flatten(1), | |
|                 var.permute(0, 2, 1).flatten(1), | |
|             ], | |
|             dim=1, | |
|         )
 | |
| 
 |