You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							213 lines
						
					
					
						
							7.3 KiB
						
					
					
				
			
		
		
		
			
			
			
				
					
				
				
					
				
			
		
		
	
	
							213 lines
						
					
					
						
							7.3 KiB
						
					
					
				
								from functools import partial
							 | 
						|
								from typing import Any, Dict, Optional, Tuple
							 | 
						|
								
							 | 
						|
								import torch
							 | 
						|
								import torch.nn as nn
							 | 
						|
								
							 | 
						|
								from shap_e.models.nn.checkpoint import checkpoint
							 | 
						|
								from shap_e.models.nn.encoding import encode_position, maybe_encode_direction
							 | 
						|
								from shap_e.models.nn.meta import MetaModule, subdict
							 | 
						|
								from shap_e.models.nn.ops import MetaLinear, get_act, mlp_init
							 | 
						|
								from shap_e.models.query import Query
							 | 
						|
								from shap_e.util.collections import AttrDict
							 | 
						|
								
							 | 
						|
								from .base import Model
							 | 
						|
								
							 | 
						|
								
							 | 
						|
								class MLPModel(MetaModule, Model):
							 | 
						|
								    def __init__(
							 | 
						|
								        self,
							 | 
						|
								        n_output: int,
							 | 
						|
								        output_activation: str,
							 | 
						|
								        # Positional encoding parameters
							 | 
						|
								        posenc_version: str = "v1",
							 | 
						|
								        # Direction related channel prediction
							 | 
						|
								        insert_direction_at: Optional[int] = None,
							 | 
						|
								        # MLP parameters
							 | 
						|
								        d_hidden: int = 256,
							 | 
						|
								        n_hidden_layers: int = 4,
							 | 
						|
								        activation: str = "relu",
							 | 
						|
								        init: Optional[str] = None,
							 | 
						|
								        init_scale: float = 1.0,
							 | 
						|
								        meta_parameters: bool = False,
							 | 
						|
								        trainable_meta: bool = False,
							 | 
						|
								        meta_proj: bool = True,
							 | 
						|
								        meta_bias: bool = True,
							 | 
						|
								        meta_start: int = 0,
							 | 
						|
								        meta_stop: Optional[int] = None,
							 | 
						|
								        n_meta_layers: Optional[int] = None,
							 | 
						|
								        register_freqs: bool = False,
							 | 
						|
								        device: torch.device = torch.device("cuda"),
							 | 
						|
								    ):
							 | 
						|
								        super().__init__()
							 | 
						|
								
							 | 
						|
								        if register_freqs:
							 | 
						|
								            self.register_buffer("freqs", 2.0 ** torch.arange(10, device=device).view(1, 10))
							 | 
						|
								
							 | 
						|
								        # Positional encoding
							 | 
						|
								        self.posenc_version = posenc_version
							 | 
						|
								        dummy = torch.eye(1, 3)
							 | 
						|
								        d_posenc_pos = encode_position(posenc_version, position=dummy).shape[-1]
							 | 
						|
								        d_posenc_dir = maybe_encode_direction(posenc_version, position=dummy).shape[-1]
							 | 
						|
								
							 | 
						|
								        # Instantiate the MLP
							 | 
						|
								        mlp_widths = [d_hidden] * n_hidden_layers
							 | 
						|
								        input_widths = [d_posenc_pos, *mlp_widths]
							 | 
						|
								        output_widths = mlp_widths + [n_output]
							 | 
						|
								
							 | 
						|
								        self.meta_parameters = meta_parameters
							 | 
						|
								
							 | 
						|
								        # When this model is used jointly to express NeRF, it may have to
							 | 
						|
								        # process directions as well in which case we simply concatenate
							 | 
						|
								        # the direction representation at the specified layer.
							 | 
						|
								        self.insert_direction_at = insert_direction_at
							 | 
						|
								        if insert_direction_at is not None:
							 | 
						|
								            input_widths[self.insert_direction_at] += d_posenc_dir
							 | 
						|
								
							 | 
						|
								        linear_cls = lambda meta: (
							 | 
						|
								            partial(
							 | 
						|
								                MetaLinear,
							 | 
						|
								                meta_scale=False,
							 | 
						|
								                meta_shift=False,
							 | 
						|
								                meta_proj=meta_proj,
							 | 
						|
								                meta_bias=meta_bias,
							 | 
						|
								                trainable_meta=trainable_meta,
							 | 
						|
								            )
							 | 
						|
								            if meta
							 | 
						|
								            else nn.Linear
							 | 
						|
								        )
							 | 
						|
								
							 | 
						|
								        if meta_stop is None:
							 | 
						|
								            if n_meta_layers is not None:
							 | 
						|
								                assert n_meta_layers > 0
							 | 
						|
								                meta_stop = meta_start + n_meta_layers - 1
							 | 
						|
								            else:
							 | 
						|
								                meta_stop = n_hidden_layers
							 | 
						|
								
							 | 
						|
								        if meta_parameters:
							 | 
						|
								            metas = [meta_start <= layer <= meta_stop for layer in range(n_hidden_layers + 1)]
							 | 
						|
								        else:
							 | 
						|
								            metas = [False] * (n_hidden_layers + 1)
							 | 
						|
								
							 | 
						|
								        self.mlp = nn.ModuleList(
							 | 
						|
								            [
							 | 
						|
								                linear_cls(meta)(d_in, d_out, device=device)
							 | 
						|
								                for meta, d_in, d_out in zip(metas, input_widths, output_widths)
							 | 
						|
								            ]
							 | 
						|
								        )
							 | 
						|
								
							 | 
						|
								        mlp_init(self.mlp, init=init, init_scale=init_scale)
							 | 
						|
								
							 | 
						|
								        self.activation = get_act(activation)
							 | 
						|
								        self.output_activation = get_act(output_activation)
							 | 
						|
								
							 | 
						|
								        self.device = device
							 | 
						|
								        self.to(device)
							 | 
						|
								
							 | 
						|
								    def forward(
							 | 
						|
								        self,
							 | 
						|
								        query: Query,
							 | 
						|
								        params: Optional[Dict[str, torch.Tensor]] = None,
							 | 
						|
								        options: Optional[Dict[str, Any]] = None,
							 | 
						|
								    ) -> AttrDict:
							 | 
						|
								        """
							 | 
						|
								        :param position: [batch_size x ... x 3]
							 | 
						|
								        :param params: Meta parameters
							 | 
						|
								        :param options: Optional hyperparameters
							 | 
						|
								        """
							 | 
						|
								
							 | 
						|
								        # query.direction is None typically for SDF models and training
							 | 
						|
								        h_final, _h_directionless = self._mlp(
							 | 
						|
								            query.position, query.direction, params=params, options=options
							 | 
						|
								        )
							 | 
						|
								        return self.output_activation(h_final)
							 | 
						|
								
							 | 
						|
								    def _run_mlp(
							 | 
						|
								        self, position: torch.Tensor, direction: torch.Tensor, params: AttrDict[str, torch.Tensor]
							 | 
						|
								    ) -> Tuple[torch.Tensor, torch.Tensor]:
							 | 
						|
								        """
							 | 
						|
								        :return: the final and directionless activations at the given query
							 | 
						|
								        """
							 | 
						|
								        h_preact = h = encode_position(self.posenc_version, position=position)
							 | 
						|
								        h_directionless = None
							 | 
						|
								        for i, layer in enumerate(self.mlp):
							 | 
						|
								            if i == self.insert_direction_at:
							 | 
						|
								                h_directionless = h_preact
							 | 
						|
								                h_direction = maybe_encode_direction(
							 | 
						|
								                    self.posenc_version, position=position, direction=direction
							 | 
						|
								                )
							 | 
						|
								                h = torch.cat([h, h_direction], dim=-1)
							 | 
						|
								            if isinstance(layer, MetaLinear):
							 | 
						|
								                h = layer(h, params=subdict(params, f"mlp.{i}"))
							 | 
						|
								            else:
							 | 
						|
								                h = layer(h)
							 | 
						|
								            h_preact = h
							 | 
						|
								            if i < len(self.mlp) - 1:
							 | 
						|
								                h = self.activation(h)
							 | 
						|
								        h_final = h
							 | 
						|
								        if h_directionless is None:
							 | 
						|
								            h_directionless = h_preact
							 | 
						|
								        return h_final, h_directionless
							 | 
						|
								
							 | 
						|
								    def _mlp(
							 | 
						|
								        self,
							 | 
						|
								        position: torch.Tensor,
							 | 
						|
								        direction: Optional[torch.Tensor] = None,
							 | 
						|
								        params: Optional[Dict[str, torch.Tensor]] = None,
							 | 
						|
								        options: Optional[Dict[str, Any]] = None,
							 | 
						|
								    ) -> Tuple[torch.Tensor, torch.Tensor]:
							 | 
						|
								        """
							 | 
						|
								        :param position: [batch_size x ... x 3]
							 | 
						|
								        :param params: Meta parameters
							 | 
						|
								        :param options: Optional hyperparameters
							 | 
						|
								        :return: the final and directionless activations at the given query
							 | 
						|
								        """
							 | 
						|
								        params = self.update(params)
							 | 
						|
								        options = AttrDict() if options is None else AttrDict(options)
							 | 
						|
								
							 | 
						|
								        mlp = partial(self._run_mlp, direction=direction, params=params)
							 | 
						|
								        parameters = []
							 | 
						|
								        for i, layer in enumerate(self.mlp):
							 | 
						|
								            if isinstance(layer, MetaLinear):
							 | 
						|
								                parameters.extend(list(subdict(params, f"mlp.{i}").values()))
							 | 
						|
								            else:
							 | 
						|
								                parameters.extend(layer.parameters())
							 | 
						|
								
							 | 
						|
								        h_final, h_directionless = checkpoint(
							 | 
						|
								            mlp, (position,), parameters, options.checkpoint_stf_model
							 | 
						|
								        )
							 | 
						|
								
							 | 
						|
								        return h_final, h_directionless
							 | 
						|
								
							 | 
						|
								
							 | 
						|
								class MLPSDFModel(MLPModel):
							 | 
						|
								    def __init__(self, initial_bias: float = -0.1, **kwargs):
							 | 
						|
								        super().__init__(n_output=1, output_activation="identity", **kwargs)
							 | 
						|
								        self.mlp[-1].bias.data.fill_(initial_bias)
							 | 
						|
								
							 | 
						|
								    def forward(
							 | 
						|
								        self,
							 | 
						|
								        query: Query,
							 | 
						|
								        params: Optional[Dict[str, torch.Tensor]] = None,
							 | 
						|
								        options: Optional[Dict[str, Any]] = None,
							 | 
						|
								    ) -> AttrDict[str, Any]:
							 | 
						|
								        signed_distance = super().forward(query=query, params=params, options=options)
							 | 
						|
								        return AttrDict(signed_distance=signed_distance)
							 | 
						|
								
							 | 
						|
								
							 | 
						|
								class MLPTextureFieldModel(MLPModel):
							 | 
						|
								    def __init__(
							 | 
						|
								        self,
							 | 
						|
								        n_channels: int = 3,
							 | 
						|
								        **kwargs,
							 | 
						|
								    ):
							 | 
						|
								        super().__init__(n_output=n_channels, output_activation="sigmoid", **kwargs)
							 | 
						|
								
							 | 
						|
								    def forward(
							 | 
						|
								        self,
							 | 
						|
								        query: Query,
							 | 
						|
								        params: Optional[Dict[str, torch.Tensor]] = None,
							 | 
						|
								        options: Optional[Dict[str, Any]] = None,
							 | 
						|
								    ) -> AttrDict[str, Any]:
							 | 
						|
								        channels = super().forward(query=query, params=params, options=options)
							 | 
						|
								        return AttrDict(channels=channels)
							 | 
						|
								
							 |