bluestyle97
9 months ago
46 changed files with 6747 additions and 0 deletions
@ -0,0 +1,44 @@ |
|||
# Byte-compiled / optimized / DLL files |
|||
__pycache__/ |
|||
*.py[cod] |
|||
*$py.class |
|||
|
|||
# C extensions |
|||
*.so |
|||
|
|||
# Distribution / packaging |
|||
.Python |
|||
build/ |
|||
develop-eggs/ |
|||
dist/ |
|||
eggs/ |
|||
.eggs/ |
|||
.vscode/ |
|||
lib/ |
|||
lib64/ |
|||
parts/ |
|||
sdist/ |
|||
var/ |
|||
wheels/ |
|||
*.egg-info/ |
|||
.installed.cfg |
|||
*.egg |
|||
MANIFEST |
|||
|
|||
.DS_Store |
|||
|
|||
tools/objaverse_rendering/blender-3.2.2-linux-x64/ |
|||
tools/objaverse_rendering/output/ |
|||
ckpts/ |
|||
data/ |
|||
lightning_logs/ |
|||
logs/ |
|||
.trash/ |
|||
.env/ |
|||
outputs/ |
|||
figures*/ |
|||
|
|||
# Useless Files |
|||
*.sh |
|||
blender/ |
|||
.restore/ |
Binary file not shown.
@ -0,0 +1,22 @@ |
|||
model_config: |
|||
target: src.models.lrm_mesh.InstantMesh |
|||
params: |
|||
encoder_feat_dim: 768 |
|||
encoder_freeze: false |
|||
encoder_model_name: facebook/dino-vitb16 |
|||
transformer_dim: 1024 |
|||
transformer_layers: 12 |
|||
transformer_heads: 16 |
|||
triplane_low_res: 32 |
|||
triplane_high_res: 64 |
|||
triplane_dim: 40 |
|||
rendering_samples_per_ray: 96 |
|||
grid_res: 128 |
|||
grid_scale: 2.1 |
|||
|
|||
|
|||
infer_config: |
|||
unet_path: ckpts/diffusion_pytorch_model.bin |
|||
model_path: ckpts/instant_mesh_base.ckpt |
|||
texture_resolution: 1024 |
|||
render_resolution: 512 |
@ -0,0 +1,22 @@ |
|||
model_config: |
|||
target: src.models.lrm_mesh.InstantMesh |
|||
params: |
|||
encoder_feat_dim: 768 |
|||
encoder_freeze: false |
|||
encoder_model_name: facebook/dino-vitb16 |
|||
transformer_dim: 1024 |
|||
transformer_layers: 16 |
|||
transformer_heads: 16 |
|||
triplane_low_res: 32 |
|||
triplane_high_res: 64 |
|||
triplane_dim: 80 |
|||
rendering_samples_per_ray: 128 |
|||
grid_res: 128 |
|||
grid_scale: 2.1 |
|||
|
|||
|
|||
infer_config: |
|||
unet_path: ckpts/diffusion_pytorch_model.bin |
|||
model_path: ckpts/instant_mesh_large.ckpt |
|||
texture_resolution: 1024 |
|||
render_resolution: 512 |
@ -0,0 +1,21 @@ |
|||
model_config: |
|||
target: src.models.lrm.InstantNeRF |
|||
params: |
|||
encoder_feat_dim: 768 |
|||
encoder_freeze: false |
|||
encoder_model_name: facebook/dino-vitb16 |
|||
transformer_dim: 1024 |
|||
transformer_layers: 12 |
|||
transformer_heads: 16 |
|||
triplane_low_res: 32 |
|||
triplane_high_res: 64 |
|||
triplane_dim: 40 |
|||
rendering_samples_per_ray: 96 |
|||
|
|||
|
|||
infer_config: |
|||
unet_path: ckpts/diffusion_pytorch_model.bin |
|||
model_path: ckpts/instant_nerf_base.ckpt |
|||
mesh_threshold: 10.0 |
|||
mesh_resolution: 256 |
|||
render_resolution: 384 |
@ -0,0 +1,21 @@ |
|||
model_config: |
|||
target: src.models.lrm.InstantNeRF |
|||
params: |
|||
encoder_feat_dim: 768 |
|||
encoder_freeze: false |
|||
encoder_model_name: facebook/dino-vitb16 |
|||
transformer_dim: 1024 |
|||
transformer_layers: 16 |
|||
transformer_heads: 16 |
|||
triplane_low_res: 32 |
|||
triplane_high_res: 64 |
|||
triplane_dim: 80 |
|||
rendering_samples_per_ray: 128 |
|||
|
|||
|
|||
infer_config: |
|||
unet_path: ckpts/diffusion_pytorch_model.bin |
|||
model_path: ckpts/instant_nerf_large.ckpt |
|||
mesh_threshold: 10.0 |
|||
mesh_resolution: 256 |
|||
render_resolution: 384 |
@ -0,0 +1,19 @@ |
|||
pytorch-lightning==2.1.2 |
|||
gradio |
|||
huggingface-hub |
|||
einops |
|||
omegaconf |
|||
torchmetrics |
|||
webdataset |
|||
accelerate |
|||
tensorboard |
|||
PyMCubes |
|||
trimesh |
|||
rembg |
|||
transformers==4.34.1 |
|||
diffusers==0.20.2 |
|||
bitsandbytes |
|||
imageio[ffmpeg] |
|||
xatlas |
|||
plyfile |
|||
git+https://github.com/NVlabs/nvdiffrast/ |
@ -0,0 +1,253 @@ |
|||
import os |
|||
import argparse |
|||
import numpy as np |
|||
import torch |
|||
import rembg |
|||
from PIL import Image |
|||
from torchvision.transforms import v2 |
|||
from pytorch_lightning import seed_everything |
|||
from omegaconf import OmegaConf |
|||
from einops import rearrange, repeat |
|||
from tqdm import tqdm |
|||
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler |
|||
|
|||
from src.utils.train_util import instantiate_from_config |
|||
from src.utils.camera_util import ( |
|||
FOV_to_intrinsics, |
|||
get_zero123plus_input_cameras, |
|||
get_circular_camera_poses, |
|||
) |
|||
from src.utils.mesh_util import save_obj, save_obj_with_mtl |
|||
from src.utils.infer_util import remove_background, resize_foreground, save_video |
|||
|
|||
|
|||
def get_render_cameras(batch_size=1, M=120, radius=4.0, elevation=20.0, is_flexicubes=False): |
|||
""" |
|||
Get the rendering camera parameters. |
|||
""" |
|||
c2ws = get_circular_camera_poses(M=M, radius=radius, elevation=elevation) |
|||
if is_flexicubes: |
|||
cameras = torch.linalg.inv(c2ws) |
|||
cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1) |
|||
else: |
|||
extrinsics = c2ws.flatten(-2) |
|||
intrinsics = FOV_to_intrinsics(30.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2) |
|||
cameras = torch.cat([extrinsics, intrinsics], dim=-1) |
|||
cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1) |
|||
return cameras |
|||
|
|||
|
|||
def render_frames(model, planes, render_cameras, render_size=512, chunk_size=1, is_flexicubes=False): |
|||
""" |
|||
Render frames from triplanes. |
|||
""" |
|||
frames = [] |
|||
for i in tqdm(range(0, render_cameras.shape[1], chunk_size)): |
|||
if is_flexicubes: |
|||
frame = model.forward_geometry( |
|||
planes, |
|||
render_cameras[:, i:i+chunk_size], |
|||
render_size=render_size, |
|||
)['img'] |
|||
else: |
|||
frame = model.forward_synthesizer( |
|||
planes, |
|||
render_cameras[:, i:i+chunk_size], |
|||
render_size=render_size, |
|||
)['images_rgb'] |
|||
frames.append(frame) |
|||
|
|||
frames = torch.cat(frames, dim=1)[0] # we suppose batch size is always 1 |
|||
return frames |
|||
|
|||
|
|||
############################################################################### |
|||
# Arguments. |
|||
############################################################################### |
|||
|
|||
parser = argparse.ArgumentParser() |
|||
parser.add_argument('config', type=str, help='Path to config file.') |
|||
parser.add_argument('input_path', type=str, help='Path to input image or directory.') |
|||
parser.add_argument('--output_path', type=str, default='outputs/', help='Output directory.') |
|||
parser.add_argument('--diffusion_steps', type=int, default=75, help='Denoising Sampling steps.') |
|||
parser.add_argument('--seed', type=int, default=42, help='Random seed for sampling.') |
|||
parser.add_argument('--scale', type=float, default=1.0, help='Scale of generated object.') |
|||
parser.add_argument('--distance', type=float, default=4.5, help='Render distance.') |
|||
parser.add_argument('--view', type=int, default=6, choices=[4, 6], help='Number of input views.') |
|||
parser.add_argument('--no_rembg', action='store_true', help='Do not remove input background.') |
|||
parser.add_argument('--export_texmap', action='store_true', help='Export a mesh with texture map.') |
|||
parser.add_argument('--save_video', action='store_true', help='Save a circular-view video.') |
|||
args = parser.parse_args() |
|||
seed_everything(args.seed) |
|||
|
|||
############################################################################### |
|||
# Stage 0: Configuration. |
|||
############################################################################### |
|||
|
|||
config = OmegaConf.load(args.config) |
|||
config_name = os.path.basename(args.config).replace('.yaml', '') |
|||
model_config = config.model_config |
|||
infer_config = config.infer_config |
|||
|
|||
IS_FLEXICUBES = True if config_name.startswith('instant-mesh') else False |
|||
|
|||
device = torch.device('cuda') |
|||
|
|||
# load diffusion model |
|||
print('Loading diffusion model ...') |
|||
pipeline = DiffusionPipeline.from_pretrained( |
|||
"sudo-ai/zero123plus-v1.2", |
|||
custom_pipeline="zero123plus", |
|||
torch_dtype=torch.float16, |
|||
) |
|||
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( |
|||
pipeline.scheduler.config, timestep_spacing='trailing' |
|||
) |
|||
|
|||
# load custom white-background UNet |
|||
print('Loading custom white-background unet ...') |
|||
state_dict = torch.load(infer_config.unet_path, map_location='cpu') |
|||
pipeline.unet.load_state_dict(state_dict, strict=True) |
|||
|
|||
pipeline = pipeline.to(device) |
|||
|
|||
# load reconstruction model |
|||
print('Loading reconstruction model ...') |
|||
model = instantiate_from_config(model_config) |
|||
state_dict = torch.load(infer_config.model_path, map_location='cpu')['state_dict'] |
|||
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')} |
|||
model.load_state_dict(state_dict, strict=True) |
|||
|
|||
model = model.to(device) |
|||
if IS_FLEXICUBES: |
|||
model.init_flexicubes_geometry(device, fovy=30.0) |
|||
model = model.eval() |
|||
|
|||
# make output directories |
|||
image_path = os.path.join(args.output_path, config_name, 'images') |
|||
mesh_path = os.path.join(args.output_path, config_name, 'meshes') |
|||
video_path = os.path.join(args.output_path, config_name, 'videos') |
|||
os.makedirs(image_path, exist_ok=True) |
|||
os.makedirs(mesh_path, exist_ok=True) |
|||
os.makedirs(video_path, exist_ok=True) |
|||
|
|||
# process input files |
|||
if os.path.isdir(args.input_path): |
|||
input_files = [ |
|||
os.path.join(args.input_path, file) |
|||
for file in os.listdir(args.input_path) |
|||
if file.endswith('.png') or file.endswith('.jpg') or file.endswith('.webp') |
|||
] |
|||
else: |
|||
input_files = [args.input_path] |
|||
print(f'Total number of input images: {len(input_files)}') |
|||
|
|||
|
|||
############################################################################### |
|||
# Stage 1: Multiview generation. |
|||
############################################################################### |
|||
|
|||
rembg_session = None if args.no_rembg else rembg.new_session() |
|||
|
|||
outputs = [] |
|||
for idx, image_file in enumerate(input_files): |
|||
name = os.path.basename(image_file).split('.')[0] |
|||
print(f'[{idx+1}/{len(input_files)}] Imagining {name} ...') |
|||
|
|||
# remove background optionally |
|||
input_image = Image.open(image_file) |
|||
if not args.no_rembg: |
|||
input_image = remove_background(input_image, rembg_session) |
|||
input_image = resize_foreground(input_image, 0.85) |
|||
|
|||
# sampling |
|||
output_image = pipeline( |
|||
input_image, |
|||
num_inference_steps=args.diffusion_steps, |
|||
).images[0] |
|||
|
|||
output_image.save(os.path.join(image_path, f'{name}.png')) |
|||
print(f"Image saved to {os.path.join(image_path, f'{name}.png')}") |
|||
|
|||
images = np.asarray(output_image, dtype=np.float32) / 255.0 |
|||
images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640) |
|||
images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320) |
|||
|
|||
outputs.append({'name': name, 'images': images}) |
|||
|
|||
# delete pipeline to save memory |
|||
del pipeline |
|||
|
|||
############################################################################### |
|||
# Stage 2: Reconstruction. |
|||
############################################################################### |
|||
|
|||
input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0*args.scale).to(device) |
|||
chunk_size = 20 if IS_FLEXICUBES else 1 |
|||
|
|||
for idx, sample in enumerate(outputs): |
|||
name = sample['name'] |
|||
print(f'[{idx+1}/{len(outputs)}] Creating {name} ...') |
|||
|
|||
images = sample['images'].unsqueeze(0).to(device) |
|||
images = v2.functional.resize(images, 320, interpolation=3, antialias=True).clamp(0, 1) |
|||
|
|||
if args.view == 4: |
|||
indices = torch.tensor([0, 2, 4, 5]).long().to(device) |
|||
images = images[:, indices] |
|||
input_cameras = input_cameras[:, indices] |
|||
|
|||
with torch.no_grad(): |
|||
# get triplane |
|||
planes = model.forward_planes(images, input_cameras) |
|||
|
|||
# get mesh |
|||
mesh_path_idx = os.path.join(mesh_path, f'{name}.obj') |
|||
|
|||
mesh_out = model.extract_mesh( |
|||
planes, |
|||
use_texture_map=args.export_texmap, |
|||
**infer_config, |
|||
) |
|||
if args.export_texmap: |
|||
vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out |
|||
save_obj_with_mtl( |
|||
vertices.data.cpu().numpy(), |
|||
uvs.data.cpu().numpy(), |
|||
faces.data.cpu().numpy(), |
|||
mesh_tex_idx.data.cpu().numpy(), |
|||
tex_map.permute(1, 2, 0).data.cpu().numpy(), |
|||
mesh_path_idx, |
|||
) |
|||
else: |
|||
vertices, faces, vertex_colors = mesh_out |
|||
save_obj(vertices, faces, vertex_colors, mesh_path_idx) |
|||
print(f"Mesh saved to {mesh_path_idx}") |
|||
|
|||
# get video |
|||
if args.save_video: |
|||
video_path_idx = os.path.join(video_path, f'{name}.mp4') |
|||
render_size = infer_config.render_resolution |
|||
render_cameras = get_render_cameras( |
|||
batch_size=1, |
|||
M=120, |
|||
radius=args.distance, |
|||
elevation=20.0, |
|||
is_flexicubes=IS_FLEXICUBES, |
|||
).to(device) |
|||
|
|||
frames = render_frames( |
|||
model, |
|||
planes, |
|||
render_cameras=render_cameras, |
|||
render_size=render_size, |
|||
chunk_size=chunk_size, |
|||
is_flexicubes=IS_FLEXICUBES, |
|||
) |
|||
|
|||
save_video( |
|||
frames, |
|||
video_path_idx, |
|||
fps=30, |
|||
) |
|||
print(f"Video saved to {video_path_idx}") |
@ -0,0 +1,310 @@ |
|||
import os |
|||
import numpy as np |
|||
import torch |
|||
import torch.nn.functional as F |
|||
from torchvision.transforms import v2 |
|||
from torchvision.utils import make_grid, save_image |
|||
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity |
|||
import pytorch_lightning as pl |
|||
from einops import rearrange, repeat |
|||
|
|||
from src.utils.train_util import instantiate_from_config |
|||
|
|||
|
|||
class MVRecon(pl.LightningModule): |
|||
def __init__( |
|||
self, |
|||
lrm_generator_config, |
|||
lrm_path=None, |
|||
input_size=256, |
|||
render_size=192, |
|||
): |
|||
super(MVRecon, self).__init__() |
|||
|
|||
self.input_size = input_size |
|||
self.render_size = render_size |
|||
|
|||
# init modules |
|||
self.lrm_generator = instantiate_from_config(lrm_generator_config) |
|||
if lrm_path is not None: |
|||
lrm_ckpt = torch.load(lrm_path) |
|||
self.lrm_generator.load_state_dict(lrm_ckpt['weights'], strict=False) |
|||
|
|||
self.lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg') |
|||
|
|||
self.validation_step_outputs = [] |
|||
|
|||
def on_fit_start(self): |
|||
if self.global_rank == 0: |
|||
os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True) |
|||
os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True) |
|||
|
|||
def prepare_batch_data(self, batch): |
|||
lrm_generator_input = {} |
|||
render_gt = {} # for supervision |
|||
|
|||
# input images |
|||
images = batch['input_images'] |
|||
images = v2.functional.resize( |
|||
images, self.input_size, interpolation=3, antialias=True).clamp(0, 1) |
|||
|
|||
lrm_generator_input['images'] = images.to(self.device) |
|||
|
|||
# input cameras and render cameras |
|||
input_c2ws = batch['input_c2ws'].flatten(-2) |
|||
input_Ks = batch['input_Ks'].flatten(-2) |
|||
target_c2ws = batch['target_c2ws'].flatten(-2) |
|||
target_Ks = batch['target_Ks'].flatten(-2) |
|||
render_cameras_input = torch.cat([input_c2ws, input_Ks], dim=-1) |
|||
render_cameras_target = torch.cat([target_c2ws, target_Ks], dim=-1) |
|||
render_cameras = torch.cat([render_cameras_input, render_cameras_target], dim=1) |
|||
|
|||
input_extrinsics = input_c2ws[:, :, :12] |
|||
input_intrinsics = torch.stack([ |
|||
input_Ks[:, :, 0], input_Ks[:, :, 4], |
|||
input_Ks[:, :, 2], input_Ks[:, :, 5], |
|||
], dim=-1) |
|||
cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1) |
|||
|
|||
# add noise to input cameras |
|||
cameras = cameras + torch.rand_like(cameras) * 0.04 - 0.02 |
|||
|
|||
lrm_generator_input['cameras'] = cameras.to(self.device) |
|||
lrm_generator_input['render_cameras'] = render_cameras.to(self.device) |
|||
|
|||
# target images |
|||
target_images = torch.cat([batch['input_images'], batch['target_images']], dim=1) |
|||
target_depths = torch.cat([batch['input_depths'], batch['target_depths']], dim=1) |
|||
target_alphas = torch.cat([batch['input_alphas'], batch['target_alphas']], dim=1) |
|||
|
|||
# random crop |
|||
render_size = np.random.randint(self.render_size, 513) |
|||
target_images = v2.functional.resize( |
|||
target_images, render_size, interpolation=3, antialias=True).clamp(0, 1) |
|||
target_depths = v2.functional.resize( |
|||
target_depths, render_size, interpolation=0, antialias=True) |
|||
target_alphas = v2.functional.resize( |
|||
target_alphas, render_size, interpolation=0, antialias=True) |
|||
|
|||
crop_params = v2.RandomCrop.get_params( |
|||
target_images, output_size=(self.render_size, self.render_size)) |
|||
target_images = v2.functional.crop(target_images, *crop_params) |
|||
target_depths = v2.functional.crop(target_depths, *crop_params)[:, :, 0:1] |
|||
target_alphas = v2.functional.crop(target_alphas, *crop_params)[:, :, 0:1] |
|||
|
|||
lrm_generator_input['render_size'] = render_size |
|||
lrm_generator_input['crop_params'] = crop_params |
|||
|
|||
render_gt['target_images'] = target_images.to(self.device) |
|||
render_gt['target_depths'] = target_depths.to(self.device) |
|||
render_gt['target_alphas'] = target_alphas.to(self.device) |
|||
|
|||
return lrm_generator_input, render_gt |
|||
|
|||
def prepare_validation_batch_data(self, batch): |
|||
lrm_generator_input = {} |
|||
|
|||
# input images |
|||
images = batch['input_images'] |
|||
images = v2.functional.resize( |
|||
images, self.input_size, interpolation=3, antialias=True).clamp(0, 1) |
|||
|
|||
lrm_generator_input['images'] = images.to(self.device) |
|||
|
|||
input_c2ws = batch['input_c2ws'].flatten(-2) |
|||
input_Ks = batch['input_Ks'].flatten(-2) |
|||
|
|||
input_extrinsics = input_c2ws[:, :, :12] |
|||
input_intrinsics = torch.stack([ |
|||
input_Ks[:, :, 0], input_Ks[:, :, 4], |
|||
input_Ks[:, :, 2], input_Ks[:, :, 5], |
|||
], dim=-1) |
|||
cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1) |
|||
|
|||
lrm_generator_input['cameras'] = cameras.to(self.device) |
|||
|
|||
render_c2ws = batch['render_c2ws'].flatten(-2) |
|||
render_Ks = batch['render_Ks'].flatten(-2) |
|||
render_cameras = torch.cat([render_c2ws, render_Ks], dim=-1) |
|||
|
|||
lrm_generator_input['render_cameras'] = render_cameras.to(self.device) |
|||
lrm_generator_input['render_size'] = 384 |
|||
lrm_generator_input['crop_params'] = None |
|||
|
|||
return lrm_generator_input |
|||
|
|||
def forward_lrm_generator( |
|||
self, |
|||
images, |
|||
cameras, |
|||
render_cameras, |
|||
render_size=192, |
|||
crop_params=None, |
|||
chunk_size=1, |
|||
): |
|||
planes = torch.utils.checkpoint.checkpoint( |
|||
self.lrm_generator.forward_planes, |
|||
images, |
|||
cameras, |
|||
use_reentrant=False, |
|||
) |
|||
frames = [] |
|||
for i in range(0, render_cameras.shape[1], chunk_size): |
|||
frames.append( |
|||
torch.utils.checkpoint.checkpoint( |
|||
self.lrm_generator.synthesizer, |
|||
planes, |
|||
cameras=render_cameras[:, i:i+chunk_size], |
|||
render_size=render_size, |
|||
crop_params=crop_params, |
|||
use_reentrant=False |
|||
) |
|||
) |
|||
frames = { |
|||
k: torch.cat([r[k] for r in frames], dim=1) |
|||
for k in frames[0].keys() |
|||
} |
|||
return frames |
|||
|
|||
def forward(self, lrm_generator_input): |
|||
images = lrm_generator_input['images'] |
|||
cameras = lrm_generator_input['cameras'] |
|||
render_cameras = lrm_generator_input['render_cameras'] |
|||
render_size = lrm_generator_input['render_size'] |
|||
crop_params = lrm_generator_input['crop_params'] |
|||
|
|||
out = self.forward_lrm_generator( |
|||
images, |
|||
cameras, |
|||
render_cameras, |
|||
render_size=render_size, |
|||
crop_params=crop_params, |
|||
chunk_size=1, |
|||
) |
|||
render_images = torch.clamp(out['images_rgb'], 0.0, 1.0) |
|||
render_depths = out['images_depth'] |
|||
render_alphas = torch.clamp(out['images_weight'], 0.0, 1.0) |
|||
|
|||
out = { |
|||
'render_images': render_images, |
|||
'render_depths': render_depths, |
|||
'render_alphas': render_alphas, |
|||
} |
|||
return out |
|||
|
|||
def training_step(self, batch, batch_idx): |
|||
lrm_generator_input, render_gt = self.prepare_batch_data(batch) |
|||
|
|||
render_out = self.forward(lrm_generator_input) |
|||
|
|||
loss, loss_dict = self.compute_loss(render_out, render_gt) |
|||
|
|||
self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True) |
|||
|
|||
if self.global_step % 1000 == 0 and self.global_rank == 0: |
|||
B, N, C, H, W = render_gt['target_images'].shape |
|||
N_in = lrm_generator_input['images'].shape[1] |
|||
|
|||
input_images = v2.functional.resize( |
|||
lrm_generator_input['images'], (H, W), interpolation=3, antialias=True).clamp(0, 1) |
|||
input_images = torch.cat( |
|||
[input_images, torch.ones(B, N-N_in, C, H, W).to(input_images)], dim=1) |
|||
|
|||
input_images = rearrange( |
|||
input_images, 'b n c h w -> b c h (n w)') |
|||
target_images = rearrange( |
|||
render_gt['target_images'], 'b n c h w -> b c h (n w)') |
|||
render_images = rearrange( |
|||
render_out['render_images'], 'b n c h w -> b c h (n w)') |
|||
target_alphas = rearrange( |
|||
repeat(render_gt['target_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)') |
|||
render_alphas = rearrange( |
|||
repeat(render_out['render_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)') |
|||
target_depths = rearrange( |
|||
repeat(render_gt['target_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)') |
|||
render_depths = rearrange( |
|||
repeat(render_out['render_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)') |
|||
MAX_DEPTH = torch.max(target_depths) |
|||
target_depths = target_depths / MAX_DEPTH * target_alphas |
|||
render_depths = render_depths / MAX_DEPTH |
|||
|
|||
grid = torch.cat([ |
|||
input_images, |
|||
target_images, render_images, |
|||
target_alphas, render_alphas, |
|||
target_depths, render_depths, |
|||
], dim=-2) |
|||
grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1)) |
|||
|
|||
save_image(grid, os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png')) |
|||
|
|||
return loss |
|||
|
|||
def compute_loss(self, render_out, render_gt): |
|||
# NOTE: the rgb value range of OpenLRM is [0, 1] |
|||
render_images = render_out['render_images'] |
|||
target_images = render_gt['target_images'].to(render_images) |
|||
render_images = rearrange(render_images, 'b n ... -> (b n) ...') * 2.0 - 1.0 |
|||
target_images = rearrange(target_images, 'b n ... -> (b n) ...') * 2.0 - 1.0 |
|||
|
|||
loss_mse = F.mse_loss(render_images, target_images) |
|||
loss_lpips = 2.0 * self.lpips(render_images, target_images) |
|||
|
|||
render_alphas = render_out['render_alphas'] |
|||
target_alphas = render_gt['target_alphas'] |
|||
loss_mask = F.mse_loss(render_alphas, target_alphas) |
|||
|
|||
loss = loss_mse + loss_lpips + loss_mask |
|||
|
|||
prefix = 'train' |
|||
loss_dict = {} |
|||
loss_dict.update({f'{prefix}/loss_mse': loss_mse}) |
|||
loss_dict.update({f'{prefix}/loss_lpips': loss_lpips}) |
|||
loss_dict.update({f'{prefix}/loss_mask': loss_mask}) |
|||
loss_dict.update({f'{prefix}/loss': loss}) |
|||
|
|||
return loss, loss_dict |
|||
|
|||
@torch.no_grad() |
|||
def validation_step(self, batch, batch_idx): |
|||
lrm_generator_input = self.prepare_validation_batch_data(batch) |
|||
|
|||
render_out = self.forward(lrm_generator_input) |
|||
render_images = render_out['render_images'] |
|||
render_images = rearrange(render_images, 'b n c h w -> b c h (n w)') |
|||
|
|||
self.validation_step_outputs.append(render_images) |
|||
|
|||
def on_validation_epoch_end(self): |
|||
images = torch.cat(self.validation_step_outputs, dim=-1) |
|||
|
|||
all_images = self.all_gather(images) |
|||
all_images = rearrange(all_images, 'r b c h w -> (r b) c h w') |
|||
|
|||
if self.global_rank == 0: |
|||
image_path = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png') |
|||
|
|||
grid = make_grid(all_images, nrow=1, normalize=True, value_range=(0, 1)) |
|||
save_image(grid, image_path) |
|||
print(f"Saved image to {image_path}") |
|||
|
|||
self.validation_step_outputs.clear() |
|||
|
|||
def configure_optimizers(self): |
|||
lr = self.learning_rate |
|||
|
|||
params = [] |
|||
|
|||
lrm_params_fast, lrm_params_slow = [], [] |
|||
for n, p in self.lrm_generator.named_parameters(): |
|||
if 'adaLN_modulation' in n or 'camera_embedder' in n: |
|||
lrm_params_fast.append(p) |
|||
else: |
|||
lrm_params_slow.append(p) |
|||
params.append({"params": lrm_params_fast, "lr": lr, "weight_decay": 0.01 }) |
|||
params.append({"params": lrm_params_slow, "lr": lr / 10.0, "weight_decay": 0.01 }) |
|||
|
|||
optimizer = torch.optim.AdamW(params, lr=lr, betas=(0.90, 0.95)) |
|||
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 3000, eta_min=lr/4) |
|||
|
|||
return {'optimizer': optimizer, 'lr_scheduler': scheduler} |
@ -0,0 +1,325 @@ |
|||
import os |
|||
import numpy as np |
|||
import torch |
|||
import torch.nn.functional as F |
|||
from torchvision.transforms import v2 |
|||
from torchvision.utils import make_grid, save_image |
|||
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity |
|||
import pytorch_lightning as pl |
|||
from einops import rearrange, repeat |
|||
|
|||
from src.utils.train_util import instantiate_from_config |
|||
|
|||
|
|||
# Regulrarization loss for FlexiCubes |
|||
def sdf_reg_loss_batch(sdf, all_edges): |
|||
sdf_f1x6x2 = sdf[:, all_edges.reshape(-1)].reshape(sdf.shape[0], -1, 2) |
|||
mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1]) |
|||
sdf_f1x6x2 = sdf_f1x6x2[mask] |
|||
sdf_diff = F.binary_cross_entropy_with_logits( |
|||
sdf_f1x6x2[..., 0], (sdf_f1x6x2[..., 1] > 0).float()) + \ |
|||
F.binary_cross_entropy_with_logits( |
|||
sdf_f1x6x2[..., 1], (sdf_f1x6x2[..., 0] > 0).float()) |
|||
return sdf_diff |
|||
|
|||
|
|||
class MVRecon(pl.LightningModule): |
|||
def __init__( |
|||
self, |
|||
lrm_generator_config, |
|||
input_size=256, |
|||
render_size=512, |
|||
init_ckpt=None, |
|||
): |
|||
super(MVRecon, self).__init__() |
|||
|
|||
self.input_size = input_size |
|||
self.render_size = render_size |
|||
|
|||
# init modules |
|||
self.lrm_generator = instantiate_from_config(lrm_generator_config) |
|||
|
|||
self.lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg') |
|||
|
|||
# Load weights from pretrained MVRecon model, and use the mlp |
|||
# weights to initialize the weights of sdf and rgb mlps. |
|||
if init_ckpt is not None: |
|||
sd = torch.load(init_ckpt, map_location='cpu')['state_dict'] |
|||
sd = {k: v for k, v in sd.items() if k.startswith('lrm_generator')} |
|||
sd_fc = {} |
|||
for k, v in sd.items(): |
|||
if k.startswith('lrm_generator.synthesizer.decoder.net.'): |
|||
if k.startswith('lrm_generator.synthesizer.decoder.net.6.'): # last layer |
|||
# Here we assume the density filed's isosurface threshold is t, |
|||
# we reverse the sign of density filed to initialize SDF field. |
|||
# -(w*x + b - t) = (-w)*x + (t - b) |
|||
if 'weight' in k: |
|||
sd_fc[k.replace('net.', 'net_sdf.')] = -v[0:1] |
|||
else: |
|||
sd_fc[k.replace('net.', 'net_sdf.')] = 3.0 - v[0:1] |
|||
sd_fc[k.replace('net.', 'net_rgb.')] = v[1:4] |
|||
else: |
|||
sd_fc[k.replace('net.', 'net_sdf.')] = v |
|||
sd_fc[k.replace('net.', 'net_rgb.')] = v |
|||
else: |
|||
sd_fc[k] = v |
|||
sd_fc = {k.replace('lrm_generator.', ''): v for k, v in sd_fc.items()} |
|||
# missing `net_deformation` and `net_weight` parameters |
|||
self.lrm_generator.load_state_dict(sd_fc, strict=False) |
|||
print(f'Loaded weights from {init_ckpt}') |
|||
|
|||
self.validation_step_outputs = [] |
|||
|
|||
def on_fit_start(self): |
|||
device = torch.device(f'cuda:{self.global_rank}') |
|||
self.lrm_generator.init_flexicubes_geometry(device) |
|||
if self.global_rank == 0: |
|||
os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True) |
|||
os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True) |
|||
|
|||
def prepare_batch_data(self, batch): |
|||
lrm_generator_input = {} |
|||
render_gt = {} |
|||
|
|||
# input images |
|||
images = batch['input_images'] |
|||
images = v2.functional.resize( |
|||
images, self.input_size, interpolation=3, antialias=True).clamp(0, 1) |
|||
|
|||
lrm_generator_input['images'] = images.to(self.device) |
|||
|
|||
# input cameras and render cameras |
|||
input_c2ws = batch['input_c2ws'] |
|||
input_Ks = batch['input_Ks'] |
|||
target_c2ws = batch['target_c2ws'] |
|||
|
|||
render_c2ws = torch.cat([input_c2ws, target_c2ws], dim=1) |
|||
render_w2cs = torch.linalg.inv(render_c2ws) |
|||
|
|||
input_extrinsics = input_c2ws.flatten(-2) |
|||
input_extrinsics = input_extrinsics[:, :, :12] |
|||
input_intrinsics = input_Ks.flatten(-2) |
|||
input_intrinsics = torch.stack([ |
|||
input_intrinsics[:, :, 0], input_intrinsics[:, :, 4], |
|||
input_intrinsics[:, :, 2], input_intrinsics[:, :, 5], |
|||
], dim=-1) |
|||
cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1) |
|||
|
|||
# add noise to input_cameras |
|||
cameras = cameras + torch.rand_like(cameras) * 0.04 - 0.02 |
|||
|
|||
lrm_generator_input['cameras'] = cameras.to(self.device) |
|||
lrm_generator_input['render_cameras'] = render_w2cs.to(self.device) |
|||
|
|||
# target images |
|||
target_images = torch.cat([batch['input_images'], batch['target_images']], dim=1) |
|||
target_depths = torch.cat([batch['input_depths'], batch['target_depths']], dim=1) |
|||
target_alphas = torch.cat([batch['input_alphas'], batch['target_alphas']], dim=1) |
|||
target_normals = torch.cat([batch['input_normals'], batch['target_normals']], dim=1) |
|||
|
|||
render_size = self.render_size |
|||
target_images = v2.functional.resize( |
|||
target_images, render_size, interpolation=3, antialias=True).clamp(0, 1) |
|||
target_depths = v2.functional.resize( |
|||
target_depths, render_size, interpolation=0, antialias=True) |
|||
target_alphas = v2.functional.resize( |
|||
target_alphas, render_size, interpolation=0, antialias=True) |
|||
target_normals = v2.functional.resize( |
|||
target_normals, render_size, interpolation=3, antialias=True) |
|||
|
|||
lrm_generator_input['render_size'] = render_size |
|||
|
|||
render_gt['target_images'] = target_images.to(self.device) |
|||
render_gt['target_depths'] = target_depths.to(self.device) |
|||
render_gt['target_alphas'] = target_alphas.to(self.device) |
|||
render_gt['target_normals'] = target_normals.to(self.device) |
|||
|
|||
return lrm_generator_input, render_gt |
|||
|
|||
def prepare_validation_batch_data(self, batch): |
|||
lrm_generator_input = {} |
|||
|
|||
# input images |
|||
images = batch['input_images'] |
|||
images = v2.functional.resize( |
|||
images, self.input_size, interpolation=3, antialias=True).clamp(0, 1) |
|||
|
|||
lrm_generator_input['images'] = images.to(self.device) |
|||
|
|||
# input cameras |
|||
input_c2ws = batch['input_c2ws'].flatten(-2) |
|||
input_Ks = batch['input_Ks'].flatten(-2) |
|||
|
|||
input_extrinsics = input_c2ws[:, :, :12] |
|||
input_intrinsics = torch.stack([ |
|||
input_Ks[:, :, 0], input_Ks[:, :, 4], |
|||
input_Ks[:, :, 2], input_Ks[:, :, 5], |
|||
], dim=-1) |
|||
cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1) |
|||
|
|||
lrm_generator_input['cameras'] = cameras.to(self.device) |
|||
|
|||
# render cameras |
|||
render_c2ws = batch['render_c2ws'] |
|||
render_w2cs = torch.linalg.inv(render_c2ws) |
|||
|
|||
lrm_generator_input['render_cameras'] = render_w2cs.to(self.device) |
|||
lrm_generator_input['render_size'] = 384 |
|||
|
|||
return lrm_generator_input |
|||
|
|||
def forward_lrm_generator(self, images, cameras, render_cameras, render_size=512): |
|||
planes = torch.utils.checkpoint.checkpoint( |
|||
self.lrm_generator.forward_planes, |
|||
images, |
|||
cameras, |
|||
use_reentrant=False, |
|||
) |
|||
out = self.lrm_generator.forward_geometry( |
|||
planes, |
|||
render_cameras, |
|||
render_size, |
|||
) |
|||
return out |
|||
|
|||
def forward(self, lrm_generator_input): |
|||
images = lrm_generator_input['images'] |
|||
cameras = lrm_generator_input['cameras'] |
|||
render_cameras = lrm_generator_input['render_cameras'] |
|||
render_size = lrm_generator_input['render_size'] |
|||
|
|||
out = self.forward_lrm_generator( |
|||
images, cameras, render_cameras, render_size=render_size) |
|||
|
|||
return out |
|||
|
|||
def training_step(self, batch, batch_idx): |
|||
lrm_generator_input, render_gt = self.prepare_batch_data(batch) |
|||
|
|||
render_out = self.forward(lrm_generator_input) |
|||
|
|||
loss, loss_dict = self.compute_loss(render_out, render_gt) |
|||
|
|||
self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True) |
|||
|
|||
if self.global_step % 1000 == 0 and self.global_rank == 0: |
|||
B, N, C, H, W = render_gt['target_images'].shape |
|||
N_in = lrm_generator_input['images'].shape[1] |
|||
|
|||
target_images = rearrange( |
|||
render_gt['target_images'], 'b n c h w -> b c h (n w)') |
|||
render_images = rearrange( |
|||
render_out['img'], 'b n c h w -> b c h (n w)') |
|||
target_alphas = rearrange( |
|||
repeat(render_gt['target_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)') |
|||
render_alphas = rearrange( |
|||
repeat(render_out['mask'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)') |
|||
target_depths = rearrange( |
|||
repeat(render_gt['target_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)') |
|||
render_depths = rearrange( |
|||
repeat(render_out['depth'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)') |
|||
target_normals = rearrange( |
|||
render_gt['target_normals'], 'b n c h w -> b c h (n w)') |
|||
render_normals = rearrange( |
|||
render_out['normal'], 'b n c h w -> b c h (n w)') |
|||
MAX_DEPTH = torch.max(target_depths) |
|||
target_depths = target_depths / MAX_DEPTH * target_alphas |
|||
render_depths = render_depths / MAX_DEPTH |
|||
|
|||
grid = torch.cat([ |
|||
target_images, render_images, |
|||
target_alphas, render_alphas, |
|||
target_depths, render_depths, |
|||
target_normals, render_normals, |
|||
], dim=-2) |
|||
grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1)) |
|||
|
|||
image_path = os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png') |
|||
save_image(grid, image_path) |
|||
print(f"Saved image to {image_path}") |
|||
|
|||
return loss |
|||
|
|||
def compute_loss(self, render_out, render_gt): |
|||
# NOTE: the rgb value range of OpenLRM is [0, 1] |
|||
render_images = render_out['img'] |
|||
target_images = render_gt['target_images'].to(render_images) |
|||
render_images = rearrange(render_images, 'b n ... -> (b n) ...') * 2.0 - 1.0 |
|||
target_images = rearrange(target_images, 'b n ... -> (b n) ...') * 2.0 - 1.0 |
|||
loss_mse = F.mse_loss(render_images, target_images) |
|||
loss_lpips = 2.0 * self.lpips(render_images, target_images) |
|||
|
|||
render_alphas = render_out['mask'] |
|||
target_alphas = render_gt['target_alphas'] |
|||
loss_mask = F.mse_loss(render_alphas, target_alphas) |
|||
|
|||
render_depths = render_out['depth'] |
|||
target_depths = render_gt['target_depths'] |
|||
loss_depth = 0.5 * F.l1_loss(render_depths[target_alphas>0], target_depths[target_alphas>0]) |
|||
|
|||
render_normals = render_out['normal'] * 2.0 - 1.0 |
|||
target_normals = render_gt['target_normals'] * 2.0 - 1.0 |
|||
similarity = (render_normals * target_normals).sum(dim=-3).abs() |
|||
normal_mask = target_alphas.squeeze(-3) |
|||
loss_normal = 1 - similarity[normal_mask>0].mean() |
|||
loss_normal = 0.2 * loss_normal |
|||
|
|||
# flexicubes regularization loss |
|||
sdf = render_out['sdf'] |
|||
sdf_reg_loss = render_out['sdf_reg_loss'] |
|||
sdf_reg_loss_entropy = sdf_reg_loss_batch(sdf, self.lrm_generator.geometry.all_edges).mean() * 0.01 |
|||
_, flexicubes_surface_reg, flexicubes_weights_reg = sdf_reg_loss |
|||
flexicubes_surface_reg = flexicubes_surface_reg.mean() * 0.5 |
|||
flexicubes_weights_reg = flexicubes_weights_reg.mean() * 0.1 |
|||
|
|||
loss_reg = sdf_reg_loss_entropy + flexicubes_surface_reg + flexicubes_weights_reg |
|||
|
|||
loss = loss_mse + loss_lpips + loss_mask + loss_normal + loss_reg |
|||
|
|||
prefix = 'train' |
|||
loss_dict = {} |
|||
loss_dict.update({f'{prefix}/loss_mse': loss_mse}) |
|||
loss_dict.update({f'{prefix}/loss_lpips': loss_lpips}) |
|||
loss_dict.update({f'{prefix}/loss_mask': loss_mask}) |
|||
loss_dict.update({f'{prefix}/loss_normal': loss_normal}) |
|||
loss_dict.update({f'{prefix}/loss_depth': loss_depth}) |
|||
loss_dict.update({f'{prefix}/loss_reg_sdf': sdf_reg_loss_entropy}) |
|||
loss_dict.update({f'{prefix}/loss_reg_surface': flexicubes_surface_reg}) |
|||
loss_dict.update({f'{prefix}/loss_reg_weights': flexicubes_weights_reg}) |
|||
loss_dict.update({f'{prefix}/loss': loss}) |
|||
|
|||
return loss, loss_dict |
|||
|
|||
@torch.no_grad() |
|||
def validation_step(self, batch, batch_idx): |
|||
lrm_generator_input = self.prepare_validation_batch_data(batch) |
|||
|
|||
render_out = self.forward(lrm_generator_input) |
|||
render_images = render_out['img'] |
|||
render_images = rearrange(render_images, 'b n c h w -> b c h (n w)') |
|||
|
|||
self.validation_step_outputs.append(render_images) |
|||
|
|||
def on_validation_epoch_end(self): |
|||
images = torch.cat(self.validation_step_outputs, dim=-1) |
|||
|
|||
all_images = self.all_gather(images) |
|||
all_images = rearrange(all_images, 'r b c h w -> (r b) c h w') |
|||
|
|||
if self.global_rank == 0: |
|||
image_path = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png') |
|||
|
|||
grid = make_grid(all_images, nrow=1, normalize=True, value_range=(0, 1)) |
|||
save_image(grid, image_path) |
|||
print(f"Saved image to {image_path}") |
|||
|
|||
self.validation_step_outputs.clear() |
|||
|
|||
def configure_optimizers(self): |
|||
lr = self.learning_rate |
|||
|
|||
optimizer = torch.optim.AdamW( |
|||
self.lrm_generator.parameters(), lr=lr, betas=(0.90, 0.95), weight_decay=0.01) |
|||
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 100000, eta_min=0) |
|||
|
|||
return {'optimizer': optimizer, 'lr_scheduler': scheduler} |
@ -0,0 +1,123 @@ |
|||
# Copyright (c) 2023, Zexin He |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# https://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
|
|||
|
|||
import torch |
|||
import torch.nn as nn |
|||
|
|||
|
|||
class BasicTransformerBlock(nn.Module): |
|||
""" |
|||
Transformer block that takes in a cross-attention condition and another modulation vector applied to sub-blocks. |
|||
""" |
|||
# use attention from torch.nn.MultiHeadAttention |
|||
# Block contains a cross-attention layer, a self-attention layer, and a MLP |
|||
def __init__( |
|||
self, |
|||
inner_dim: int, |
|||
cond_dim: int, |
|||
num_heads: int, |
|||
eps: float, |
|||
attn_drop: float = 0., |
|||
attn_bias: bool = False, |
|||
mlp_ratio: float = 4., |
|||
mlp_drop: float = 0., |
|||
): |
|||
super().__init__() |
|||
|
|||
self.norm1 = nn.LayerNorm(inner_dim) |
|||
self.cross_attn = nn.MultiheadAttention( |
|||
embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim, |
|||
dropout=attn_drop, bias=attn_bias, batch_first=True) |
|||
self.norm2 = nn.LayerNorm(inner_dim) |
|||
self.self_attn = nn.MultiheadAttention( |
|||
embed_dim=inner_dim, num_heads=num_heads, |
|||
dropout=attn_drop, bias=attn_bias, batch_first=True) |
|||
self.norm3 = nn.LayerNorm(inner_dim) |
|||
self.mlp = nn.Sequential( |
|||
nn.Linear(inner_dim, int(inner_dim * mlp_ratio)), |
|||
nn.GELU(), |
|||
nn.Dropout(mlp_drop), |
|||
nn.Linear(int(inner_dim * mlp_ratio), inner_dim), |
|||
nn.Dropout(mlp_drop), |
|||
) |
|||
|
|||
def forward(self, x, cond): |
|||
# x: [N, L, D] |
|||
# cond: [N, L_cond, D_cond] |
|||
x = x + self.cross_attn(self.norm1(x), cond, cond)[0] |
|||
before_sa = self.norm2(x) |
|||
x = x + self.self_attn(before_sa, before_sa, before_sa)[0] |
|||
x = x + self.mlp(self.norm3(x)) |
|||
return x |
|||
|
|||
|
|||
class TriplaneTransformer(nn.Module): |
|||
""" |
|||
Transformer with condition that generates a triplane representation. |
|||
|
|||
Reference: |
|||
Timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L486 |
|||
""" |
|||
def __init__( |
|||
self, |
|||
inner_dim: int, |
|||
image_feat_dim: int, |
|||
triplane_low_res: int, |
|||
triplane_high_res: int, |
|||
triplane_dim: int, |
|||
num_layers: int, |
|||
num_heads: int, |
|||
eps: float = 1e-6, |
|||
): |
|||
super().__init__() |
|||
|
|||
# attributes |
|||
self.triplane_low_res = triplane_low_res |
|||
self.triplane_high_res = triplane_high_res |
|||
self.triplane_dim = triplane_dim |
|||
|
|||
# modules |
|||
# initialize pos_embed with 1/sqrt(dim) * N(0, 1) |
|||
self.pos_embed = nn.Parameter(torch.randn(1, 3*triplane_low_res**2, inner_dim) * (1. / inner_dim) ** 0.5) |
|||
self.layers = nn.ModuleList([ |
|||
BasicTransformerBlock( |
|||
inner_dim=inner_dim, cond_dim=image_feat_dim, num_heads=num_heads, eps=eps) |
|||
for _ in range(num_layers) |
|||
]) |
|||
self.norm = nn.LayerNorm(inner_dim, eps=eps) |
|||
self.deconv = nn.ConvTranspose2d(inner_dim, triplane_dim, kernel_size=2, stride=2, padding=0) |
|||
|
|||
def forward(self, image_feats): |
|||
# image_feats: [N, L_cond, D_cond] |
|||
|
|||
N = image_feats.shape[0] |
|||
H = W = self.triplane_low_res |
|||
L = 3 * H * W |
|||
|
|||
x = self.pos_embed.repeat(N, 1, 1) # [N, L, D] |
|||
for layer in self.layers: |
|||
x = layer(x, image_feats) |
|||
x = self.norm(x) |
|||
|
|||
# separate each plane and apply deconv |
|||
x = x.view(N, 3, H, W, -1) |
|||
x = torch.einsum('nihwd->indhw', x) # [3, N, D, H, W] |
|||
x = x.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W] |
|||
x = self.deconv(x) # [3*N, D', H', W'] |
|||
x = x.view(3, N, *x.shape[-3:]) # [3, N, D', H', W'] |
|||
x = torch.einsum('indhw->nidhw', x) # [N, 3, D', H', W'] |
|||
x = x.contiguous() |
|||
|
|||
return x |
@ -0,0 +1,550 @@ |
|||
# coding=utf-8 |
|||
# Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
""" PyTorch ViT model.""" |
|||
|
|||
|
|||
import collections.abc |
|||
import math |
|||
from typing import Dict, List, Optional, Set, Tuple, Union |
|||
|
|||
import torch |
|||
from torch import nn |
|||
|
|||
from transformers.activations import ACT2FN |
|||
from transformers.modeling_outputs import ( |
|||
BaseModelOutput, |
|||
BaseModelOutputWithPooling, |
|||
) |
|||
from transformers import PreTrainedModel, ViTConfig |
|||
from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer |
|||
|
|||
|
|||
class ViTEmbeddings(nn.Module): |
|||
""" |
|||
Construct the CLS token, position and patch embeddings. Optionally, also the mask token. |
|||
""" |
|||
|
|||
def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None: |
|||
super().__init__() |
|||
|
|||
self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) |
|||
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None |
|||
self.patch_embeddings = ViTPatchEmbeddings(config) |
|||
num_patches = self.patch_embeddings.num_patches |
|||
self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size)) |
|||
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|||
self.config = config |
|||
|
|||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: |
|||
""" |
|||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher |
|||
resolution images. |
|||
|
|||
Source: |
|||
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 |
|||
""" |
|||
|
|||
num_patches = embeddings.shape[1] - 1 |
|||
num_positions = self.position_embeddings.shape[1] - 1 |
|||
if num_patches == num_positions and height == width: |
|||
return self.position_embeddings |
|||
class_pos_embed = self.position_embeddings[:, 0] |
|||
patch_pos_embed = self.position_embeddings[:, 1:] |
|||
dim = embeddings.shape[-1] |
|||
h0 = height // self.config.patch_size |
|||
w0 = width // self.config.patch_size |
|||
# we add a small number to avoid floating point error in the interpolation |
|||
# see discussion at https://github.com/facebookresearch/dino/issues/8 |
|||
h0, w0 = h0 + 0.1, w0 + 0.1 |
|||
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) |
|||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) |
|||
patch_pos_embed = nn.functional.interpolate( |
|||
patch_pos_embed, |
|||
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), |
|||
mode="bicubic", |
|||
align_corners=False, |
|||
) |
|||
assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1] |
|||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) |
|||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) |
|||
|
|||
def forward( |
|||
self, |
|||
pixel_values: torch.Tensor, |
|||
bool_masked_pos: Optional[torch.BoolTensor] = None, |
|||
interpolate_pos_encoding: bool = False, |
|||
) -> torch.Tensor: |
|||
batch_size, num_channels, height, width = pixel_values.shape |
|||
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) |
|||
|
|||
if bool_masked_pos is not None: |
|||
seq_length = embeddings.shape[1] |
|||
mask_tokens = self.mask_token.expand(batch_size, seq_length, -1) |
|||
# replace the masked visual tokens by mask_tokens |
|||
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) |
|||
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask |
|||
|
|||
# add the [CLS] token to the embedded patch tokens |
|||
cls_tokens = self.cls_token.expand(batch_size, -1, -1) |
|||
embeddings = torch.cat((cls_tokens, embeddings), dim=1) |
|||
|
|||
# add positional encoding to each token |
|||
if interpolate_pos_encoding: |
|||
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) |
|||
else: |
|||
embeddings = embeddings + self.position_embeddings |
|||
|
|||
embeddings = self.dropout(embeddings) |
|||
|
|||
return embeddings |
|||
|
|||
|
|||
class ViTPatchEmbeddings(nn.Module): |
|||
""" |
|||
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial |
|||
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a |
|||
Transformer. |
|||
""" |
|||
|
|||
def __init__(self, config): |
|||
super().__init__() |
|||
image_size, patch_size = config.image_size, config.patch_size |
|||
num_channels, hidden_size = config.num_channels, config.hidden_size |
|||
|
|||
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) |
|||
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) |
|||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) |
|||
self.image_size = image_size |
|||
self.patch_size = patch_size |
|||
self.num_channels = num_channels |
|||
self.num_patches = num_patches |
|||
|
|||
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) |
|||
|
|||
def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: |
|||
batch_size, num_channels, height, width = pixel_values.shape |
|||
if num_channels != self.num_channels: |
|||
raise ValueError( |
|||
"Make sure that the channel dimension of the pixel values match with the one set in the configuration." |
|||
f" Expected {self.num_channels} but got {num_channels}." |
|||
) |
|||
if not interpolate_pos_encoding: |
|||
if height != self.image_size[0] or width != self.image_size[1]: |
|||
raise ValueError( |
|||
f"Input image size ({height}*{width}) doesn't match model" |
|||
f" ({self.image_size[0]}*{self.image_size[1]})." |
|||
) |
|||
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) |
|||
return embeddings |
|||
|
|||
|
|||
class ViTSelfAttention(nn.Module): |
|||
def __init__(self, config: ViTConfig) -> None: |
|||
super().__init__() |
|||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): |
|||
raise ValueError( |
|||
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention " |
|||
f"heads {config.num_attention_heads}." |
|||
) |
|||
|
|||
self.num_attention_heads = config.num_attention_heads |
|||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads) |
|||
self.all_head_size = self.num_attention_heads * self.attention_head_size |
|||
|
|||
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) |
|||
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) |
|||
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) |
|||
|
|||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) |
|||
|
|||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: |
|||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) |
|||
x = x.view(new_x_shape) |
|||
return x.permute(0, 2, 1, 3) |
|||
|
|||
def forward( |
|||
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False |
|||
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: |
|||
mixed_query_layer = self.query(hidden_states) |
|||
|
|||
key_layer = self.transpose_for_scores(self.key(hidden_states)) |
|||
value_layer = self.transpose_for_scores(self.value(hidden_states)) |
|||
query_layer = self.transpose_for_scores(mixed_query_layer) |
|||
|
|||
# Take the dot product between "query" and "key" to get the raw attention scores. |
|||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) |
|||
|
|||
attention_scores = attention_scores / math.sqrt(self.attention_head_size) |
|||
|
|||
# Normalize the attention scores to probabilities. |
|||
attention_probs = nn.functional.softmax(attention_scores, dim=-1) |
|||
|
|||
# This is actually dropping out entire tokens to attend to, which might |
|||
# seem a bit unusual, but is taken from the original Transformer paper. |
|||
attention_probs = self.dropout(attention_probs) |
|||
|
|||
# Mask heads if we want to |
|||
if head_mask is not None: |
|||
attention_probs = attention_probs * head_mask |
|||
|
|||
context_layer = torch.matmul(attention_probs, value_layer) |
|||
|
|||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() |
|||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) |
|||
context_layer = context_layer.view(new_context_layer_shape) |
|||
|
|||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) |
|||
|
|||
return outputs |
|||
|
|||
|
|||
class ViTSelfOutput(nn.Module): |
|||
""" |
|||
The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the |
|||
layernorm applied before each block. |
|||
""" |
|||
|
|||
def __init__(self, config: ViTConfig) -> None: |
|||
super().__init__() |
|||
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|||
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|||
|
|||
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: |
|||
hidden_states = self.dense(hidden_states) |
|||
hidden_states = self.dropout(hidden_states) |
|||
|
|||
return hidden_states |
|||
|
|||
|
|||
class ViTAttention(nn.Module): |
|||
def __init__(self, config: ViTConfig) -> None: |
|||
super().__init__() |
|||
self.attention = ViTSelfAttention(config) |
|||
self.output = ViTSelfOutput(config) |
|||
self.pruned_heads = set() |
|||
|
|||
def prune_heads(self, heads: Set[int]) -> None: |
|||
if len(heads) == 0: |
|||
return |
|||
heads, index = find_pruneable_heads_and_indices( |
|||
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads |
|||
) |
|||
|
|||
# Prune linear layers |
|||
self.attention.query = prune_linear_layer(self.attention.query, index) |
|||
self.attention.key = prune_linear_layer(self.attention.key, index) |
|||
self.attention.value = prune_linear_layer(self.attention.value, index) |
|||
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) |
|||
|
|||
# Update hyper params and store pruned heads |
|||
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) |
|||
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads |
|||
self.pruned_heads = self.pruned_heads.union(heads) |
|||
|
|||
def forward( |
|||
self, |
|||
hidden_states: torch.Tensor, |
|||
head_mask: Optional[torch.Tensor] = None, |
|||
output_attentions: bool = False, |
|||
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: |
|||
self_outputs = self.attention(hidden_states, head_mask, output_attentions) |
|||
|
|||
attention_output = self.output(self_outputs[0], hidden_states) |
|||
|
|||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them |
|||
return outputs |
|||
|
|||
|
|||
class ViTIntermediate(nn.Module): |
|||
def __init__(self, config: ViTConfig) -> None: |
|||
super().__init__() |
|||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size) |
|||
if isinstance(config.hidden_act, str): |
|||
self.intermediate_act_fn = ACT2FN[config.hidden_act] |
|||
else: |
|||
self.intermediate_act_fn = config.hidden_act |
|||
|
|||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|||
hidden_states = self.dense(hidden_states) |
|||
hidden_states = self.intermediate_act_fn(hidden_states) |
|||
|
|||
return hidden_states |
|||
|
|||
|
|||
class ViTOutput(nn.Module): |
|||
def __init__(self, config: ViTConfig) -> None: |
|||
super().__init__() |
|||
self.dense = nn.Linear(config.intermediate_size, config.hidden_size) |
|||
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|||
|
|||
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: |
|||
hidden_states = self.dense(hidden_states) |
|||
hidden_states = self.dropout(hidden_states) |
|||
|
|||
hidden_states = hidden_states + input_tensor |
|||
|
|||
return hidden_states |
|||
|
|||
|
|||
def modulate(x, shift, scale): |
|||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) |
|||
|
|||
|
|||
class ViTLayer(nn.Module): |
|||
"""This corresponds to the Block class in the timm implementation.""" |
|||
|
|||
def __init__(self, config: ViTConfig) -> None: |
|||
super().__init__() |
|||
self.chunk_size_feed_forward = config.chunk_size_feed_forward |
|||
self.seq_len_dim = 1 |
|||
self.attention = ViTAttention(config) |
|||
self.intermediate = ViTIntermediate(config) |
|||
self.output = ViTOutput(config) |
|||
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|||
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|||
|
|||
self.adaLN_modulation = nn.Sequential( |
|||
nn.SiLU(), |
|||
nn.Linear(config.hidden_size, 4 * config.hidden_size, bias=True) |
|||
) |
|||
nn.init.constant_(self.adaLN_modulation[-1].weight, 0) |
|||
nn.init.constant_(self.adaLN_modulation[-1].bias, 0) |
|||
|
|||
def forward( |
|||
self, |
|||
hidden_states: torch.Tensor, |
|||
adaln_input: torch.Tensor = None, |
|||
head_mask: Optional[torch.Tensor] = None, |
|||
output_attentions: bool = False, |
|||
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: |
|||
shift_msa, scale_msa, shift_mlp, scale_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1) |
|||
|
|||
self_attention_outputs = self.attention( |
|||
modulate(self.layernorm_before(hidden_states), shift_msa, scale_msa), # in ViT, layernorm is applied before self-attention |
|||
head_mask, |
|||
output_attentions=output_attentions, |
|||
) |
|||
attention_output = self_attention_outputs[0] |
|||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights |
|||
|
|||
# first residual connection |
|||
hidden_states = attention_output + hidden_states |
|||
|
|||
# in ViT, layernorm is also applied after self-attention |
|||
layer_output = modulate(self.layernorm_after(hidden_states), shift_mlp, scale_mlp) |
|||
layer_output = self.intermediate(layer_output) |
|||
|
|||
# second residual connection is done here |
|||
layer_output = self.output(layer_output, hidden_states) |
|||
|
|||
outputs = (layer_output,) + outputs |
|||
|
|||
return outputs |
|||
|
|||
|
|||
class ViTEncoder(nn.Module): |
|||
def __init__(self, config: ViTConfig) -> None: |
|||
super().__init__() |
|||
self.config = config |
|||
self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)]) |
|||
self.gradient_checkpointing = False |
|||
|
|||
def forward( |
|||
self, |
|||
hidden_states: torch.Tensor, |
|||
adaln_input: torch.Tensor = None, |
|||
head_mask: Optional[torch.Tensor] = None, |
|||
output_attentions: bool = False, |
|||
output_hidden_states: bool = False, |
|||
return_dict: bool = True, |
|||
) -> Union[tuple, BaseModelOutput]: |
|||
all_hidden_states = () if output_hidden_states else None |
|||
all_self_attentions = () if output_attentions else None |
|||
|
|||
for i, layer_module in enumerate(self.layer): |
|||
if output_hidden_states: |
|||
all_hidden_states = all_hidden_states + (hidden_states,) |
|||
|
|||
layer_head_mask = head_mask[i] if head_mask is not None else None |
|||
|
|||
if self.gradient_checkpointing and self.training: |
|||
layer_outputs = self._gradient_checkpointing_func( |
|||
layer_module.__call__, |
|||
hidden_states, |
|||
adaln_input, |
|||
layer_head_mask, |
|||
output_attentions, |
|||
) |
|||
else: |
|||
layer_outputs = layer_module(hidden_states, adaln_input, layer_head_mask, output_attentions) |
|||
|
|||
hidden_states = layer_outputs[0] |
|||
|
|||
if output_attentions: |
|||
all_self_attentions = all_self_attentions + (layer_outputs[1],) |
|||
|
|||
if output_hidden_states: |
|||
all_hidden_states = all_hidden_states + (hidden_states,) |
|||
|
|||
if not return_dict: |
|||
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) |
|||
return BaseModelOutput( |
|||
last_hidden_state=hidden_states, |
|||
hidden_states=all_hidden_states, |
|||
attentions=all_self_attentions, |
|||
) |
|||
|
|||
|
|||
class ViTPreTrainedModel(PreTrainedModel): |
|||
""" |
|||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
|||
models. |
|||
""" |
|||
|
|||
config_class = ViTConfig |
|||
base_model_prefix = "vit" |
|||
main_input_name = "pixel_values" |
|||
supports_gradient_checkpointing = True |
|||
_no_split_modules = ["ViTEmbeddings", "ViTLayer"] |
|||
|
|||
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: |
|||
"""Initialize the weights""" |
|||
if isinstance(module, (nn.Linear, nn.Conv2d)): |
|||
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid |
|||
# `trunc_normal_cpu` not implemented in `half` issues |
|||
module.weight.data = nn.init.trunc_normal_( |
|||
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range |
|||
).to(module.weight.dtype) |
|||
if module.bias is not None: |
|||
module.bias.data.zero_() |
|||
elif isinstance(module, nn.LayerNorm): |
|||
module.bias.data.zero_() |
|||
module.weight.data.fill_(1.0) |
|||
elif isinstance(module, ViTEmbeddings): |
|||
module.position_embeddings.data = nn.init.trunc_normal_( |
|||
module.position_embeddings.data.to(torch.float32), |
|||
mean=0.0, |
|||
std=self.config.initializer_range, |
|||
).to(module.position_embeddings.dtype) |
|||
|
|||
module.cls_token.data = nn.init.trunc_normal_( |
|||
module.cls_token.data.to(torch.float32), |
|||
mean=0.0, |
|||
std=self.config.initializer_range, |
|||
).to(module.cls_token.dtype) |
|||
|
|||
|
|||
class ViTModel(ViTPreTrainedModel): |
|||
def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False): |
|||
super().__init__(config) |
|||
self.config = config |
|||
|
|||
self.embeddings = ViTEmbeddings(config, use_mask_token=use_mask_token) |
|||
self.encoder = ViTEncoder(config) |
|||
|
|||
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|||
self.pooler = ViTPooler(config) if add_pooling_layer else None |
|||
|
|||
# Initialize weights and apply final processing |
|||
self.post_init() |
|||
|
|||
def get_input_embeddings(self) -> ViTPatchEmbeddings: |
|||
return self.embeddings.patch_embeddings |
|||
|
|||
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: |
|||
""" |
|||
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base |
|||
class PreTrainedModel |
|||
""" |
|||
for layer, heads in heads_to_prune.items(): |
|||
self.encoder.layer[layer].attention.prune_heads(heads) |
|||
|
|||
def forward( |
|||
self, |
|||
pixel_values: Optional[torch.Tensor] = None, |
|||
adaln_input: Optional[torch.Tensor] = None, |
|||
bool_masked_pos: Optional[torch.BoolTensor] = None, |
|||
head_mask: Optional[torch.Tensor] = None, |
|||
output_attentions: Optional[bool] = None, |
|||
output_hidden_states: Optional[bool] = None, |
|||
interpolate_pos_encoding: Optional[bool] = None, |
|||
return_dict: Optional[bool] = None, |
|||
) -> Union[Tuple, BaseModelOutputWithPooling]: |
|||
r""" |
|||
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): |
|||
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). |
|||
""" |
|||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|||
output_hidden_states = ( |
|||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|||
) |
|||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|||
|
|||
if pixel_values is None: |
|||
raise ValueError("You have to specify pixel_values") |
|||
|
|||
# Prepare head mask if needed |
|||
# 1.0 in head_mask indicate we keep the head |
|||
# attention_probs has shape bsz x n_heads x N x N |
|||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] |
|||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] |
|||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) |
|||
|
|||
# TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?) |
|||
expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype |
|||
if pixel_values.dtype != expected_dtype: |
|||
pixel_values = pixel_values.to(expected_dtype) |
|||
|
|||
embedding_output = self.embeddings( |
|||
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding |
|||
) |
|||
|
|||
encoder_outputs = self.encoder( |
|||
embedding_output, |
|||
adaln_input=adaln_input, |
|||
head_mask=head_mask, |
|||
output_attentions=output_attentions, |
|||
output_hidden_states=output_hidden_states, |
|||
return_dict=return_dict, |
|||
) |
|||
sequence_output = encoder_outputs[0] |
|||
sequence_output = self.layernorm(sequence_output) |
|||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None |
|||
|
|||
if not return_dict: |
|||
head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) |
|||
return head_outputs + encoder_outputs[1:] |
|||
|
|||
return BaseModelOutputWithPooling( |
|||
last_hidden_state=sequence_output, |
|||
pooler_output=pooled_output, |
|||
hidden_states=encoder_outputs.hidden_states, |
|||
attentions=encoder_outputs.attentions, |
|||
) |
|||
|
|||
|
|||
class ViTPooler(nn.Module): |
|||
def __init__(self, config: ViTConfig): |
|||
super().__init__() |
|||
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|||
self.activation = nn.Tanh() |
|||
|
|||
def forward(self, hidden_states): |
|||
# We "pool" the model by simply taking the hidden state corresponding |
|||
# to the first token. |
|||
first_token_tensor = hidden_states[:, 0] |
|||
pooled_output = self.dense(first_token_tensor) |
|||
pooled_output = self.activation(pooled_output) |
|||
return pooled_output |
@ -0,0 +1,80 @@ |
|||
# Copyright (c) 2023, Zexin He |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# https://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
|
|||
|
|||
import torch.nn as nn |
|||
from transformers import ViTImageProcessor |
|||
from einops import rearrange, repeat |
|||
from .dino import ViTModel |
|||
|
|||
|
|||
class DinoWrapper(nn.Module): |
|||
""" |
|||
Dino v1 wrapper using huggingface transformer implementation. |
|||
""" |
|||
def __init__(self, model_name: str, freeze: bool = True): |
|||
super().__init__() |
|||
self.model, self.processor = self._build_dino(model_name) |
|||
self.camera_embedder = nn.Sequential( |
|||
nn.Linear(16, self.model.config.hidden_size, bias=True), |
|||
nn.SiLU(), |
|||
nn.Linear(self.model.config.hidden_size, self.model.config.hidden_size, bias=True) |
|||
) |
|||
if freeze: |
|||
self._freeze() |
|||
|
|||
def forward(self, image, camera): |
|||
# image: [B, N, C, H, W] |
|||
# camera: [B, N, D] |
|||
# RGB image with [0,1] scale and properly sized |
|||
if image.ndim == 5: |
|||
image = rearrange(image, 'b n c h w -> (b n) c h w') |
|||
dtype = image.dtype |
|||
inputs = self.processor( |
|||
images=image.float(), |
|||
return_tensors="pt", |
|||
do_rescale=False, |
|||
do_resize=False, |
|||
).to(self.model.device).to(dtype) |
|||
# embed camera |
|||
N = camera.shape[1] |
|||
camera_embeddings = self.camera_embedder(camera) |
|||
camera_embeddings = rearrange(camera_embeddings, 'b n d -> (b n) d') |
|||
embeddings = camera_embeddings |
|||
# This resampling of positional embedding uses bicubic interpolation |
|||
outputs = self.model(**inputs, adaln_input=embeddings, interpolate_pos_encoding=True) |
|||
last_hidden_states = outputs.last_hidden_state |
|||
return last_hidden_states |
|||
|
|||
def _freeze(self): |
|||
print(f"======== Freezing DinoWrapper ========") |
|||
self.model.eval() |
|||
for name, param in self.model.named_parameters(): |
|||
param.requires_grad = False |
|||
|
|||
@staticmethod |
|||
def _build_dino(model_name: str, proxy_error_retries: int = 3, proxy_error_cooldown: int = 5): |
|||
import requests |
|||
try: |
|||
model = ViTModel.from_pretrained(model_name, add_pooling_layer=False) |
|||
processor = ViTImageProcessor.from_pretrained(model_name) |
|||
return model, processor |
|||
except requests.exceptions.ProxyError as err: |
|||
if proxy_error_retries > 0: |
|||
print(f"Huggingface ProxyError: Retrying in {proxy_error_cooldown} seconds...") |
|||
import time |
|||
time.sleep(proxy_error_cooldown) |
|||
return DinoWrapper._build_dino(model_name, proxy_error_retries - 1, proxy_error_cooldown) |
|||
else: |
|||
raise err |
@ -0,0 +1,7 @@ |
|||
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
|||
# |
|||
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property |
|||
# and proprietary rights in and to this software, related documentation |
|||
# and any modifications thereto. Any use, reproduction, disclosure or |
|||
# distribution of this software and related documentation without an express |
|||
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. |
@ -0,0 +1,16 @@ |
|||
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
|||
# |
|||
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property |
|||
# and proprietary rights in and to this software, related documentation |
|||
# and any modifications thereto. Any use, reproduction, disclosure or |
|||
# distribution of this software and related documentation without an express |
|||
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. |
|||
|
|||
import torch |
|||
from torch import nn |
|||
|
|||
|
|||
class Camera(nn.Module): |
|||
def __init__(self): |
|||
super(Camera, self).__init__() |
|||
pass |
@ -0,0 +1,35 @@ |
|||
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
|||
# |
|||
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property |
|||
# and proprietary rights in and to this software, related documentation |
|||
# and any modifications thereto. Any use, reproduction, disclosure or |
|||
# distribution of this software and related documentation without an express |
|||
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. |
|||
|
|||
import torch |
|||
from . import Camera |
|||
import numpy as np |
|||
|
|||
|
|||
def projection(x=0.1, n=1.0, f=50.0, near_plane=None): |
|||
if near_plane is None: |
|||
near_plane = n |
|||
return np.array( |
|||
[[n / x, 0, 0, 0], |
|||
[0, n / -x, 0, 0], |
|||
[0, 0, -(f + near_plane) / (f - near_plane), -(2 * f * near_plane) / (f - near_plane)], |
|||
[0, 0, -1, 0]]).astype(np.float32) |
|||
|
|||
|
|||
class PerspectiveCamera(Camera): |
|||
def __init__(self, fovy=49.0, device='cuda'): |
|||
super(PerspectiveCamera, self).__init__() |
|||
self.device = device |
|||
focal = np.tan(fovy / 180.0 * np.pi * 0.5) |
|||
self.proj_mtx = torch.from_numpy(projection(x=focal, f=1000.0, n=1.0, near_plane=0.1)).to(self.device).unsqueeze(dim=0) |
|||
|
|||
def project(self, points_bxnx4): |
|||
out = torch.matmul( |
|||
points_bxnx4, |
|||
torch.transpose(self.proj_mtx, 1, 2)) |
|||
return out |
@ -0,0 +1,8 @@ |
|||
import torch |
|||
|
|||
class Renderer(): |
|||
def __init__(self): |
|||
pass |
|||
|
|||
def forward(self): |
|||
pass |
@ -0,0 +1,121 @@ |
|||
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
|||
# |
|||
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property |
|||
# and proprietary rights in and to this software, related documentation |
|||
# and any modifications thereto. Any use, reproduction, disclosure or |
|||
# distribution of this software and related documentation without an express |
|||
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. |
|||
|
|||
import torch |
|||
import torch.nn.functional as F |
|||
import nvdiffrast.torch as dr |
|||
from . import Renderer |
|||
|
|||
_FG_LUT = None |
|||
|
|||
|
|||
def interpolate(attr, rast, attr_idx, rast_db=None): |
|||
return dr.interpolate( |
|||
attr.contiguous(), rast, attr_idx, rast_db=rast_db, |
|||
diff_attrs=None if rast_db is None else 'all') |
|||
|
|||
|
|||
def xfm_points(points, matrix, use_python=True): |
|||
'''Transform points. |
|||
Args: |
|||
points: Tensor containing 3D points with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3] |
|||
matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4] |
|||
use_python: Use PyTorch's torch.matmul (for validation) |
|||
Returns: |
|||
Transformed points in homogeneous 4D with shape [minibatch_size, num_vertices, 4]. |
|||
''' |
|||
out = torch.matmul(torch.nn.functional.pad(points, pad=(0, 1), mode='constant', value=1.0), torch.transpose(matrix, 1, 2)) |
|||
if torch.is_anomaly_enabled(): |
|||
assert torch.all(torch.isfinite(out)), "Output of xfm_points contains inf or NaN" |
|||
return out |
|||
|
|||
|
|||
def dot(x, y): |
|||
return torch.sum(x * y, -1, keepdim=True) |
|||
|
|||
|
|||
def compute_vertex_normal(v_pos, t_pos_idx): |
|||
i0 = t_pos_idx[:, 0] |
|||
i1 = t_pos_idx[:, 1] |
|||
i2 = t_pos_idx[:, 2] |
|||
|
|||
v0 = v_pos[i0, :] |
|||
v1 = v_pos[i1, :] |
|||
v2 = v_pos[i2, :] |
|||
|
|||
face_normals = torch.cross(v1 - v0, v2 - v0) |
|||
|
|||
# Splat face normals to vertices |
|||
v_nrm = torch.zeros_like(v_pos) |
|||
v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals) |
|||
v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals) |
|||
v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals) |
|||
|
|||
# Normalize, replace zero (degenerated) normals with some default value |
|||
v_nrm = torch.where( |
|||
dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm) |
|||
) |
|||
v_nrm = F.normalize(v_nrm, dim=1) |
|||
assert torch.all(torch.isfinite(v_nrm)) |
|||
|
|||
return v_nrm |
|||
|
|||
|
|||
class NeuralRender(Renderer): |
|||
def __init__(self, device='cuda', camera_model=None): |
|||
super(NeuralRender, self).__init__() |
|||
self.device = device |
|||
self.ctx = dr.RasterizeCudaContext(device=device) |
|||
self.projection_mtx = None |
|||
self.camera = camera_model |
|||
|
|||
def render_mesh( |
|||
self, |
|||
mesh_v_pos_bxnx3, |
|||
mesh_t_pos_idx_fx3, |
|||
camera_mv_bx4x4, |
|||
mesh_v_feat_bxnxd, |
|||
resolution=256, |
|||
spp=1, |
|||
device='cuda', |
|||
hierarchical_mask=False |
|||
): |
|||
assert not hierarchical_mask |
|||
|
|||
mtx_in = torch.tensor(camera_mv_bx4x4, dtype=torch.float32, device=device) if not torch.is_tensor(camera_mv_bx4x4) else camera_mv_bx4x4 |
|||
v_pos = xfm_points(mesh_v_pos_bxnx3, mtx_in) # Rotate it to camera coordinates |
|||
v_pos_clip = self.camera.project(v_pos) # Projection in the camera |
|||
|
|||
v_nrm = compute_vertex_normal(mesh_v_pos_bxnx3[0], mesh_t_pos_idx_fx3.long()) # vertex normals in world coordinates |
|||
|
|||
# Render the image, |
|||
# Here we only return the feature (3D location) at each pixel, which will be used as the input for neural render |
|||
num_layers = 1 |
|||
mask_pyramid = None |
|||
assert mesh_t_pos_idx_fx3.shape[0] > 0 # Make sure we have shapes |
|||
mesh_v_feat_bxnxd = torch.cat([mesh_v_feat_bxnxd.repeat(v_pos.shape[0], 1, 1), v_pos], dim=-1) # Concatenate the pos |
|||
|
|||
with dr.DepthPeeler(self.ctx, v_pos_clip, mesh_t_pos_idx_fx3, [resolution * spp, resolution * spp]) as peeler: |
|||
for _ in range(num_layers): |
|||
rast, db = peeler.rasterize_next_layer() |
|||
gb_feat, _ = interpolate(mesh_v_feat_bxnxd, rast, mesh_t_pos_idx_fx3) |
|||
|
|||
hard_mask = torch.clamp(rast[..., -1:], 0, 1) |
|||
antialias_mask = dr.antialias( |
|||
hard_mask.clone().contiguous(), rast, v_pos_clip, |
|||
mesh_t_pos_idx_fx3) |
|||
|
|||
depth = gb_feat[..., -2:-1] |
|||
ori_mesh_feature = gb_feat[..., :-4] |
|||
|
|||
normal, _ = interpolate(v_nrm[None, ...], rast, mesh_t_pos_idx_fx3) |
|||
normal = dr.antialias(normal.clone().contiguous(), rast, v_pos_clip, mesh_t_pos_idx_fx3) |
|||
normal = F.normalize(normal, dim=-1) |
|||
normal = torch.lerp(torch.zeros_like(normal), (normal + 1.0) / 2.0, hard_mask.float()) # black background |
|||
|
|||
return ori_mesh_feature, antialias_mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal |
@ -0,0 +1,18 @@ |
|||
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
|||
# |
|||
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property |
|||
# and proprietary rights in and to this software, related documentation |
|||
# and any modifications thereto. Any use, reproduction, disclosure or |
|||
# distribution of this software and related documentation without an express |
|||
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. |
|||
|
|||
import torch |
|||
import numpy as np |
|||
|
|||
|
|||
class Geometry(): |
|||
def __init__(self): |
|||
pass |
|||
|
|||
def forward(self): |
|||
pass |
@ -0,0 +1,504 @@ |
|||
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
|||
# |
|||
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property |
|||
# and proprietary rights in and to this software, related documentation |
|||
# and any modifications thereto. Any use, reproduction, disclosure or |
|||
# distribution of this software and related documentation without an express |
|||
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. |
|||
|
|||
import torch |
|||
import numpy as np |
|||
import os |
|||
from . import Geometry |
|||
from .dmtet_utils import get_center_boundary_index |
|||
import torch.nn.functional as F |
|||
|
|||
|
|||
############################################################################### |
|||
# DMTet utility functions |
|||
############################################################################### |
|||
def create_mt_variable(device): |
|||
triangle_table = torch.tensor( |
|||
[ |
|||
[-1, -1, -1, -1, -1, -1], |
|||
[1, 0, 2, -1, -1, -1], |
|||
[4, 0, 3, -1, -1, -1], |
|||
[1, 4, 2, 1, 3, 4], |
|||
[3, 1, 5, -1, -1, -1], |
|||
[2, 3, 0, 2, 5, 3], |
|||
[1, 4, 0, 1, 5, 4], |
|||
[4, 2, 5, -1, -1, -1], |
|||
[4, 5, 2, -1, -1, -1], |
|||
[4, 1, 0, 4, 5, 1], |
|||
[3, 2, 0, 3, 5, 2], |
|||
[1, 3, 5, -1, -1, -1], |
|||
[4, 1, 2, 4, 3, 1], |
|||
[3, 0, 4, -1, -1, -1], |
|||
[2, 0, 1, -1, -1, -1], |
|||
[-1, -1, -1, -1, -1, -1] |
|||
], dtype=torch.long, device=device) |
|||
|
|||
num_triangles_table = torch.tensor([0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long, device=device) |
|||
base_tet_edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=device) |
|||
v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=device)) |
|||
return triangle_table, num_triangles_table, base_tet_edges, v_id |
|||
|
|||
|
|||
def sort_edges(edges_ex2): |
|||
with torch.no_grad(): |
|||
order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long() |
|||
order = order.unsqueeze(dim=1) |
|||
a = torch.gather(input=edges_ex2, index=order, dim=1) |
|||
b = torch.gather(input=edges_ex2, index=1 - order, dim=1) |
|||
return torch.stack([a, b], -1) |
|||
|
|||
|
|||
############################################################################### |
|||
# marching tetrahedrons (differentiable) |
|||
############################################################################### |
|||
|
|||
def marching_tets(pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id): |
|||
with torch.no_grad(): |
|||
occ_n = sdf_n > 0 |
|||
occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4) |
|||
occ_sum = torch.sum(occ_fx4, -1) |
|||
valid_tets = (occ_sum > 0) & (occ_sum < 4) |
|||
occ_sum = occ_sum[valid_tets] |
|||
|
|||
# find all vertices |
|||
all_edges = tet_fx4[valid_tets][:, base_tet_edges].reshape(-1, 2) |
|||
all_edges = sort_edges(all_edges) |
|||
unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True) |
|||
|
|||
unique_edges = unique_edges.long() |
|||
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 |
|||
mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=sdf_n.device) * -1 |
|||
mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long, device=sdf_n.device) |
|||
idx_map = mapping[idx_map] # map edges to verts |
|||
|
|||
interp_v = unique_edges[mask_edges] # .long() |
|||
edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3) |
|||
edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1) |
|||
edges_to_interp_sdf[:, -1] *= -1 |
|||
|
|||
denominator = edges_to_interp_sdf.sum(1, keepdim=True) |
|||
|
|||
edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator |
|||
verts = (edges_to_interp * edges_to_interp_sdf).sum(1) |
|||
|
|||
idx_map = idx_map.reshape(-1, 6) |
|||
|
|||
tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1) |
|||
num_triangles = num_triangles_table[tetindex] |
|||
|
|||
# Generate triangle indices |
|||
faces = torch.cat( |
|||
( |
|||
torch.gather( |
|||
input=idx_map[num_triangles == 1], dim=1, |
|||
index=triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1, 3), |
|||
torch.gather( |
|||
input=idx_map[num_triangles == 2], dim=1, |
|||
index=triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1, 3), |
|||
), dim=0) |
|||
return verts, faces |
|||
|
|||
|
|||
def create_tetmesh_variables(device='cuda'): |
|||
tet_table = torch.tensor( |
|||
[[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], |
|||
[0, 4, 5, 6, -1, -1, -1, -1, -1, -1, -1, -1], |
|||
[1, 4, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1], |
|||
[1, 0, 8, 7, 0, 5, 8, 7, 0, 5, 6, 8], |
|||
[2, 5, 7, 9, -1, -1, -1, -1, -1, -1, -1, -1], |
|||
[2, 0, 9, 7, 0, 4, 9, 7, 0, 4, 6, 9], |
|||
[2, 1, 9, 5, 1, 4, 9, 5, 1, 4, 8, 9], |
|||
[6, 0, 1, 2, 6, 1, 2, 8, 6, 8, 2, 9], |
|||
[3, 6, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1], |
|||
[3, 0, 9, 8, 0, 4, 9, 8, 0, 4, 5, 9], |
|||
[3, 1, 9, 6, 1, 4, 9, 6, 1, 4, 7, 9], |
|||
[5, 0, 1, 3, 5, 1, 3, 7, 5, 7, 3, 9], |
|||
[3, 2, 8, 6, 2, 5, 8, 6, 2, 5, 7, 8], |
|||
[4, 0, 2, 3, 4, 2, 3, 7, 4, 7, 3, 8], |
|||
[4, 1, 2, 3, 4, 2, 3, 5, 4, 5, 3, 6], |
|||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]], dtype=torch.long, device=device) |
|||
num_tets_table = torch.tensor([0, 1, 1, 3, 1, 3, 3, 3, 1, 3, 3, 3, 3, 3, 3, 0], dtype=torch.long, device=device) |
|||
return tet_table, num_tets_table |
|||
|
|||
|
|||
def marching_tets_tetmesh( |
|||
pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id, |
|||
return_tet_mesh=False, ori_v=None, num_tets_table=None, tet_table=None): |
|||
with torch.no_grad(): |
|||
occ_n = sdf_n > 0 |
|||
occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4) |
|||
occ_sum = torch.sum(occ_fx4, -1) |
|||
valid_tets = (occ_sum > 0) & (occ_sum < 4) |
|||
occ_sum = occ_sum[valid_tets] |
|||
|
|||
# find all vertices |
|||
all_edges = tet_fx4[valid_tets][:, base_tet_edges].reshape(-1, 2) |
|||
all_edges = sort_edges(all_edges) |
|||
unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True) |
|||
|
|||
unique_edges = unique_edges.long() |
|||
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 |
|||
mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=sdf_n.device) * -1 |
|||
mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long, device=sdf_n.device) |
|||
idx_map = mapping[idx_map] # map edges to verts |
|||
|
|||
interp_v = unique_edges[mask_edges] # .long() |
|||
edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3) |
|||
edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1) |
|||
edges_to_interp_sdf[:, -1] *= -1 |
|||
|
|||
denominator = edges_to_interp_sdf.sum(1, keepdim=True) |
|||
|
|||
edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator |
|||
verts = (edges_to_interp * edges_to_interp_sdf).sum(1) |
|||
|
|||
idx_map = idx_map.reshape(-1, 6) |
|||
|
|||
tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1) |
|||
num_triangles = num_triangles_table[tetindex] |
|||
|
|||
# Generate triangle indices |
|||
faces = torch.cat( |
|||
( |
|||
torch.gather( |
|||
input=idx_map[num_triangles == 1], dim=1, |
|||
index=triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1, 3), |
|||
torch.gather( |
|||
input=idx_map[num_triangles == 2], dim=1, |
|||
index=triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1, 3), |
|||
), dim=0) |
|||
if not return_tet_mesh: |
|||
return verts, faces |
|||
occupied_verts = ori_v[occ_n] |
|||
mapping = torch.ones((pos_nx3.shape[0]), dtype=torch.long, device="cuda") * -1 |
|||
mapping[occ_n] = torch.arange(occupied_verts.shape[0], device="cuda") |
|||
tet_fx4 = mapping[tet_fx4.reshape(-1)].reshape((-1, 4)) |
|||
|
|||
idx_map = torch.cat([tet_fx4[valid_tets] + verts.shape[0], idx_map], -1) # t x 10 |
|||
tet_verts = torch.cat([verts, occupied_verts], 0) |
|||
num_tets = num_tets_table[tetindex] |
|||
|
|||
tets = torch.cat( |
|||
( |
|||
torch.gather(input=idx_map[num_tets == 1], dim=1, index=tet_table[tetindex[num_tets == 1]][:, :4]).reshape( |
|||
-1, |
|||
4), |
|||
torch.gather(input=idx_map[num_tets == 3], dim=1, index=tet_table[tetindex[num_tets == 3]][:, :12]).reshape( |
|||
-1, |
|||
4), |
|||
), dim=0) |
|||
# add fully occupied tets |
|||
fully_occupied = occ_fx4.sum(-1) == 4 |
|||
tet_fully_occupied = tet_fx4[fully_occupied] + verts.shape[0] |
|||
tets = torch.cat([tets, tet_fully_occupied]) |
|||
|
|||
return verts, faces, tet_verts, tets |
|||
|
|||
|
|||
############################################################################### |
|||
# Compact tet grid |
|||
############################################################################### |
|||
|
|||
def compact_tets(pos_nx3, sdf_n, tet_fx4): |
|||
with torch.no_grad(): |
|||
# Find surface tets |
|||
occ_n = sdf_n > 0 |
|||
occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4) |
|||
occ_sum = torch.sum(occ_fx4, -1) |
|||
valid_tets = (occ_sum > 0) & (occ_sum < 4) # one value per tet, these are the surface tets |
|||
|
|||
valid_vtx = tet_fx4[valid_tets].reshape(-1) |
|||
unique_vtx, idx_map = torch.unique(valid_vtx, dim=0, return_inverse=True) |
|||
new_pos = pos_nx3[unique_vtx] |
|||
new_sdf = sdf_n[unique_vtx] |
|||
new_tets = idx_map.reshape(-1, 4) |
|||
return new_pos, new_sdf, new_tets |
|||
|
|||
|
|||
############################################################################### |
|||
# Subdivide volume |
|||
############################################################################### |
|||
|
|||
def batch_subdivide_volume(tet_pos_bxnx3, tet_bxfx4, grid_sdf): |
|||
device = tet_pos_bxnx3.device |
|||
# get new verts |
|||
tet_fx4 = tet_bxfx4[0] |
|||
edges = [0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3] |
|||
all_edges = tet_fx4[:, edges].reshape(-1, 2) |
|||
all_edges = sort_edges(all_edges) |
|||
unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True) |
|||
idx_map = idx_map + tet_pos_bxnx3.shape[1] |
|||
all_values = torch.cat([tet_pos_bxnx3, grid_sdf], -1) |
|||
mid_points_pos = all_values[:, unique_edges.reshape(-1)].reshape( |
|||
all_values.shape[0], -1, 2, |
|||
all_values.shape[-1]).mean(2) |
|||
new_v = torch.cat([all_values, mid_points_pos], 1) |
|||
new_v, new_sdf = new_v[..., :3], new_v[..., 3] |
|||
|
|||
# get new tets |
|||
|
|||
idx_a, idx_b, idx_c, idx_d = tet_fx4[:, 0], tet_fx4[:, 1], tet_fx4[:, 2], tet_fx4[:, 3] |
|||
idx_ab = idx_map[0::6] |
|||
idx_ac = idx_map[1::6] |
|||
idx_ad = idx_map[2::6] |
|||
idx_bc = idx_map[3::6] |
|||
idx_bd = idx_map[4::6] |
|||
idx_cd = idx_map[5::6] |
|||
|
|||
tet_1 = torch.stack([idx_a, idx_ab, idx_ac, idx_ad], dim=1) |
|||
tet_2 = torch.stack([idx_b, idx_bc, idx_ab, idx_bd], dim=1) |
|||
tet_3 = torch.stack([idx_c, idx_ac, idx_bc, idx_cd], dim=1) |
|||
tet_4 = torch.stack([idx_d, idx_ad, idx_cd, idx_bd], dim=1) |
|||
tet_5 = torch.stack([idx_ab, idx_ac, idx_ad, idx_bd], dim=1) |
|||
tet_6 = torch.stack([idx_ab, idx_ac, idx_bd, idx_bc], dim=1) |
|||
tet_7 = torch.stack([idx_cd, idx_ac, idx_bd, idx_ad], dim=1) |
|||
tet_8 = torch.stack([idx_cd, idx_ac, idx_bc, idx_bd], dim=1) |
|||
|
|||
tet_np = torch.cat([tet_1, tet_2, tet_3, tet_4, tet_5, tet_6, tet_7, tet_8], dim=0) |
|||
tet_np = tet_np.reshape(1, -1, 4).expand(tet_pos_bxnx3.shape[0], -1, -1) |
|||
tet = tet_np.long().to(device) |
|||
|
|||
return new_v, tet, new_sdf |
|||
|
|||
|
|||
############################################################################### |
|||
# Adjacency |
|||
############################################################################### |
|||
def tet_to_tet_adj_sparse(tet_tx4): |
|||
# include self connection!!!!!!!!!!!!!!!!!!! |
|||
with torch.no_grad(): |
|||
t = tet_tx4.shape[0] |
|||
device = tet_tx4.device |
|||
idx_array = torch.LongTensor( |
|||
[0, 1, 2, |
|||
1, 0, 3, |
|||
2, 3, 0, |
|||
3, 2, 1]).to(device).reshape(4, 3).unsqueeze(0).expand(t, -1, -1) # (t, 4, 3) |
|||
|
|||
# get all faces |
|||
all_faces = torch.gather(input=tet_tx4.unsqueeze(1).expand(-1, 4, -1), index=idx_array, dim=-1).reshape( |
|||
-1, |
|||
3) # (tx4, 3) |
|||
all_faces_tet_idx = torch.arange(t, device=device).unsqueeze(-1).expand(-1, 4).reshape(-1) |
|||
# sort and group |
|||
all_faces_sorted, _ = torch.sort(all_faces, dim=1) |
|||
|
|||
all_faces_unique, inverse_indices, counts = torch.unique( |
|||
all_faces_sorted, dim=0, return_counts=True, |
|||
return_inverse=True) |
|||
tet_face_fx3 = all_faces_unique[counts == 2] |
|||
counts = counts[inverse_indices] # tx4 |
|||
valid = (counts == 2) |
|||
|
|||
group = inverse_indices[valid] |
|||
# print (inverse_indices.shape, group.shape, all_faces_tet_idx.shape) |
|||
_, indices = torch.sort(group) |
|||
all_faces_tet_idx_grouped = all_faces_tet_idx[valid][indices] |
|||
tet_face_tetidx_fx2 = torch.stack([all_faces_tet_idx_grouped[::2], all_faces_tet_idx_grouped[1::2]], dim=-1) |
|||
|
|||
tet_adj_idx = torch.cat([tet_face_tetidx_fx2, torch.flip(tet_face_tetidx_fx2, [1])]) |
|||
adj_self = torch.arange(t, device=tet_tx4.device) |
|||
adj_self = torch.stack([adj_self, adj_self], -1) |
|||
tet_adj_idx = torch.cat([tet_adj_idx, adj_self]) |
|||
|
|||
tet_adj_idx = torch.unique(tet_adj_idx, dim=0) |
|||
values = torch.ones( |
|||
tet_adj_idx.shape[0], device=tet_tx4.device).float() |
|||
adj_sparse = torch.sparse.FloatTensor( |
|||
tet_adj_idx.t(), values, torch.Size([t, t])) |
|||
|
|||
# normalization |
|||
neighbor_num = 1.0 / torch.sparse.sum( |
|||
adj_sparse, dim=1).to_dense() |
|||
values = torch.index_select(neighbor_num, 0, tet_adj_idx[:, 0]) |
|||
adj_sparse = torch.sparse.FloatTensor( |
|||
tet_adj_idx.t(), values, torch.Size([t, t])) |
|||
return adj_sparse |
|||
|
|||
|
|||
############################################################################### |
|||
# Compact grid |
|||
############################################################################### |
|||
|
|||
def get_tet_bxfx4x3(bxnxz, bxfx4): |
|||
n_batch, z = bxnxz.shape[0], bxnxz.shape[2] |
|||
gather_input = bxnxz.unsqueeze(2).expand( |
|||
n_batch, bxnxz.shape[1], 4, z) |
|||
gather_index = bxfx4.unsqueeze(-1).expand( |
|||
n_batch, bxfx4.shape[1], 4, z).long() |
|||
tet_bxfx4xz = torch.gather( |
|||
input=gather_input, dim=1, index=gather_index) |
|||
|
|||
return tet_bxfx4xz |
|||
|
|||
|
|||
def shrink_grid(tet_pos_bxnx3, tet_bxfx4, grid_sdf): |
|||
with torch.no_grad(): |
|||
assert tet_pos_bxnx3.shape[0] == 1 |
|||
|
|||
occ = grid_sdf[0] > 0 |
|||
occ_sum = get_tet_bxfx4x3(occ.unsqueeze(0).unsqueeze(-1), tet_bxfx4).reshape(-1, 4).sum(-1) |
|||
mask = (occ_sum > 0) & (occ_sum < 4) |
|||
|
|||
# build connectivity graph |
|||
adj_matrix = tet_to_tet_adj_sparse(tet_bxfx4[0]) |
|||
mask = mask.float().unsqueeze(-1) |
|||
|
|||
# Include a one ring of neighbors |
|||
for i in range(1): |
|||
mask = torch.sparse.mm(adj_matrix, mask) |
|||
mask = mask.squeeze(-1) > 0 |
|||
|
|||
mapping = torch.zeros((tet_pos_bxnx3.shape[1]), device=tet_pos_bxnx3.device, dtype=torch.long) |
|||
new_tet_bxfx4 = tet_bxfx4[:, mask].long() |
|||
selected_verts_idx = torch.unique(new_tet_bxfx4) |
|||
new_tet_pos_bxnx3 = tet_pos_bxnx3[:, selected_verts_idx] |
|||
mapping[selected_verts_idx] = torch.arange(selected_verts_idx.shape[0], device=tet_pos_bxnx3.device) |
|||
new_tet_bxfx4 = mapping[new_tet_bxfx4.reshape(-1)].reshape(new_tet_bxfx4.shape) |
|||
new_grid_sdf = grid_sdf[:, selected_verts_idx] |
|||
return new_tet_pos_bxnx3, new_tet_bxfx4, new_grid_sdf |
|||
|
|||
|
|||
############################################################################### |
|||
# Regularizer |
|||
############################################################################### |
|||
|
|||
def sdf_reg_loss(sdf, all_edges): |
|||
sdf_f1x6x2 = sdf[all_edges.reshape(-1)].reshape(-1, 2) |
|||
mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1]) |
|||
sdf_f1x6x2 = sdf_f1x6x2[mask] |
|||
sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits( |
|||
sdf_f1x6x2[..., 0], |
|||
(sdf_f1x6x2[..., 1] > 0).float()) + \ |
|||
torch.nn.functional.binary_cross_entropy_with_logits( |
|||
sdf_f1x6x2[..., 1], |
|||
(sdf_f1x6x2[..., 0] > 0).float()) |
|||
return sdf_diff |
|||
|
|||
|
|||
def sdf_reg_loss_batch(sdf, all_edges): |
|||
sdf_f1x6x2 = sdf[:, all_edges.reshape(-1)].reshape(sdf.shape[0], -1, 2) |
|||
mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1]) |
|||
sdf_f1x6x2 = sdf_f1x6x2[mask] |
|||
sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[..., 0], (sdf_f1x6x2[..., 1] > 0).float()) + \ |
|||
torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[..., 1], (sdf_f1x6x2[..., 0] > 0).float()) |
|||
return sdf_diff |
|||
|
|||
|
|||
############################################################################### |
|||
# Geometry interface |
|||
############################################################################### |
|||
class DMTetGeometry(Geometry): |
|||
def __init__( |
|||
self, grid_res=64, scale=2.0, device='cuda', renderer=None, |
|||
render_type='neural_render', args=None): |
|||
super(DMTetGeometry, self).__init__() |
|||
self.grid_res = grid_res |
|||
self.device = device |
|||
self.args = args |
|||
tets = np.load('data/tets/%d_compress.npz' % (grid_res)) |
|||
self.verts = torch.from_numpy(tets['vertices']).float().to(self.device) |
|||
# Make sure the tet is zero-centered and length is equal to 1 |
|||
length = self.verts.max(dim=0)[0] - self.verts.min(dim=0)[0] |
|||
length = length.max() |
|||
mid = (self.verts.max(dim=0)[0] + self.verts.min(dim=0)[0]) / 2.0 |
|||
self.verts = (self.verts - mid.unsqueeze(dim=0)) / length |
|||
if isinstance(scale, list): |
|||
self.verts[:, 0] = self.verts[:, 0] * scale[0] |
|||
self.verts[:, 1] = self.verts[:, 1] * scale[1] |
|||
self.verts[:, 2] = self.verts[:, 2] * scale[1] |
|||
else: |
|||
self.verts = self.verts * scale |
|||
self.indices = torch.from_numpy(tets['tets']).long().to(self.device) |
|||
self.triangle_table, self.num_triangles_table, self.base_tet_edges, self.v_id = create_mt_variable(self.device) |
|||
self.tet_table, self.num_tets_table = create_tetmesh_variables(self.device) |
|||
# Parameters for regularization computation |
|||
edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=self.device) |
|||
all_edges = self.indices[:, edges].reshape(-1, 2) |
|||
all_edges_sorted = torch.sort(all_edges, dim=1)[0] |
|||
self.all_edges = torch.unique(all_edges_sorted, dim=0) |
|||
|
|||
# Parameters used for fix boundary sdf |
|||
self.center_indices, self.boundary_indices = get_center_boundary_index(self.verts) |
|||
self.renderer = renderer |
|||
self.render_type = render_type |
|||
|
|||
def getAABB(self): |
|||
return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values |
|||
|
|||
def get_mesh(self, v_deformed_nx3, sdf_n, with_uv=False, indices=None): |
|||
if indices is None: |
|||
indices = self.indices |
|||
verts, faces = marching_tets( |
|||
v_deformed_nx3, sdf_n, indices, self.triangle_table, |
|||
self.num_triangles_table, self.base_tet_edges, self.v_id) |
|||
faces = torch.cat( |
|||
[faces[:, 0:1], |
|||
faces[:, 2:3], |
|||
faces[:, 1:2], ], dim=-1) |
|||
return verts, faces |
|||
|
|||
def get_tet_mesh(self, v_deformed_nx3, sdf_n, with_uv=False, indices=None): |
|||
if indices is None: |
|||
indices = self.indices |
|||
verts, faces, tet_verts, tets = marching_tets_tetmesh( |
|||
v_deformed_nx3, sdf_n, indices, self.triangle_table, |
|||
self.num_triangles_table, self.base_tet_edges, self.v_id, return_tet_mesh=True, |
|||
num_tets_table=self.num_tets_table, tet_table=self.tet_table, ori_v=v_deformed_nx3) |
|||
faces = torch.cat( |
|||
[faces[:, 0:1], |
|||
faces[:, 2:3], |
|||
faces[:, 1:2], ], dim=-1) |
|||
return verts, faces, tet_verts, tets |
|||
|
|||
def render_mesh(self, mesh_v_nx3, mesh_f_fx3, camera_mv_bx4x4, resolution=256, hierarchical_mask=False): |
|||
return_value = dict() |
|||
if self.render_type == 'neural_render': |
|||
tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth = self.renderer.render_mesh( |
|||
mesh_v_nx3.unsqueeze(dim=0), |
|||
mesh_f_fx3.int(), |
|||
camera_mv_bx4x4, |
|||
mesh_v_nx3.unsqueeze(dim=0), |
|||
resolution=resolution, |
|||
device=self.device, |
|||
hierarchical_mask=hierarchical_mask |
|||
) |
|||
|
|||
return_value['tex_pos'] = tex_pos |
|||
return_value['mask'] = mask |
|||
return_value['hard_mask'] = hard_mask |
|||
return_value['rast'] = rast |
|||
return_value['v_pos_clip'] = v_pos_clip |
|||
return_value['mask_pyramid'] = mask_pyramid |
|||
return_value['depth'] = depth |
|||
else: |
|||
raise NotImplementedError |
|||
|
|||
return return_value |
|||
|
|||
def render(self, v_deformed_bxnx3=None, sdf_bxn=None, camera_mv_bxnviewx4x4=None, resolution=256): |
|||
# Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1 |
|||
v_list = [] |
|||
f_list = [] |
|||
n_batch = v_deformed_bxnx3.shape[0] |
|||
all_render_output = [] |
|||
for i_batch in range(n_batch): |
|||
verts_nx3, faces_fx3 = self.get_mesh(v_deformed_bxnx3[i_batch], sdf_bxn[i_batch]) |
|||
v_list.append(verts_nx3) |
|||
f_list.append(faces_fx3) |
|||
render_output = self.render_mesh(verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution) |
|||
all_render_output.append(render_output) |
|||
|
|||
# Concatenate all render output |
|||
return_keys = all_render_output[0].keys() |
|||
return_value = dict() |
|||
for k in return_keys: |
|||
value = [v[k] for v in all_render_output] |
|||
return_value[k] = value |
|||
# We can do concatenation outside of the render |
|||
return return_value |
@ -0,0 +1,20 @@ |
|||
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
|||
# |
|||
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property |
|||
# and proprietary rights in and to this software, related documentation |
|||
# and any modifications thereto. Any use, reproduction, disclosure or |
|||
# distribution of this software and related documentation without an express |
|||
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. |
|||
|
|||
import torch |
|||
|
|||
|
|||
def get_center_boundary_index(verts): |
|||
length_ = torch.sum(verts ** 2, dim=-1) |
|||
center_idx = torch.argmin(length_) |
|||
boundary_neg = verts == verts.max() |
|||
boundary_pos = verts == verts.min() |
|||
boundary = torch.bitwise_or(boundary_pos, boundary_neg) |
|||
boundary = torch.sum(boundary.float(), dim=-1) |
|||
boundary_idx = torch.nonzero(boundary) |
|||
return center_idx, boundary_idx.squeeze(dim=-1) |
@ -0,0 +1,40 @@ |
|||
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
|||
# |
|||
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property |
|||
# and proprietary rights in and to this software, related documentation |
|||
# and any modifications thereto. Any use, reproduction, disclosure or |
|||
# distribution of this software and related documentation without an express |
|||
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. |
|||
|
|||
import torch |
|||
import xatlas |
|||
import numpy as np |
|||
import nvdiffrast.torch as dr |
|||
|
|||
|
|||
# ============================================================================================== |
|||
def interpolate(attr, rast, attr_idx, rast_db=None): |
|||
return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all') |
|||
|
|||
|
|||
def xatlas_uvmap(ctx, mesh_v, mesh_pos_idx, resolution): |
|||
vmapping, indices, uvs = xatlas.parametrize(mesh_v.detach().cpu().numpy(), mesh_pos_idx.detach().cpu().numpy()) |
|||
|
|||
# Convert to tensors |
|||
indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64) |
|||
|
|||
uvs = torch.tensor(uvs, dtype=torch.float32, device=mesh_v.device) |
|||
mesh_tex_idx = torch.tensor(indices_int64, dtype=torch.int64, device=mesh_v.device) |
|||
# mesh_v_tex. ture |
|||
uv_clip = uvs[None, ...] * 2.0 - 1.0 |
|||
|
|||
# pad to four component coordinate |
|||
uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., 0:1]), torch.ones_like(uv_clip[..., 0:1])), dim=-1) |
|||
|
|||
# rasterize |
|||
rast, _ = dr.rasterize(ctx, uv_clip4, mesh_tex_idx.int(), (resolution, resolution)) |
|||
|
|||
# Interpolate world space position |
|||
gb_pos, _ = interpolate(mesh_v[None, ...], rast, mesh_pos_idx.int()) |
|||
mask = rast[..., 3:4] > 0 |
|||
return uvs, mesh_tex_idx, gb_pos, mask |
@ -0,0 +1,579 @@ |
|||
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
|||
# |
|||
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property |
|||
# and proprietary rights in and to this software, related documentation |
|||
# and any modifications thereto. Any use, reproduction, disclosure or |
|||
# distribution of this software and related documentation without an express |
|||
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. |
|||
import torch |
|||
from .tables import * |
|||
|
|||
__all__ = [ |
|||
'FlexiCubes' |
|||
] |
|||
|
|||
|
|||
class FlexiCubes: |
|||
""" |
|||
This class implements the FlexiCubes method for extracting meshes from scalar fields. |
|||
It maintains a series of lookup tables and indices to support the mesh extraction process. |
|||
FlexiCubes, a differentiable variant of the Dual Marching Cubes (DMC) scheme, enhances |
|||
the geometric fidelity and mesh quality of reconstructed meshes by dynamically adjusting |
|||
the surface representation through gradient-based optimization. |
|||
|
|||
During instantiation, the class loads DMC tables from a file and transforms them into |
|||
PyTorch tensors on the specified device. |
|||
|
|||
Attributes: |
|||
device (str): Specifies the computational device (default is "cuda"). |
|||
dmc_table (torch.Tensor): Dual Marching Cubes (DMC) table that encodes the edges |
|||
associated with each dual vertex in 256 Marching Cubes (MC) configurations. |
|||
num_vd_table (torch.Tensor): Table holding the number of dual vertices in each of |
|||
the 256 MC configurations. |
|||
check_table (torch.Tensor): Table resolving ambiguity in cases C16 and C19 |
|||
of the DMC configurations. |
|||
tet_table (torch.Tensor): Lookup table used in tetrahedralizing the isosurface. |
|||
quad_split_1 (torch.Tensor): Indices for splitting a quad into two triangles |
|||
along one diagonal. |
|||
quad_split_2 (torch.Tensor): Alternative indices for splitting a quad into |
|||
two triangles along the other diagonal. |
|||
quad_split_train (torch.Tensor): Indices for splitting a quad into four triangles |
|||
during training by connecting all edges to their midpoints. |
|||
cube_corners (torch.Tensor): Defines the positions of a standard unit cube's |
|||
eight corners in 3D space, ordered starting from the origin (0,0,0), |
|||
moving along the x-axis, then y-axis, and finally z-axis. |
|||
Used as a blueprint for generating a voxel grid. |
|||
cube_corners_idx (torch.Tensor): Cube corners indexed as powers of 2, used |
|||
to retrieve the case id. |
|||
cube_edges (torch.Tensor): Edge connections in a cube, listed in pairs. |
|||
Used to retrieve edge vertices in DMC. |
|||
edge_dir_table (torch.Tensor): A mapping tensor that associates edge indices with |
|||
their corresponding axis. For instance, edge_dir_table[0] = 0 indicates that the |
|||
first edge is oriented along the x-axis. |
|||
dir_faces_table (torch.Tensor): A tensor that maps the corresponding axis of shared edges |
|||
across four adjacent cubes to the shared faces of these cubes. For instance, |
|||
dir_faces_table[0] = [5, 4] implies that for four cubes sharing an edge along |
|||
the x-axis, the first and second cubes share faces indexed as 5 and 4, respectively. |
|||
This tensor is only utilized during isosurface tetrahedralization. |
|||
adj_pairs (torch.Tensor): |
|||
A tensor containing index pairs that correspond to neighboring cubes that share the same edge. |
|||
qef_reg_scale (float): |
|||
The scaling factor applied to the regularization loss to prevent issues with singularity |
|||
when solving the QEF. This parameter is only used when a 'grad_func' is specified. |
|||
weight_scale (float): |
|||
The scale of weights in FlexiCubes. Should be between 0 and 1. |
|||
""" |
|||
|
|||
def __init__(self, device="cuda", qef_reg_scale=1e-3, weight_scale=0.99): |
|||
|
|||
self.device = device |
|||
self.dmc_table = torch.tensor(dmc_table, dtype=torch.long, device=device, requires_grad=False) |
|||
self.num_vd_table = torch.tensor(num_vd_table, |
|||
dtype=torch.long, device=device, requires_grad=False) |
|||
self.check_table = torch.tensor( |
|||
check_table, |
|||
dtype=torch.long, device=device, requires_grad=False) |
|||
|
|||
self.tet_table = torch.tensor(tet_table, dtype=torch.long, device=device, requires_grad=False) |
|||
self.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False) |
|||
self.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False) |
|||
self.quad_split_train = torch.tensor( |
|||
[0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.long, device=device, requires_grad=False) |
|||
|
|||
self.cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [ |
|||
1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.float, device=device) |
|||
self.cube_corners_idx = torch.pow(2, torch.arange(8, requires_grad=False)) |
|||
self.cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6, |
|||
2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, device=device, requires_grad=False) |
|||
|
|||
self.edge_dir_table = torch.tensor([0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1], |
|||
dtype=torch.long, device=device) |
|||
self.dir_faces_table = torch.tensor([ |
|||
[[5, 4], [3, 2], [4, 5], [2, 3]], |
|||
[[5, 4], [1, 0], [4, 5], [0, 1]], |
|||
[[3, 2], [1, 0], [2, 3], [0, 1]] |
|||
], dtype=torch.long, device=device) |
|||
self.adj_pairs = torch.tensor([0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.long, device=device) |
|||
self.qef_reg_scale = qef_reg_scale |
|||
self.weight_scale = weight_scale |
|||
|
|||
def construct_voxel_grid(self, res): |
|||
""" |
|||
Generates a voxel grid based on the specified resolution. |
|||
|
|||
Args: |
|||
res (int or list[int]): The resolution of the voxel grid. If an integer |
|||
is provided, it is used for all three dimensions. If a list or tuple |
|||
of 3 integers is provided, they define the resolution for the x, |
|||
y, and z dimensions respectively. |
|||
|
|||
Returns: |
|||
(torch.Tensor, torch.Tensor): Returns the vertices and the indices of the |
|||
cube corners (index into vertices) of the constructed voxel grid. |
|||
The vertices are centered at the origin, with the length of each |
|||
dimension in the grid being one. |
|||
""" |
|||
base_cube_f = torch.arange(8).to(self.device) |
|||
if isinstance(res, int): |
|||
res = (res, res, res) |
|||
voxel_grid_template = torch.ones(res, device=self.device) |
|||
|
|||
res = torch.tensor([res], dtype=torch.float, device=self.device) |
|||
coords = torch.nonzero(voxel_grid_template).float() / res # N, 3 |
|||
verts = (self.cube_corners.unsqueeze(0) / res + coords.unsqueeze(1)).reshape(-1, 3) |
|||
cubes = (base_cube_f.unsqueeze(0) + |
|||
torch.arange(coords.shape[0], device=self.device).unsqueeze(1) * 8).reshape(-1) |
|||
|
|||
verts_rounded = torch.round(verts * 10**5) / (10**5) |
|||
verts_unique, inverse_indices = torch.unique(verts_rounded, dim=0, return_inverse=True) |
|||
cubes = inverse_indices[cubes.reshape(-1)].reshape(-1, 8) |
|||
|
|||
return verts_unique - 0.5, cubes |
|||
|
|||
def __call__(self, x_nx3, s_n, cube_fx8, res, beta_fx12=None, alpha_fx8=None, |
|||
gamma_f=None, training=False, output_tetmesh=False, grad_func=None): |
|||
r""" |
|||
Main function for mesh extraction from scalar field using FlexiCubes. This function converts |
|||
discrete signed distance fields, encoded on voxel grids and additional per-cube parameters, |
|||
to triangle or tetrahedral meshes using a differentiable operation as described in |
|||
`Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_. FlexiCubes enhances |
|||
mesh quality and geometric fidelity by adjusting the surface representation based on gradient |
|||
optimization. The output surface is differentiable with respect to the input vertex positions, |
|||
scalar field values, and weight parameters. |
|||
|
|||
If you intend to extract a surface mesh from a fixed Signed Distance Field without the |
|||
optimization of parameters, it is suggested to provide the "grad_func" which should |
|||
return the surface gradient at any given 3D position. When grad_func is provided, the process |
|||
to determine the dual vertex position adapts to solve a Quadratic Error Function (QEF), as |
|||
described in the `Manifold Dual Contouring`_ paper, and employs an smart splitting strategy. |
|||
Please note, this approach is non-differentiable. |
|||
|
|||
For more details and example usage in optimization, refer to the |
|||
`Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_ SIGGRAPH 2023 paper. |
|||
|
|||
Args: |
|||
x_nx3 (torch.Tensor): Coordinates of the voxel grid vertices, can be deformed. |
|||
s_n (torch.Tensor): Scalar field values at each vertex of the voxel grid. Negative values |
|||
denote that the corresponding vertex resides inside the isosurface. This affects |
|||
the directions of the extracted triangle faces and volume to be tetrahedralized. |
|||
cube_fx8 (torch.Tensor): Indices of 8 vertices for each cube in the voxel grid. |
|||
res (int or list[int]): The resolution of the voxel grid. If an integer is provided, it |
|||
is used for all three dimensions. If a list or tuple of 3 integers is provided, they |
|||
specify the resolution for the x, y, and z dimensions respectively. |
|||
beta_fx12 (torch.Tensor, optional): Weight parameters for the cube edges to adjust dual |
|||
vertices positioning. Defaults to uniform value for all edges. |
|||
alpha_fx8 (torch.Tensor, optional): Weight parameters for the cube corners to adjust dual |
|||
vertices positioning. Defaults to uniform value for all vertices. |
|||
gamma_f (torch.Tensor, optional): Weight parameters to control the splitting of |
|||
quadrilaterals into triangles. Defaults to uniform value for all cubes. |
|||
training (bool, optional): If set to True, applies differentiable quad splitting for |
|||
training. Defaults to False. |
|||
output_tetmesh (bool, optional): If set to True, outputs a tetrahedral mesh, otherwise, |
|||
outputs a triangular mesh. Defaults to False. |
|||
grad_func (callable, optional): A function to compute the surface gradient at specified |
|||
3D positions (input: Nx3 positions). The function should return gradients as an Nx3 |
|||
tensor. If None, the original FlexiCubes algorithm is utilized. Defaults to None. |
|||
|
|||
Returns: |
|||
(torch.Tensor, torch.LongTensor, torch.Tensor): Tuple containing: |
|||
- Vertices for the extracted triangular/tetrahedral mesh. |
|||
- Faces for the extracted triangular/tetrahedral mesh. |
|||
- Regularizer L_dev, computed per dual vertex. |
|||
|
|||
.. _Flexible Isosurface Extraction for Gradient-Based Mesh Optimization: |
|||
https://research.nvidia.com/labs/toronto-ai/flexicubes/ |
|||
.. _Manifold Dual Contouring: |
|||
https://people.engr.tamu.edu/schaefer/research/dualsimp_tvcg.pdf |
|||
""" |
|||
|
|||
surf_cubes, occ_fx8 = self._identify_surf_cubes(s_n, cube_fx8) |
|||
if surf_cubes.sum() == 0: |
|||
return torch.zeros( |
|||
(0, 3), |
|||
device=self.device), torch.zeros( |
|||
(0, 4), |
|||
dtype=torch.long, device=self.device) if output_tetmesh else torch.zeros( |
|||
(0, 3), |
|||
dtype=torch.long, device=self.device), torch.zeros( |
|||
(0), |
|||
device=self.device) |
|||
beta_fx12, alpha_fx8, gamma_f = self._normalize_weights(beta_fx12, alpha_fx8, gamma_f, surf_cubes) |
|||
|
|||
case_ids = self._get_case_id(occ_fx8, surf_cubes, res) |
|||
|
|||
surf_edges, idx_map, edge_counts, surf_edges_mask = self._identify_surf_edges(s_n, cube_fx8, surf_cubes) |
|||
|
|||
vd, L_dev, vd_gamma, vd_idx_map = self._compute_vd( |
|||
x_nx3, cube_fx8[surf_cubes], surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func) |
|||
vertices, faces, s_edges, edge_indices = self._triangulate( |
|||
s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func) |
|||
if not output_tetmesh: |
|||
return vertices, faces, L_dev |
|||
else: |
|||
vertices, tets = self._tetrahedralize( |
|||
x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices, |
|||
surf_cubes, training) |
|||
return vertices, tets, L_dev |
|||
|
|||
def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges): |
|||
""" |
|||
Regularizer L_dev as in Equation 8 |
|||
""" |
|||
dist = torch.norm(ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1) |
|||
mean_l2 = torch.zeros_like(vd[:, 0]) |
|||
mean_l2 = (mean_l2).index_add_(0, edge_group_to_vd, dist) / vd_num_edges.squeeze(1).float() |
|||
mad = (dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0)).abs() |
|||
return mad |
|||
|
|||
def _normalize_weights(self, beta_fx12, alpha_fx8, gamma_f, surf_cubes): |
|||
""" |
|||
Normalizes the given weights to be non-negative. If input weights are None, it creates and returns a set of weights of ones. |
|||
""" |
|||
n_cubes = surf_cubes.shape[0] |
|||
|
|||
if beta_fx12 is not None: |
|||
beta_fx12 = (torch.tanh(beta_fx12) * self.weight_scale + 1) |
|||
else: |
|||
beta_fx12 = torch.ones((n_cubes, 12), dtype=torch.float, device=self.device) |
|||
|
|||
if alpha_fx8 is not None: |
|||
alpha_fx8 = (torch.tanh(alpha_fx8) * self.weight_scale + 1) |
|||
else: |
|||
alpha_fx8 = torch.ones((n_cubes, 8), dtype=torch.float, device=self.device) |
|||
|
|||
if gamma_f is not None: |
|||
gamma_f = torch.sigmoid(gamma_f) * self.weight_scale + (1 - self.weight_scale)/2 |
|||
else: |
|||
gamma_f = torch.ones((n_cubes), dtype=torch.float, device=self.device) |
|||
|
|||
return beta_fx12[surf_cubes], alpha_fx8[surf_cubes], gamma_f[surf_cubes] |
|||
|
|||
@torch.no_grad() |
|||
def _get_case_id(self, occ_fx8, surf_cubes, res): |
|||
""" |
|||
Obtains the ID of topology cases based on cell corner occupancy. This function resolves the |
|||
ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the |
|||
supplementary material. It should be noted that this function assumes a regular grid. |
|||
""" |
|||
case_ids = (occ_fx8[surf_cubes] * self.cube_corners_idx.to(self.device).unsqueeze(0)).sum(-1) |
|||
|
|||
problem_config = self.check_table.to(self.device)[case_ids] |
|||
to_check = problem_config[..., 0] == 1 |
|||
problem_config = problem_config[to_check] |
|||
if not isinstance(res, (list, tuple)): |
|||
res = [res, res, res] |
|||
|
|||
# The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array, |
|||
# 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes). |
|||
# This allows efficient checking on adjacent cubes. |
|||
problem_config_full = torch.zeros(list(res) + [5], device=self.device, dtype=torch.long) |
|||
vol_idx = torch.nonzero(problem_config_full[..., 0] == 0) # N, 3 |
|||
vol_idx_problem = vol_idx[surf_cubes][to_check] |
|||
problem_config_full[vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2]] = problem_config |
|||
vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4] |
|||
|
|||
within_range = ( |
|||
vol_idx_problem_adj[..., 0] >= 0) & ( |
|||
vol_idx_problem_adj[..., 0] < res[0]) & ( |
|||
vol_idx_problem_adj[..., 1] >= 0) & ( |
|||
vol_idx_problem_adj[..., 1] < res[1]) & ( |
|||
vol_idx_problem_adj[..., 2] >= 0) & ( |
|||
vol_idx_problem_adj[..., 2] < res[2]) |
|||
|
|||
vol_idx_problem = vol_idx_problem[within_range] |
|||
vol_idx_problem_adj = vol_idx_problem_adj[within_range] |
|||
problem_config = problem_config[within_range] |
|||
problem_config_adj = problem_config_full[vol_idx_problem_adj[..., 0], |
|||
vol_idx_problem_adj[..., 1], vol_idx_problem_adj[..., 2]] |
|||
# If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted. |
|||
to_invert = (problem_config_adj[..., 0] == 1) |
|||
idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][within_range][to_invert] |
|||
case_ids.index_put_((idx,), problem_config[to_invert][..., -1]) |
|||
return case_ids |
|||
|
|||
@torch.no_grad() |
|||
def _identify_surf_edges(self, s_n, cube_fx8, surf_cubes): |
|||
""" |
|||
Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge |
|||
can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge |
|||
and marks the cube edges with this index. |
|||
""" |
|||
occ_n = s_n < 0 |
|||
all_edges = cube_fx8[surf_cubes][:, self.cube_edges].reshape(-1, 2) |
|||
unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True) |
|||
|
|||
unique_edges = unique_edges.long() |
|||
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 |
|||
|
|||
surf_edges_mask = mask_edges[_idx_map] |
|||
counts = counts[_idx_map] |
|||
|
|||
mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=cube_fx8.device) * -1 |
|||
mapping[mask_edges] = torch.arange(mask_edges.sum(), device=cube_fx8.device) |
|||
# Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index |
|||
# for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1. |
|||
idx_map = mapping[_idx_map] |
|||
surf_edges = unique_edges[mask_edges] |
|||
return surf_edges, idx_map, counts, surf_edges_mask |
|||
|
|||
@torch.no_grad() |
|||
def _identify_surf_cubes(self, s_n, cube_fx8): |
|||
""" |
|||
Identifies grid cubes that intersect with the underlying surface by checking if the signs at |
|||
all corners are not identical. |
|||
""" |
|||
occ_n = s_n < 0 |
|||
occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8) |
|||
_occ_sum = torch.sum(occ_fx8, -1) |
|||
surf_cubes = (_occ_sum > 0) & (_occ_sum < 8) |
|||
return surf_cubes, occ_fx8 |
|||
|
|||
def _linear_interp(self, edges_weight, edges_x): |
|||
""" |
|||
Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'. |
|||
""" |
|||
edge_dim = edges_weight.dim() - 2 |
|||
assert edges_weight.shape[edge_dim] == 2 |
|||
edges_weight = torch.cat([torch.index_select(input=edges_weight, index=torch.tensor(1, device=self.device), dim=edge_dim), - |
|||
torch.index_select(input=edges_weight, index=torch.tensor(0, device=self.device), dim=edge_dim)], edge_dim) |
|||
denominator = edges_weight.sum(edge_dim) |
|||
ue = (edges_x * edges_weight).sum(edge_dim) / denominator |
|||
return ue |
|||
|
|||
def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3=None): |
|||
p_bxnx3 = p_bxnx3.reshape(-1, 7, 3) |
|||
norm_bxnx3 = norm_bxnx3.reshape(-1, 7, 3) |
|||
c_bx3 = c_bx3.reshape(-1, 3) |
|||
A = norm_bxnx3 |
|||
B = ((p_bxnx3) * norm_bxnx3).sum(-1, keepdims=True) |
|||
|
|||
A_reg = (torch.eye(3, device=p_bxnx3.device) * self.qef_reg_scale).unsqueeze(0).repeat(p_bxnx3.shape[0], 1, 1) |
|||
B_reg = (self.qef_reg_scale * c_bx3).unsqueeze(-1) |
|||
A = torch.cat([A, A_reg], 1) |
|||
B = torch.cat([B, B_reg], 1) |
|||
dual_verts = torch.linalg.lstsq(A, B).solution.squeeze(-1) |
|||
return dual_verts |
|||
|
|||
def _compute_vd(self, x_nx3, surf_cubes_fx8, surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func): |
|||
""" |
|||
Computes the location of dual vertices as described in Section 4.2 |
|||
""" |
|||
alpha_nx12x2 = torch.index_select(input=alpha_fx8, index=self.cube_edges, dim=1).reshape(-1, 12, 2) |
|||
surf_edges_x = torch.index_select(input=x_nx3, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 3) |
|||
surf_edges_s = torch.index_select(input=s_n, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 1) |
|||
zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x) |
|||
|
|||
idx_map = idx_map.reshape(-1, 12) |
|||
num_vd = torch.index_select(input=self.num_vd_table, index=case_ids, dim=0) |
|||
edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = [], [], [], [], [] |
|||
|
|||
total_num_vd = 0 |
|||
vd_idx_map = torch.zeros((case_ids.shape[0], 12), dtype=torch.long, device=self.device, requires_grad=False) |
|||
if grad_func is not None: |
|||
normals = torch.nn.functional.normalize(grad_func(zero_crossing), dim=-1) |
|||
vd = [] |
|||
for num in torch.unique(num_vd): |
|||
cur_cubes = (num_vd == num) # consider cubes with the same numbers of vd emitted (for batching) |
|||
curr_num_vd = cur_cubes.sum() * num |
|||
curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape(-1, num * 7) |
|||
curr_edge_group_to_vd = torch.arange( |
|||
curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7) + total_num_vd |
|||
total_num_vd += curr_num_vd |
|||
curr_edge_group_to_cube = torch.arange(idx_map.shape[0], device=self.device)[ |
|||
cur_cubes].unsqueeze(-1).repeat(1, num * 7).reshape_as(curr_edge_group) |
|||
|
|||
curr_mask = (curr_edge_group != -1) |
|||
edge_group.append(torch.masked_select(curr_edge_group, curr_mask)) |
|||
edge_group_to_vd.append(torch.masked_select(curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask)) |
|||
edge_group_to_cube.append(torch.masked_select(curr_edge_group_to_cube, curr_mask)) |
|||
vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True)) |
|||
vd_gamma.append(torch.masked_select(gamma_f, cur_cubes).unsqueeze(-1).repeat(1, num).reshape(-1)) |
|||
|
|||
if grad_func is not None: |
|||
with torch.no_grad(): |
|||
cube_e_verts_idx = idx_map[cur_cubes] |
|||
curr_edge_group[~curr_mask] = 0 |
|||
|
|||
verts_group_idx = torch.gather(input=cube_e_verts_idx, dim=1, index=curr_edge_group) |
|||
verts_group_idx[verts_group_idx == -1] = 0 |
|||
verts_group_pos = torch.index_select( |
|||
input=zero_crossing, index=verts_group_idx.reshape(-1), dim=0).reshape(-1, num.item(), 7, 3) |
|||
v0 = x_nx3[surf_cubes_fx8[cur_cubes][:, 0]].reshape(-1, 1, 1, 3).repeat(1, num.item(), 1, 1) |
|||
curr_mask = curr_mask.reshape(-1, num.item(), 7, 1) |
|||
verts_centroid = (verts_group_pos * curr_mask).sum(2) / (curr_mask.sum(2)) |
|||
|
|||
normals_bx7x3 = torch.index_select(input=normals, index=verts_group_idx.reshape(-1), dim=0).reshape( |
|||
-1, num.item(), 7, |
|||
3) |
|||
curr_mask = curr_mask.squeeze(2) |
|||
vd.append(self._solve_vd_QEF((verts_group_pos - v0) * curr_mask, normals_bx7x3 * curr_mask, |
|||
verts_centroid - v0.squeeze(2)) + v0.reshape(-1, 3)) |
|||
edge_group = torch.cat(edge_group) |
|||
edge_group_to_vd = torch.cat(edge_group_to_vd) |
|||
edge_group_to_cube = torch.cat(edge_group_to_cube) |
|||
vd_num_edges = torch.cat(vd_num_edges) |
|||
vd_gamma = torch.cat(vd_gamma) |
|||
|
|||
if grad_func is not None: |
|||
vd = torch.cat(vd) |
|||
L_dev = torch.zeros([1], device=self.device) |
|||
else: |
|||
vd = torch.zeros((total_num_vd, 3), device=self.device) |
|||
beta_sum = torch.zeros((total_num_vd, 1), device=self.device) |
|||
|
|||
idx_group = torch.gather(input=idx_map.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group) |
|||
|
|||
x_group = torch.index_select(input=surf_edges_x, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 3) |
|||
s_group = torch.index_select(input=surf_edges_s, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 1) |
|||
|
|||
zero_crossing_group = torch.index_select( |
|||
input=zero_crossing, index=idx_group.reshape(-1), dim=0).reshape(-1, 3) |
|||
|
|||
alpha_group = torch.index_select(input=alpha_nx12x2.reshape(-1, 2), dim=0, |
|||
index=edge_group_to_cube * 12 + edge_group).reshape(-1, 2, 1) |
|||
ue_group = self._linear_interp(s_group * alpha_group, x_group) |
|||
|
|||
beta_group = torch.gather(input=beta_fx12.reshape(-1), dim=0, |
|||
index=edge_group_to_cube * 12 + edge_group).reshape(-1, 1) |
|||
beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group) |
|||
vd = vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) / beta_sum |
|||
L_dev = self._compute_reg_loss(vd, zero_crossing_group, edge_group_to_vd, vd_num_edges) |
|||
|
|||
v_idx = torch.arange(vd.shape[0], device=self.device) # + total_num_vd |
|||
|
|||
vd_idx_map = (vd_idx_map.reshape(-1)).scatter(dim=0, index=edge_group_to_cube * |
|||
12 + edge_group, src=v_idx[edge_group_to_vd]) |
|||
|
|||
return vd, L_dev, vd_gamma, vd_idx_map |
|||
|
|||
def _triangulate(self, s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func): |
|||
""" |
|||
Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into |
|||
triangles based on the gamma parameter, as described in Section 4.3. |
|||
""" |
|||
with torch.no_grad(): |
|||
group_mask = (edge_counts == 4) & surf_edges_mask # surface edges shared by 4 cubes. |
|||
group = idx_map.reshape(-1)[group_mask] |
|||
vd_idx = vd_idx_map[group_mask] |
|||
edge_indices, indices = torch.sort(group, stable=True) |
|||
quad_vd_idx = vd_idx[indices].reshape(-1, 4) |
|||
|
|||
# Ensure all face directions point towards the positive SDF to maintain consistent winding. |
|||
s_edges = s_n[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1)].reshape(-1, 2) |
|||
flip_mask = s_edges[:, 0] > 0 |
|||
quad_vd_idx = torch.cat((quad_vd_idx[flip_mask][:, [0, 1, 3, 2]], |
|||
quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]])) |
|||
if grad_func is not None: |
|||
# when grad_func is given, split quadrilaterals along the diagonals with more consistent gradients. |
|||
with torch.no_grad(): |
|||
vd_gamma = torch.nn.functional.normalize(grad_func(vd), dim=-1) |
|||
quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3) |
|||
gamma_02 = (quad_gamma[:, 0] * quad_gamma[:, 2]).sum(-1, keepdims=True) |
|||
gamma_13 = (quad_gamma[:, 1] * quad_gamma[:, 3]).sum(-1, keepdims=True) |
|||
else: |
|||
quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4) |
|||
gamma_02 = torch.index_select(input=quad_gamma, index=torch.tensor( |
|||
0, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(2, device=self.device), dim=1) |
|||
gamma_13 = torch.index_select(input=quad_gamma, index=torch.tensor( |
|||
1, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(3, device=self.device), dim=1) |
|||
if not training: |
|||
mask = (gamma_02 > gamma_13).squeeze(1) |
|||
faces = torch.zeros((quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device) |
|||
faces[mask] = quad_vd_idx[mask][:, self.quad_split_1] |
|||
faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2] |
|||
faces = faces.reshape(-1, 3) |
|||
else: |
|||
vd_quad = torch.index_select(input=vd, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3) |
|||
vd_02 = (torch.index_select(input=vd_quad, index=torch.tensor(0, device=self.device), dim=1) + |
|||
torch.index_select(input=vd_quad, index=torch.tensor(2, device=self.device), dim=1)) / 2 |
|||
vd_13 = (torch.index_select(input=vd_quad, index=torch.tensor(1, device=self.device), dim=1) + |
|||
torch.index_select(input=vd_quad, index=torch.tensor(3, device=self.device), dim=1)) / 2 |
|||
weight_sum = (gamma_02 + gamma_13) + 1e-8 |
|||
vd_center = ((vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1)) / |
|||
weight_sum.unsqueeze(-1)).squeeze(1) |
|||
vd_center_idx = torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0] |
|||
vd = torch.cat([vd, vd_center]) |
|||
faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2) |
|||
faces = torch.cat([faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1).reshape(-1, 3) |
|||
return vd, faces, s_edges, edge_indices |
|||
|
|||
def _tetrahedralize( |
|||
self, x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices, |
|||
surf_cubes, training): |
|||
""" |
|||
Tetrahedralizes the interior volume to produce a tetrahedral mesh, as described in Section 4.5. |
|||
""" |
|||
occ_n = s_n < 0 |
|||
occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8) |
|||
occ_sum = torch.sum(occ_fx8, -1) |
|||
|
|||
inside_verts = x_nx3[occ_n] |
|||
mapping_inside_verts = torch.ones((occ_n.shape[0]), dtype=torch.long, device=self.device) * -1 |
|||
mapping_inside_verts[occ_n] = torch.arange(occ_n.sum(), device=self.device) + vertices.shape[0] |
|||
""" |
|||
For each grid edge connecting two grid vertices with different |
|||
signs, we first form a four-sided pyramid by connecting one |
|||
of the grid vertices with four mesh vertices that correspond |
|||
to the grid edge and then subdivide the pyramid into two tetrahedra |
|||
""" |
|||
inside_verts_idx = mapping_inside_verts[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1, 2)[ |
|||
s_edges < 0]] |
|||
if not training: |
|||
inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 2).reshape(-1) |
|||
else: |
|||
inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 4).reshape(-1) |
|||
|
|||
tets_surface = torch.cat([faces, inside_verts_idx.unsqueeze(-1)], -1) |
|||
""" |
|||
For each grid edge connecting two grid vertices with the |
|||
same sign, the tetrahedron is formed by the two grid vertices |
|||
and two vertices in consecutive adjacent cells |
|||
""" |
|||
inside_cubes = (occ_sum == 8) |
|||
inside_cubes_center = x_nx3[cube_fx8[inside_cubes].reshape(-1)].reshape(-1, 8, 3).mean(1) |
|||
inside_cubes_center_idx = torch.arange( |
|||
inside_cubes_center.shape[0], device=inside_cubes.device) + vertices.shape[0] + inside_verts.shape[0] |
|||
|
|||
surface_n_inside_cubes = surf_cubes | inside_cubes |
|||
edge_center_vertex_idx = torch.ones(((surface_n_inside_cubes).sum(), 13), |
|||
dtype=torch.long, device=x_nx3.device) * -1 |
|||
surf_cubes = surf_cubes[surface_n_inside_cubes] |
|||
inside_cubes = inside_cubes[surface_n_inside_cubes] |
|||
edge_center_vertex_idx[surf_cubes, :12] = vd_idx_map.reshape(-1, 12) |
|||
edge_center_vertex_idx[inside_cubes, 12] = inside_cubes_center_idx |
|||
|
|||
all_edges = cube_fx8[surface_n_inside_cubes][:, self.cube_edges].reshape(-1, 2) |
|||
unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True) |
|||
unique_edges = unique_edges.long() |
|||
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 2 |
|||
mask = mask_edges[_idx_map] |
|||
counts = counts[_idx_map] |
|||
mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=self.device) * -1 |
|||
mapping[mask_edges] = torch.arange(mask_edges.sum(), device=self.device) |
|||
idx_map = mapping[_idx_map] |
|||
|
|||
group_mask = (counts == 4) & mask |
|||
group = idx_map.reshape(-1)[group_mask] |
|||
edge_indices, indices = torch.sort(group) |
|||
cube_idx = torch.arange((_idx_map.shape[0] // 12), dtype=torch.long, |
|||
device=self.device).unsqueeze(1).expand(-1, 12).reshape(-1)[group_mask] |
|||
edge_idx = torch.arange((12), dtype=torch.long, device=self.device).unsqueeze( |
|||
0).expand(_idx_map.shape[0] // 12, -1).reshape(-1)[group_mask] |
|||
# Identify the face shared by the adjacent cells. |
|||
cube_idx_4 = cube_idx[indices].reshape(-1, 4) |
|||
edge_dir = self.edge_dir_table[edge_idx[indices]].reshape(-1, 4)[..., 0] |
|||
shared_faces_4x2 = self.dir_faces_table[edge_dir].reshape(-1) |
|||
cube_idx_4x2 = cube_idx_4[:, self.adj_pairs].reshape(-1) |
|||
# Identify an edge of the face with different signs and |
|||
# select the mesh vertex corresponding to the identified edge. |
|||
case_ids_expand = torch.ones((surface_n_inside_cubes).sum(), dtype=torch.long, device=x_nx3.device) * 255 |
|||
case_ids_expand[surf_cubes] = case_ids |
|||
cases = case_ids_expand[cube_idx_4x2] |
|||
quad_edge = edge_center_vertex_idx[cube_idx_4x2, self.tet_table[cases, shared_faces_4x2]].reshape(-1, 2) |
|||
mask = (quad_edge == -1).sum(-1) == 0 |
|||
inside_edge = mapping_inside_verts[unique_edges[mask_edges][edge_indices].reshape(-1)].reshape(-1, 2) |
|||
tets_inside = torch.cat([quad_edge, inside_edge], -1)[mask] |
|||
|
|||
tets = torch.cat([tets_surface, tets_inside]) |
|||
vertices = torch.cat([vertices, inside_verts, inside_cubes_center]) |
|||
return vertices, tets |
@ -0,0 +1,120 @@ |
|||
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
|||
# |
|||
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property |
|||
# and proprietary rights in and to this software, related documentation |
|||
# and any modifications thereto. Any use, reproduction, disclosure or |
|||
# distribution of this software and related documentation without an express |
|||
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. |
|||
|
|||
import torch |
|||
import numpy as np |
|||
import os |
|||
from . import Geometry |
|||
from .flexicubes import FlexiCubes # replace later |
|||
from .dmtet import sdf_reg_loss_batch |
|||
import torch.nn.functional as F |
|||
|
|||
def get_center_boundary_index(grid_res, device): |
|||
v = torch.zeros((grid_res + 1, grid_res + 1, grid_res + 1), dtype=torch.bool, device=device) |
|||
v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = True |
|||
center_indices = torch.nonzero(v.reshape(-1)) |
|||
|
|||
v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = False |
|||
v[:2, ...] = True |
|||
v[-2:, ...] = True |
|||
v[:, :2, ...] = True |
|||
v[:, -2:, ...] = True |
|||
v[:, :, :2] = True |
|||
v[:, :, -2:] = True |
|||
boundary_indices = torch.nonzero(v.reshape(-1)) |
|||
return center_indices, boundary_indices |
|||
|
|||
############################################################################### |
|||
# Geometry interface |
|||
############################################################################### |
|||
class FlexiCubesGeometry(Geometry): |
|||
def __init__( |
|||
self, grid_res=64, scale=2.0, device='cuda', renderer=None, |
|||
render_type='neural_render', args=None): |
|||
super(FlexiCubesGeometry, self).__init__() |
|||
self.grid_res = grid_res |
|||
self.device = device |
|||
self.args = args |
|||
self.fc = FlexiCubes(device, weight_scale=0.5) |
|||
self.verts, self.indices = self.fc.construct_voxel_grid(grid_res) |
|||
if isinstance(scale, list): |
|||
self.verts[:, 0] = self.verts[:, 0] * scale[0] |
|||
self.verts[:, 1] = self.verts[:, 1] * scale[1] |
|||
self.verts[:, 2] = self.verts[:, 2] * scale[1] |
|||
else: |
|||
self.verts = self.verts * scale |
|||
|
|||
all_edges = self.indices[:, self.fc.cube_edges].reshape(-1, 2) |
|||
self.all_edges = torch.unique(all_edges, dim=0) |
|||
|
|||
# Parameters used for fix boundary sdf |
|||
self.center_indices, self.boundary_indices = get_center_boundary_index(self.grid_res, device) |
|||
self.renderer = renderer |
|||
self.render_type = render_type |
|||
|
|||
def getAABB(self): |
|||
return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values |
|||
|
|||
def get_mesh(self, v_deformed_nx3, sdf_n, weight_n=None, with_uv=False, indices=None, is_training=False): |
|||
if indices is None: |
|||
indices = self.indices |
|||
|
|||
verts, faces, v_reg_loss = self.fc(v_deformed_nx3, sdf_n, indices, self.grid_res, |
|||
beta_fx12=weight_n[:, :12], alpha_fx8=weight_n[:, 12:20], |
|||
gamma_f=weight_n[:, 20], training=is_training |
|||
) |
|||
return verts, faces, v_reg_loss |
|||
|
|||
|
|||
def render_mesh(self, mesh_v_nx3, mesh_f_fx3, camera_mv_bx4x4, resolution=256, hierarchical_mask=False): |
|||
return_value = dict() |
|||
if self.render_type == 'neural_render': |
|||
tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal = self.renderer.render_mesh( |
|||
mesh_v_nx3.unsqueeze(dim=0), |
|||
mesh_f_fx3.int(), |
|||
camera_mv_bx4x4, |
|||
mesh_v_nx3.unsqueeze(dim=0), |
|||
resolution=resolution, |
|||
device=self.device, |
|||
hierarchical_mask=hierarchical_mask |
|||
) |
|||
|
|||
return_value['tex_pos'] = tex_pos |
|||
return_value['mask'] = mask |
|||
return_value['hard_mask'] = hard_mask |
|||
return_value['rast'] = rast |
|||
return_value['v_pos_clip'] = v_pos_clip |
|||
return_value['mask_pyramid'] = mask_pyramid |
|||
return_value['depth'] = depth |
|||
return_value['normal'] = normal |
|||
else: |
|||
raise NotImplementedError |
|||
|
|||
return return_value |
|||
|
|||
def render(self, v_deformed_bxnx3=None, sdf_bxn=None, camera_mv_bxnviewx4x4=None, resolution=256): |
|||
# Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1 |
|||
v_list = [] |
|||
f_list = [] |
|||
n_batch = v_deformed_bxnx3.shape[0] |
|||
all_render_output = [] |
|||
for i_batch in range(n_batch): |
|||
verts_nx3, faces_fx3 = self.get_mesh(v_deformed_bxnx3[i_batch], sdf_bxn[i_batch]) |
|||
v_list.append(verts_nx3) |
|||
f_list.append(faces_fx3) |
|||
render_output = self.render_mesh(verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution) |
|||
all_render_output.append(render_output) |
|||
|
|||
# Concatenate all render output |
|||
return_keys = all_render_output[0].keys() |
|||
return_value = dict() |
|||
for k in return_keys: |
|||
value = [v[k] for v in all_render_output] |
|||
return_value[k] = value |
|||
# We can do concatenation outside of the render |
|||
return return_value |
@ -0,0 +1,791 @@ |
|||
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
|||
# |
|||
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property |
|||
# and proprietary rights in and to this software, related documentation |
|||
# and any modifications thereto. Any use, reproduction, disclosure or |
|||
# distribution of this software and related documentation without an express |
|||
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. |
|||
dmc_table = [ |
|||
[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 2, 8, 11, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 4, 5, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[5, 7, 8, 9, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 5, 7, 8, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 4, 7, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 2, 9, 10, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[5, 7, 8, 9, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 5, 7, 9, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[4, 7, 8, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 9, 10, 11, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[4, 5, 9, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 8, 10, 11, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[5, 7, 8, 9, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 5, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 5, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 8, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 3, 8, 9, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 9, -1, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 4, 5, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 3, 4, 5, 8, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[4, 5, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 2, 6, 7, 8, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 4, 5, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 2, 4, 5, 6, 7, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 2, 3, 5, 6, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 2, 9, 10, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[2, 3, 8, 9, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[4, 6, 8, 11, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 4, 6, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 2, 9, 10, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[2, 3, 4, 6, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1]], |
|||
[[0, 2, 4, 5, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[2, 3, 4, 5, 8, 10, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[5, 6, 8, 9, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 5, 6, 9, 11, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 2, 5, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[2, 3, 5, 6, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 4, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[4, 5, 9, -1, -1, -1, -1], [1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 6, 7, 8, 10, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 4, 5, 6, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[4, 5, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 3, 5, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 5, 6, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 9, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 3, 8, 9, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 4, 7, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 3, 4, 7, 9, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 8, -1, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 2, 8, 11, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 2, 8, 9, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 2, 4, 7, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1]], |
|||
[[1, 2, 4, 7, 9, 11, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[4, 6, 9, 10, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 2, 8, 11, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 4, 6, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 2, 4, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[6, 7, 8, 9, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 2, 6, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 6, 7, 8, 10, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 2, 6, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[4, 7, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 4, 7, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 2, 5, 6, 9, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[2, 3, 4, 5, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 2, 3, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[4, 7, 8, -1, -1, -1, -1], [1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 4, 5, 6, 7, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 5, 6, 9, 11, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[4, 5, 6, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 4, 6, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 3, 6, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 6, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 8, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 9, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 3, 8, 9, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 9, -1, -1, -1, -1], [4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 3, 4, 5, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 8, -1, -1, -1, -1], [4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 3, 4, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 9, -1, -1, -1, -1], [2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 2, 5, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 2, 4, 5, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 2, 4, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 2, 3, 4, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 2, 3, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[2, 3, 5, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 2, 3, 4, 5, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 2, 4, 5, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[2, 3, 4, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 2, 3, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 4, 5, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 4, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], |
|||
[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]] |
|||
] |
|||
num_vd_table = [0, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 3, 1, 2, 2, |
|||
2, 1, 2, 1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 1, 2, 3, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2, |
|||
1, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 2, 3, 2, 2, 1, 1, 1, 1, |
|||
1, 1, 2, 1, 1, 1, 2, 1, 2, 2, 2, 1, 1, 1, 1, 1, 2, 3, 2, 2, 2, 2, 2, 1, 3, 4, 2, |
|||
2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2, |
|||
3, 2, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 3, 2, 3, 2, 4, 2, 2, 2, 2, 1, 2, 1, 2, 1, 1, |
|||
2, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, |
|||
1, 2, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, |
|||
1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, |
|||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0] |
|||
check_table = [ |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[1, 1, 0, 0, 194], |
|||
[1, -1, 0, 0, 193], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[1, 0, 1, 0, 164], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[1, 0, -1, 0, 161], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[1, 0, 0, 1, 152], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[1, 0, 0, 1, 145], |
|||
[1, 0, 0, 1, 144], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[1, 0, 0, -1, 137], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[1, 0, 1, 0, 133], |
|||
[1, 0, 1, 0, 132], |
|||
[1, 1, 0, 0, 131], |
|||
[1, 1, 0, 0, 130], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[1, 0, 0, 1, 100], |
|||
[0, 0, 0, 0, 0], |
|||
[1, 0, 0, 1, 98], |
|||
[0, 0, 0, 0, 0], |
|||
[1, 0, 0, 1, 96], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[1, 0, 1, 0, 88], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[1, 0, -1, 0, 82], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[1, 0, 1, 0, 74], |
|||
[0, 0, 0, 0, 0], |
|||
[1, 0, 1, 0, 72], |
|||
[0, 0, 0, 0, 0], |
|||
[1, 0, 0, -1, 70], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[1, -1, 0, 0, 67], |
|||
[0, 0, 0, 0, 0], |
|||
[1, -1, 0, 0, 65], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[1, 1, 0, 0, 56], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[1, -1, 0, 0, 52], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[1, 1, 0, 0, 44], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[1, 1, 0, 0, 40], |
|||
[0, 0, 0, 0, 0], |
|||
[1, 0, 0, -1, 38], |
|||
[1, 0, -1, 0, 37], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[1, 0, -1, 0, 33], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[1, -1, 0, 0, 28], |
|||
[0, 0, 0, 0, 0], |
|||
[1, 0, -1, 0, 26], |
|||
[1, 0, 0, -1, 25], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[1, -1, 0, 0, 20], |
|||
[0, 0, 0, 0, 0], |
|||
[1, 0, -1, 0, 18], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[1, 0, 0, -1, 9], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[1, 0, 0, -1, 6], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0] |
|||
] |
|||
tet_table = [ |
|||
[-1, -1, -1, -1, -1, -1], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[4, 4, 4, 4, 4, 4], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[4, 0, 0, 4, 4, -1], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[4, 4, 4, 4, 4, 4], |
|||
[0, 4, 0, 4, 4, -1], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[5, 5, 5, 5, 5, 5], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[2, 2, 2, 2, 2, 2], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[2, 0, 2, -1, 0, 2], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[2, -1, 2, 4, 4, 2], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[2, 0, 2, 4, 4, 2], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[2, 4, 2, 4, 4, 2], |
|||
[0, 4, 0, 4, 4, 0], |
|||
[2, 0, 2, 0, 0, 2], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[2, 5, 2, 5, 5, 2], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[2, 0, 2, 0, 0, 2], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[0, 1, 1, -1, 0, 1], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[2, 2, 2, 2, 2, 2], |
|||
[4, 1, 1, 4, 4, 1], |
|||
[0, 1, 1, 0, 0, 1], |
|||
[4, 0, 0, 4, 4, 0], |
|||
[2, 2, 2, 2, 2, 2], |
|||
[-1, 1, 1, 4, 4, 1], |
|||
[0, 1, 1, 4, 4, 1], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[2, 2, 2, 2, 2, 2], |
|||
[5, 1, 1, 5, 5, 1], |
|||
[0, 1, 1, 0, 0, 1], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[2, 2, 2, 2, 2, 2], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[8, 8, 8, 8, 8, 8], |
|||
[1, 1, 1, 4, 4, 1], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[4, 0, 0, 4, 4, 0], |
|||
[4, 4, 4, 4, 4, 4], |
|||
[1, 1, 1, 4, 4, 1], |
|||
[0, 4, 0, 4, 4, 0], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[4, 4, 4, 4, 4, 4], |
|||
[1, 1, 1, 5, 5, 1], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[5, 5, 5, 5, 5, 5], |
|||
[6, 6, 6, 6, 6, 6], |
|||
[6, -1, 0, 6, 0, 6], |
|||
[6, 0, 0, 6, 0, 6], |
|||
[6, 1, 1, 6, 1, 6], |
|||
[4, 4, 4, 4, 4, 4], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[4, 0, 0, 4, 4, 4], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[6, 4, -1, 6, 4, 6], |
|||
[6, 4, 0, 6, 4, 6], |
|||
[6, 0, 0, 6, 0, 6], |
|||
[6, 1, 1, 6, 1, 6], |
|||
[5, 5, 5, 5, 5, 5], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[2, 2, 2, 2, 2, 2], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[2, 0, 2, 2, 0, 2], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[2, 2, 2, 2, 2, 2], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[2, 0, 2, 2, 2, 2], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[2, 4, 2, 2, 4, 2], |
|||
[0, 4, 0, 4, 4, 0], |
|||
[2, 0, 2, 2, 0, 2], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[2, 2, 2, 2, 2, 2], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[6, 1, 1, 6, -1, 6], |
|||
[6, 1, 1, 6, 0, 6], |
|||
[6, 0, 0, 6, 0, 6], |
|||
[6, 2, 2, 6, 2, 6], |
|||
[4, 1, 1, 4, 4, 1], |
|||
[0, 1, 1, 0, 0, 1], |
|||
[4, 0, 0, 4, 4, 4], |
|||
[2, 2, 2, 2, 2, 2], |
|||
[6, 1, 1, 6, 4, 6], |
|||
[6, 1, 1, 6, 4, 6], |
|||
[6, 0, 0, 6, 0, 6], |
|||
[6, 2, 2, 6, 2, 6], |
|||
[5, 1, 1, 5, 5, 1], |
|||
[0, 1, 1, 0, 0, 1], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[2, 2, 2, 2, 2, 2], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[6, 6, 6, 6, 6, 6], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[4, 4, 4, 4, 4, 4], |
|||
[1, 1, 1, 1, 4, 1], |
|||
[0, 4, 0, 4, 4, 0], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[4, 4, 4, 4, 4, 4], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[0, 5, 0, 5, 0, 5], |
|||
[5, 5, 5, 5, 5, 5], |
|||
[5, 5, 5, 5, 5, 5], |
|||
[0, 5, 0, 5, 0, 5], |
|||
[-1, 5, 0, 5, 0, 5], |
|||
[1, 5, 1, 5, 1, 5], |
|||
[4, 5, -1, 5, 4, 5], |
|||
[0, 5, 0, 5, 0, 5], |
|||
[4, 5, 0, 5, 4, 5], |
|||
[1, 5, 1, 5, 1, 5], |
|||
[4, 4, 4, 4, 4, 4], |
|||
[0, 4, 0, 4, 4, 4], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[6, 6, 6, 6, 6, 6], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[2, 5, 2, 5, -1, 5], |
|||
[0, 5, 0, 5, 0, 5], |
|||
[2, 5, 2, 5, 0, 5], |
|||
[1, 5, 1, 5, 1, 5], |
|||
[2, 5, 2, 5, 4, 5], |
|||
[0, 5, 0, 5, 0, 5], |
|||
[2, 5, 2, 5, 4, 5], |
|||
[1, 5, 1, 5, 1, 5], |
|||
[2, 4, 2, 4, 4, 2], |
|||
[0, 4, 0, 4, 4, 4], |
|||
[2, 0, 2, 0, 0, 2], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[2, 6, 2, 6, 6, 2], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[2, 0, 2, 0, 0, 2], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[0, 1, 1, 1, 0, 1], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[2, 2, 2, 2, 2, 2], |
|||
[4, 1, 1, 1, 4, 1], |
|||
[0, 1, 1, 1, 0, 1], |
|||
[4, 0, 0, 4, 4, 0], |
|||
[2, 2, 2, 2, 2, 2], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[0, 1, 1, 1, 1, 1], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[2, 2, 2, 2, 2, 2], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[2, 2, 2, 2, 2, 2], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[5, 5, 5, 5, 5, 5], |
|||
[1, 1, 1, 1, 4, 1], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[4, 0, 0, 4, 4, 0], |
|||
[4, 4, 4, 4, 4, 4], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[4, 4, 4, 4, 4, 4], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[6, 0, 0, 6, 0, 6], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[6, 6, 6, 6, 6, 6], |
|||
[5, 5, 5, 5, 5, 5], |
|||
[5, 5, 0, 5, 0, 5], |
|||
[5, 5, 0, 5, 0, 5], |
|||
[5, 5, 1, 5, 1, 5], |
|||
[4, 4, 4, 4, 4, 4], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[4, 4, 0, 4, 4, 4], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[4, 4, 4, 4, 4, 4], |
|||
[4, 4, 0, 4, 4, 4], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[8, 8, 8, 8, 8, 8], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[2, 2, 2, 2, 2, 2], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[2, 2, 2, 2, 0, 2], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[2, 2, 2, 2, 2, 2], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[2, 2, 2, 2, 2, 2], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[2, 2, 2, 2, 2, 2], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[4, 1, 1, 4, 4, 1], |
|||
[2, 2, 2, 2, 2, 2], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[1, 1, 1, 1, 0, 1], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[2, 2, 2, 2, 2, 2], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[2, 4, 2, 4, 4, 2], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[2, 2, 2, 2, 2, 2], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[2, 2, 2, 2, 2, 2], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[5, 5, 5, 5, 5, 5], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[4, 4, 4, 4, 4, 4], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[4, 4, 4, 4, 4, 4], |
|||
[1, 1, 1, 1, 1, 1], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[0, 0, 0, 0, 0, 0], |
|||
[12, 12, 12, 12, 12, 12] |
|||
] |
@ -0,0 +1,209 @@ |
|||
# Copyright (c) 2023, Zexin He |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# https://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
|
|||
import numpy as np |
|||
import torch |
|||
import torch.nn as nn |
|||
import mcubes |
|||
import nvdiffrast.torch as dr |
|||
from einops import rearrange, repeat |
|||
|
|||
from .encoder.dino_wrapper import DinoWrapper |
|||
from .decoder.transformer import TriplaneTransformer |
|||
from .renderer.synthesizer import TriplaneSynthesizer |
|||
from ..utils.mesh_util import xatlas_uvmap |
|||
|
|||
|
|||
class InstantNeRF(nn.Module): |
|||
""" |
|||
Full model of the large reconstruction model. |
|||
""" |
|||
def __init__( |
|||
self, |
|||
encoder_freeze: bool = False, |
|||
encoder_model_name: str = 'facebook/dino-vitb16', |
|||
encoder_feat_dim: int = 768, |
|||
transformer_dim: int = 1024, |
|||
transformer_layers: int = 16, |
|||
transformer_heads: int = 16, |
|||
triplane_low_res: int = 32, |
|||
triplane_high_res: int = 64, |
|||
triplane_dim: int = 80, |
|||
rendering_samples_per_ray: int = 128, |
|||
): |
|||
super().__init__() |
|||
|
|||
# modules |
|||
self.encoder = DinoWrapper( |
|||
model_name=encoder_model_name, |
|||
freeze=encoder_freeze, |
|||
) |
|||
|
|||
self.transformer = TriplaneTransformer( |
|||
inner_dim=transformer_dim, |
|||
num_layers=transformer_layers, |
|||
num_heads=transformer_heads, |
|||
image_feat_dim=encoder_feat_dim, |
|||
triplane_low_res=triplane_low_res, |
|||
triplane_high_res=triplane_high_res, |
|||
triplane_dim=triplane_dim, |
|||
) |
|||
|
|||
self.synthesizer = TriplaneSynthesizer( |
|||
triplane_dim=triplane_dim, |
|||
samples_per_ray=rendering_samples_per_ray, |
|||
) |
|||
|
|||
def forward_planes(self, images, cameras): |
|||
# images: [B, V, C_img, H_img, W_img] |
|||
# cameras: [B, V, 16] |
|||
B = images.shape[0] |
|||
|
|||
# encode images |
|||
image_feats = self.encoder(images, cameras) |
|||
image_feats = rearrange(image_feats, '(b v) l d -> b (v l) d', b=B) |
|||
|
|||
# transformer generating planes |
|||
planes = self.transformer(image_feats) |
|||
|
|||
return planes |
|||
|
|||
def forward_synthesizer(self, planes, render_cameras, render_size: int): |
|||
render_results = self.synthesizer( |
|||
planes, |
|||
render_cameras, |
|||
render_size, |
|||
) |
|||
return render_results |
|||
|
|||
def forward(self, images, cameras, render_cameras, render_size: int): |
|||
# images: [B, V, C_img, H_img, W_img] |
|||
# cameras: [B, V, 16] |
|||
# render_cameras: [B, M, D_cam_render] |
|||
# render_size: int |
|||
B, M = render_cameras.shape[:2] |
|||
|
|||
planes = self.forward_planes(images, cameras) |
|||
|
|||
# render target views |
|||
render_results = self.synthesizer(planes, render_cameras, render_size) |
|||
|
|||
return { |
|||
'planes': planes, |
|||
**render_results, |
|||
} |
|||
|
|||
def get_texture_prediction(self, planes, tex_pos, hard_mask=None): |
|||
''' |
|||
Predict Texture given triplanes |
|||
:param planes: the triplane feature map |
|||
:param tex_pos: Position we want to query the texture field |
|||
:param hard_mask: 2D silhoueete of the rendered image |
|||
''' |
|||
tex_pos = torch.cat(tex_pos, dim=0) |
|||
if not hard_mask is None: |
|||
tex_pos = tex_pos * hard_mask.float() |
|||
batch_size = tex_pos.shape[0] |
|||
tex_pos = tex_pos.reshape(batch_size, -1, 3) |
|||
################### |
|||
# We use mask to get the texture location (to save the memory) |
|||
if hard_mask is not None: |
|||
n_point_list = torch.sum(hard_mask.long().reshape(hard_mask.shape[0], -1), dim=-1) |
|||
sample_tex_pose_list = [] |
|||
max_point = n_point_list.max() |
|||
expanded_hard_mask = hard_mask.reshape(batch_size, -1, 1).expand(-1, -1, 3) > 0.5 |
|||
for i in range(tex_pos.shape[0]): |
|||
tex_pos_one_shape = tex_pos[i][expanded_hard_mask[i]].reshape(1, -1, 3) |
|||
if tex_pos_one_shape.shape[1] < max_point: |
|||
tex_pos_one_shape = torch.cat( |
|||
[tex_pos_one_shape, torch.zeros( |
|||
1, max_point - tex_pos_one_shape.shape[1], 3, |
|||
device=tex_pos_one_shape.device, dtype=torch.float32)], dim=1) |
|||
sample_tex_pose_list.append(tex_pos_one_shape) |
|||
tex_pos = torch.cat(sample_tex_pose_list, dim=0) |
|||
|
|||
tex_feat = torch.utils.checkpoint.checkpoint( |
|||
self.synthesizer.forward_points, |
|||
planes, |
|||
tex_pos, |
|||
use_reentrant=False, |
|||
)['rgb'] |
|||
|
|||
if hard_mask is not None: |
|||
final_tex_feat = torch.zeros( |
|||
planes.shape[0], hard_mask.shape[1] * hard_mask.shape[2], tex_feat.shape[-1], device=tex_feat.device) |
|||
expanded_hard_mask = hard_mask.reshape(hard_mask.shape[0], -1, 1).expand(-1, -1, final_tex_feat.shape[-1]) > 0.5 |
|||
for i in range(planes.shape[0]): |
|||
final_tex_feat[i][expanded_hard_mask[i]] = tex_feat[i][:n_point_list[i]].reshape(-1) |
|||
tex_feat = final_tex_feat |
|||
|
|||
return tex_feat.reshape(planes.shape[0], hard_mask.shape[1], hard_mask.shape[2], tex_feat.shape[-1]) |
|||
|
|||
def extract_mesh( |
|||
self, |
|||
planes: torch.Tensor, |
|||
mesh_resolution: int = 256, |
|||
mesh_threshold: int = 10.0, |
|||
use_texture_map: bool = False, |
|||
texture_resolution: int = 1024, |
|||
**kwargs, |
|||
): |
|||
''' |
|||
Extract a 3D mesh from triplane nerf. Only support batch_size 1. |
|||
:param planes: triplane features |
|||
:param mesh_resolution: marching cubes resolution |
|||
:param mesh_threshold: iso-surface threshold |
|||
:param use_texture_map: use texture map or vertex color |
|||
:param texture_resolution: the resolution of texture map |
|||
''' |
|||
assert planes.shape[0] == 1 |
|||
device = planes.device |
|||
|
|||
grid_out = self.synthesizer.forward_grid( |
|||
planes=planes, |
|||
grid_size=mesh_resolution, |
|||
) |
|||
|
|||
vertices, faces = mcubes.marching_cubes( |
|||
grid_out['sigma'].squeeze(0).squeeze(-1).cpu().numpy(), |
|||
mesh_threshold, |
|||
) |
|||
vertices = vertices / (mesh_resolution - 1) * 2 - 1 |
|||
|
|||
if not use_texture_map: |
|||
# query vertex colors |
|||
vertices_tensor = torch.tensor(vertices, dtype=torch.float32, device=device).unsqueeze(0) |
|||
vertices_colors = self.synthesizer.forward_points( |
|||
planes, vertices_tensor)['rgb'].squeeze(0).cpu().numpy() |
|||
vertices_colors = (vertices_colors * 255).astype(np.uint8) |
|||
|
|||
return vertices, faces, vertices_colors |
|||
|
|||
# use x-atlas to get uv mapping for the mesh |
|||
vertices = torch.tensor(vertices, dtype=torch.float32, device=device) |
|||
faces = torch.tensor(faces.astype(int), dtype=torch.long, device=device) |
|||
|
|||
ctx = dr.RasterizeCudaContext(device=device) |
|||
uvs, mesh_tex_idx, gb_pos, tex_hard_mask = xatlas_uvmap( |
|||
ctx, vertices, faces, resolution=texture_resolution) |
|||
tex_hard_mask = tex_hard_mask.float() |
|||
|
|||
# query the texture field to get the RGB color for texture map |
|||
tex_feat = self.get_texture_prediction( |
|||
planes, [gb_pos], tex_hard_mask) |
|||
background_feature = torch.zeros_like(tex_feat) |
|||
img_feat = torch.lerp(background_feature, tex_feat, tex_hard_mask) |
|||
texture_map = img_feat.permute(0, 3, 1, 2).squeeze(0) |
|||
|
|||
return vertices, faces, uvs, mesh_tex_idx, texture_map |
@ -0,0 +1,382 @@ |
|||
# Copyright (c) 2023, Tencent Inc |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# https://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
|
|||
import numpy as np |
|||
import torch |
|||
import torch.nn as nn |
|||
import nvdiffrast.torch as dr |
|||
from einops import rearrange, repeat |
|||
|
|||
from .encoder.dino_wrapper import DinoWrapper |
|||
from .decoder.transformer import TriplaneTransformer |
|||
from .renderer.synthesizer_mesh import TriplaneSynthesizer |
|||
from .geometry.camera.perspective_camera import PerspectiveCamera |
|||
from .geometry.render.neural_render import NeuralRender |
|||
from .geometry.rep_3d.flexicubes_geometry import FlexiCubesGeometry |
|||
from ..utils.mesh_util import xatlas_uvmap |
|||
|
|||
|
|||
class InstantMesh(nn.Module): |
|||
""" |
|||
Full model of the large reconstruction model. |
|||
""" |
|||
def __init__( |
|||
self, |
|||
encoder_freeze: bool = False, |
|||
encoder_model_name: str = 'facebook/dino-vitb16', |
|||
encoder_feat_dim: int = 768, |
|||
transformer_dim: int = 1024, |
|||
transformer_layers: int = 16, |
|||
transformer_heads: int = 16, |
|||
triplane_low_res: int = 32, |
|||
triplane_high_res: int = 64, |
|||
triplane_dim: int = 80, |
|||
rendering_samples_per_ray: int = 128, |
|||
grid_res: int = 128, |
|||
grid_scale: float = 2.0, |
|||
): |
|||
super().__init__() |
|||
|
|||
# attributes |
|||
self.grid_res = grid_res |
|||
self.grid_scale = grid_scale |
|||
self.deformation_multiplier = 4.0 |
|||
|
|||
# modules |
|||
self.encoder = DinoWrapper( |
|||
model_name=encoder_model_name, |
|||
freeze=encoder_freeze, |
|||
) |
|||
|
|||
self.transformer = TriplaneTransformer( |
|||
inner_dim=transformer_dim, |
|||
num_layers=transformer_layers, |
|||
num_heads=transformer_heads, |
|||
image_feat_dim=encoder_feat_dim, |
|||
triplane_low_res=triplane_low_res, |
|||
triplane_high_res=triplane_high_res, |
|||
triplane_dim=triplane_dim, |
|||
) |
|||
|
|||
self.synthesizer = TriplaneSynthesizer( |
|||
triplane_dim=triplane_dim, |
|||
samples_per_ray=rendering_samples_per_ray, |
|||
) |
|||
|
|||
def init_flexicubes_geometry(self, device, fovy=50.0): |
|||
camera = PerspectiveCamera(fovy=fovy, device=device) |
|||
renderer = NeuralRender(device, camera_model=camera) |
|||
self.geometry = FlexiCubesGeometry( |
|||
grid_res=self.grid_res, |
|||
scale=self.grid_scale, |
|||
renderer=renderer, |
|||
render_type='neural_render', |
|||
device=device, |
|||
) |
|||
|
|||
def forward_planes(self, images, cameras): |
|||
# images: [B, V, C_img, H_img, W_img] |
|||
# cameras: [B, V, 16] |
|||
B = images.shape[0] |
|||
|
|||
# encode images |
|||
image_feats = self.encoder(images, cameras) |
|||
image_feats = rearrange(image_feats, '(b v) l d -> b (v l) d', b=B) |
|||
|
|||
# decode triplanes |
|||
planes = self.transformer(image_feats) |
|||
|
|||
return planes |
|||
|
|||
def get_sdf_deformation_prediction(self, planes): |
|||
''' |
|||
Predict SDF and deformation for tetrahedron vertices |
|||
:param planes: triplane feature map for the geometry |
|||
''' |
|||
init_position = self.geometry.verts.unsqueeze(0).expand(planes.shape[0], -1, -1) |
|||
|
|||
# Step 1: predict the SDF and deformation |
|||
sdf, deformation, weight = torch.utils.checkpoint.checkpoint( |
|||
self.synthesizer.get_geometry_prediction, |
|||
planes, |
|||
init_position, |
|||
self.geometry.indices, |
|||
use_reentrant=False, |
|||
) |
|||
|
|||
# Step 2: Normalize the deformation to avoid the flipped triangles. |
|||
deformation = 1.0 / (self.grid_res * self.deformation_multiplier) * torch.tanh(deformation) |
|||
sdf_reg_loss = torch.zeros(sdf.shape[0], device=sdf.device, dtype=torch.float32) |
|||
|
|||
#### |
|||
# Step 3: Fix some sdf if we observe empty shape (full positive or full negative) |
|||
sdf_bxnxnxn = sdf.reshape((sdf.shape[0], self.grid_res + 1, self.grid_res + 1, self.grid_res + 1)) |
|||
sdf_less_boundary = sdf_bxnxnxn[:, 1:-1, 1:-1, 1:-1].reshape(sdf.shape[0], -1) |
|||
pos_shape = torch.sum((sdf_less_boundary > 0).int(), dim=-1) |
|||
neg_shape = torch.sum((sdf_less_boundary < 0).int(), dim=-1) |
|||
zero_surface = torch.bitwise_or(pos_shape == 0, neg_shape == 0) |
|||
if torch.sum(zero_surface).item() > 0: |
|||
update_sdf = torch.zeros_like(sdf[0:1]) |
|||
max_sdf = sdf.max() |
|||
min_sdf = sdf.min() |
|||
update_sdf[:, self.geometry.center_indices] += (1.0 - min_sdf) # greater than zero |
|||
update_sdf[:, self.geometry.boundary_indices] += (-1 - max_sdf) # smaller than zero |
|||
new_sdf = torch.zeros_like(sdf) |
|||
for i_batch in range(zero_surface.shape[0]): |
|||
if zero_surface[i_batch]: |
|||
new_sdf[i_batch:i_batch + 1] += update_sdf |
|||
update_mask = (new_sdf == 0).float() |
|||
# Regulraization here is used to push the sdf to be a different sign (make it not fully positive or fully negative) |
|||
sdf_reg_loss = torch.abs(sdf).mean(dim=-1).mean(dim=-1) |
|||
sdf_reg_loss = sdf_reg_loss * zero_surface.float() |
|||
sdf = sdf * update_mask + new_sdf * (1 - update_mask) |
|||
|
|||
# Step 4: Here we remove the gradient for the bad sdf (full positive or full negative) |
|||
final_sdf = [] |
|||
final_def = [] |
|||
for i_batch in range(zero_surface.shape[0]): |
|||
if zero_surface[i_batch]: |
|||
final_sdf.append(sdf[i_batch: i_batch + 1].detach()) |
|||
final_def.append(deformation[i_batch: i_batch + 1].detach()) |
|||
else: |
|||
final_sdf.append(sdf[i_batch: i_batch + 1]) |
|||
final_def.append(deformation[i_batch: i_batch + 1]) |
|||
sdf = torch.cat(final_sdf, dim=0) |
|||
deformation = torch.cat(final_def, dim=0) |
|||
return sdf, deformation, sdf_reg_loss, weight |
|||
|
|||
def get_geometry_prediction(self, planes=None): |
|||
''' |
|||
Function to generate mesh with give triplanes |
|||
:param planes: triplane features |
|||
''' |
|||
# Step 1: first get the sdf and deformation value for each vertices in the tetrahedon grid. |
|||
sdf, deformation, sdf_reg_loss, weight = self.get_sdf_deformation_prediction(planes) |
|||
v_deformed = self.geometry.verts.unsqueeze(dim=0).expand(sdf.shape[0], -1, -1) + deformation |
|||
tets = self.geometry.indices |
|||
n_batch = planes.shape[0] |
|||
v_list = [] |
|||
f_list = [] |
|||
flexicubes_surface_reg_list = [] |
|||
|
|||
# Step 2: Using marching tet to obtain the mesh |
|||
for i_batch in range(n_batch): |
|||
verts, faces, flexicubes_surface_reg = self.geometry.get_mesh( |
|||
v_deformed[i_batch], |
|||
sdf[i_batch].squeeze(dim=-1), |
|||
with_uv=False, |
|||
indices=tets, |
|||
weight_n=weight[i_batch].squeeze(dim=-1), |
|||
is_training=self.training, |
|||
) |
|||
flexicubes_surface_reg_list.append(flexicubes_surface_reg) |
|||
v_list.append(verts) |
|||
f_list.append(faces) |
|||
|
|||
flexicubes_surface_reg = torch.cat(flexicubes_surface_reg_list).mean() |
|||
flexicubes_weight_reg = (weight ** 2).mean() |
|||
|
|||
return v_list, f_list, sdf, deformation, v_deformed, (sdf_reg_loss, flexicubes_surface_reg, flexicubes_weight_reg) |
|||
|
|||
def get_texture_prediction(self, planes, tex_pos, hard_mask=None): |
|||
''' |
|||
Predict Texture given triplanes |
|||
:param planes: the triplane feature map |
|||
:param tex_pos: Position we want to query the texture field |
|||
:param hard_mask: 2D silhoueete of the rendered image |
|||
''' |
|||
tex_pos = torch.cat(tex_pos, dim=0) |
|||
if not hard_mask is None: |
|||
tex_pos = tex_pos * hard_mask.float() |
|||
batch_size = tex_pos.shape[0] |
|||
tex_pos = tex_pos.reshape(batch_size, -1, 3) |
|||
################### |
|||
# We use mask to get the texture location (to save the memory) |
|||
if hard_mask is not None: |
|||
n_point_list = torch.sum(hard_mask.long().reshape(hard_mask.shape[0], -1), dim=-1) |
|||
sample_tex_pose_list = [] |
|||
max_point = n_point_list.max() |
|||
expanded_hard_mask = hard_mask.reshape(batch_size, -1, 1).expand(-1, -1, 3) > 0.5 |
|||
for i in range(tex_pos.shape[0]): |
|||
tex_pos_one_shape = tex_pos[i][expanded_hard_mask[i]].reshape(1, -1, 3) |
|||
if tex_pos_one_shape.shape[1] < max_point: |
|||
tex_pos_one_shape = torch.cat( |
|||
[tex_pos_one_shape, torch.zeros( |
|||
1, max_point - tex_pos_one_shape.shape[1], 3, |
|||
device=tex_pos_one_shape.device, dtype=torch.float32)], dim=1) |
|||
sample_tex_pose_list.append(tex_pos_one_shape) |
|||
tex_pos = torch.cat(sample_tex_pose_list, dim=0) |
|||
|
|||
tex_feat = torch.utils.checkpoint.checkpoint( |
|||
self.synthesizer.get_texture_prediction, |
|||
planes, |
|||
tex_pos, |
|||
use_reentrant=False, |
|||
) |
|||
|
|||
if hard_mask is not None: |
|||
final_tex_feat = torch.zeros( |
|||
planes.shape[0], hard_mask.shape[1] * hard_mask.shape[2], tex_feat.shape[-1], device=tex_feat.device) |
|||
expanded_hard_mask = hard_mask.reshape(hard_mask.shape[0], -1, 1).expand(-1, -1, final_tex_feat.shape[-1]) > 0.5 |
|||
for i in range(planes.shape[0]): |
|||
final_tex_feat[i][expanded_hard_mask[i]] = tex_feat[i][:n_point_list[i]].reshape(-1) |
|||
tex_feat = final_tex_feat |
|||
|
|||
return tex_feat.reshape(planes.shape[0], hard_mask.shape[1], hard_mask.shape[2], tex_feat.shape[-1]) |
|||
|
|||
def render_mesh(self, mesh_v, mesh_f, cam_mv, render_size=256): |
|||
''' |
|||
Function to render a generated mesh with nvdiffrast |
|||
:param mesh_v: List of vertices for the mesh |
|||
:param mesh_f: List of faces for the mesh |
|||
:param cam_mv: 4x4 rotation matrix |
|||
:return: |
|||
''' |
|||
return_value_list = [] |
|||
for i_mesh in range(len(mesh_v)): |
|||
return_value = self.geometry.render_mesh( |
|||
mesh_v[i_mesh], |
|||
mesh_f[i_mesh].int(), |
|||
cam_mv[i_mesh], |
|||
resolution=render_size, |
|||
hierarchical_mask=False |
|||
) |
|||
return_value_list.append(return_value) |
|||
|
|||
return_keys = return_value_list[0].keys() |
|||
return_value = dict() |
|||
for k in return_keys: |
|||
value = [v[k] for v in return_value_list] |
|||
return_value[k] = value |
|||
|
|||
mask = torch.cat(return_value['mask'], dim=0) |
|||
hard_mask = torch.cat(return_value['hard_mask'], dim=0) |
|||
tex_pos = return_value['tex_pos'] |
|||
depth = torch.cat(return_value['depth'], dim=0) |
|||
normal = torch.cat(return_value['normal'], dim=0) |
|||
return mask, hard_mask, tex_pos, depth, normal |
|||
|
|||
def forward_geometry(self, planes, render_cameras, render_size=256): |
|||
''' |
|||
Main function of our Generator. It first generate 3D mesh, then render it into 2D image |
|||
with given `render_cameras`. |
|||
:param planes: triplane features |
|||
:param render_cameras: cameras to render generated 3D shape |
|||
''' |
|||
B, NV = render_cameras.shape[:2] |
|||
|
|||
# Generate 3D mesh first |
|||
mesh_v, mesh_f, sdf, deformation, v_deformed, sdf_reg_loss = self.get_geometry_prediction(planes) |
|||
|
|||
# Render the mesh into 2D image (get 3d position of each image plane) |
|||
cam_mv = render_cameras |
|||
run_n_view = cam_mv.shape[1] |
|||
antilias_mask, hard_mask, tex_pos, depth, normal = self.render_mesh(mesh_v, mesh_f, cam_mv, render_size=render_size) |
|||
|
|||
tex_hard_mask = hard_mask |
|||
tex_pos = [torch.cat([pos[i_view:i_view + 1] for i_view in range(run_n_view)], dim=2) for pos in tex_pos] |
|||
tex_hard_mask = torch.cat( |
|||
[torch.cat( |
|||
[tex_hard_mask[i * run_n_view + i_view: i * run_n_view + i_view + 1] |
|||
for i_view in range(run_n_view)], dim=2) |
|||
for i in range(planes.shape[0])], dim=0) |
|||
|
|||
# Querying the texture field to predict the texture feature for each pixel on the image |
|||
tex_feat = self.get_texture_prediction(planes, tex_pos, tex_hard_mask) |
|||
background_feature = torch.ones_like(tex_feat) # white background |
|||
|
|||
# Merge them together |
|||
img_feat = tex_feat * tex_hard_mask + background_feature * (1 - tex_hard_mask) |
|||
|
|||
# We should split it back to the original image shape |
|||
img_feat = torch.cat( |
|||
[torch.cat( |
|||
[img_feat[i:i + 1, :, render_size * i_view: render_size * (i_view + 1)] |
|||
for i_view in range(run_n_view)], dim=0) for i in range(len(tex_pos))], dim=0) |
|||
|
|||
img = img_feat.clamp(0, 1).permute(0, 3, 1, 2).unflatten(0, (B, NV)) |
|||
antilias_mask = antilias_mask.permute(0, 3, 1, 2).unflatten(0, (B, NV)) |
|||
depth = -depth.permute(0, 3, 1, 2).unflatten(0, (B, NV)) # transform negative depth to positive |
|||
normal = normal.permute(0, 3, 1, 2).unflatten(0, (B, NV)) |
|||
|
|||
out = { |
|||
'img': img, |
|||
'mask': antilias_mask, |
|||
'depth': depth, |
|||
'normal': normal, |
|||
'sdf': sdf, |
|||
'mesh_v': mesh_v, |
|||
'mesh_f': mesh_f, |
|||
'sdf_reg_loss': sdf_reg_loss, |
|||
} |
|||
return out |
|||
|
|||
def forward(self, images, cameras, render_cameras, render_size: int): |
|||
# images: [B, V, C_img, H_img, W_img] |
|||
# cameras: [B, V, 16] |
|||
# render_cameras: [B, M, D_cam_render] |
|||
# render_size: int |
|||
B, M = render_cameras.shape[:2] |
|||
|
|||
planes = self.forward_planes(images, cameras) |
|||
out = self.forward_geometry(planes, render_cameras, render_size=render_size) |
|||
|
|||
return { |
|||
'planes': planes, |
|||
**out |
|||
} |
|||
|
|||
def extract_mesh( |
|||
self, |
|||
planes: torch.Tensor, |
|||
use_texture_map: bool = False, |
|||
texture_resolution: int = 1024, |
|||
**kwargs, |
|||
): |
|||
''' |
|||
Extract a 3D mesh from FlexiCubes. Only support batch_size 1. |
|||
:param planes: triplane features |
|||
:param use_texture_map: use texture map or vertex color |
|||
:param texture_resolution: the resolution of texure map |
|||
''' |
|||
assert planes.shape[0] == 1 |
|||
device = planes.device |
|||
|
|||
# predict geometry first |
|||
mesh_v, mesh_f, sdf, deformation, v_deformed, sdf_reg_loss = self.get_geometry_prediction(planes) |
|||
vertices, faces = mesh_v[0], mesh_f[0] |
|||
|
|||
if not use_texture_map: |
|||
# query vertex colors |
|||
vertices_tensor = vertices.unsqueeze(0) |
|||
vertices_colors = self.synthesizer.get_texture_prediction( |
|||
planes, vertices_tensor).clamp(0, 1).squeeze(0).cpu().numpy() |
|||
vertices_colors = (vertices_colors * 255).astype(np.uint8) |
|||
|
|||
return vertices.cpu().numpy(), faces.cpu().numpy(), vertices_colors |
|||
|
|||
# use x-atlas to get uv mapping for the mesh |
|||
ctx = dr.RasterizeCudaContext(device=device) |
|||
uvs, mesh_tex_idx, gb_pos, tex_hard_mask = xatlas_uvmap( |
|||
self.geometry.renderer.ctx, vertices, faces, resolution=texture_resolution) |
|||
tex_hard_mask = tex_hard_mask.float() |
|||
|
|||
# query the texture field to get the RGB color for texture map |
|||
tex_feat = self.get_texture_prediction( |
|||
planes, [gb_pos], tex_hard_mask) |
|||
background_feature = torch.zeros_like(tex_feat) |
|||
img_feat = torch.lerp(background_feature, tex_feat, tex_hard_mask) |
|||
texture_map = img_feat.permute(0, 3, 1, 2).squeeze(0) |
|||
|
|||
return vertices, faces, uvs, mesh_tex_idx, texture_map |
@ -0,0 +1,9 @@ |
|||
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
|||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary |
|||
# |
|||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual |
|||
# property and proprietary rights in and to this material, related |
|||
# documentation and any modifications thereto. Any use, reproduction, |
|||
# disclosure or distribution of this material and related documentation |
|||
# without an express license agreement from NVIDIA CORPORATION or |
|||
# its affiliates is strictly prohibited. |
@ -0,0 +1,203 @@ |
|||
# ORIGINAL LICENSE |
|||
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
|||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary |
|||
# |
|||
# Modified by Jiale Xu |
|||
# The modifications are subject to the same license as the original. |
|||
|
|||
|
|||
import itertools |
|||
import torch |
|||
import torch.nn as nn |
|||
|
|||
from .utils.renderer import ImportanceRenderer |
|||
from .utils.ray_sampler import RaySampler |
|||
|
|||
|
|||
class OSGDecoder(nn.Module): |
|||
""" |
|||
Triplane decoder that gives RGB and sigma values from sampled features. |
|||
Using ReLU here instead of Softplus in the original implementation. |
|||
|
|||
Reference: |
|||
EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112 |
|||
""" |
|||
def __init__(self, n_features: int, |
|||
hidden_dim: int = 64, num_layers: int = 4, activation: nn.Module = nn.ReLU): |
|||
super().__init__() |
|||
self.net = nn.Sequential( |
|||
nn.Linear(3 * n_features, hidden_dim), |
|||
activation(), |
|||
*itertools.chain(*[[ |
|||
nn.Linear(hidden_dim, hidden_dim), |
|||
activation(), |
|||
] for _ in range(num_layers - 2)]), |
|||
nn.Linear(hidden_dim, 1 + 3), |
|||
) |
|||
# init all bias to zero |
|||
for m in self.modules(): |
|||
if isinstance(m, nn.Linear): |
|||
nn.init.zeros_(m.bias) |
|||
|
|||
def forward(self, sampled_features, ray_directions): |
|||
# Aggregate features by mean |
|||
# sampled_features = sampled_features.mean(1) |
|||
# Aggregate features by concatenation |
|||
_N, n_planes, _M, _C = sampled_features.shape |
|||
sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C) |
|||
x = sampled_features |
|||
|
|||
N, M, C = x.shape |
|||
x = x.contiguous().view(N*M, C) |
|||
|
|||
x = self.net(x) |
|||
x = x.view(N, M, -1) |
|||
rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF |
|||
sigma = x[..., 0:1] |
|||
|
|||
return {'rgb': rgb, 'sigma': sigma} |
|||
|
|||
|
|||
class TriplaneSynthesizer(nn.Module): |
|||
""" |
|||
Synthesizer that renders a triplane volume with planes and a camera. |
|||
|
|||
Reference: |
|||
EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19 |
|||
""" |
|||
|
|||
DEFAULT_RENDERING_KWARGS = { |
|||
'ray_start': 'auto', |
|||
'ray_end': 'auto', |
|||
'box_warp': 2., |
|||
'white_back': True, |
|||
'disparity_space_sampling': False, |
|||
'clamp_mode': 'softplus', |
|||
'sampler_bbox_min': -1., |
|||
'sampler_bbox_max': 1., |
|||
} |
|||
|
|||
def __init__(self, triplane_dim: int, samples_per_ray: int): |
|||
super().__init__() |
|||
|
|||
# attributes |
|||
self.triplane_dim = triplane_dim |
|||
self.rendering_kwargs = { |
|||
**self.DEFAULT_RENDERING_KWARGS, |
|||
'depth_resolution': samples_per_ray // 2, |
|||
'depth_resolution_importance': samples_per_ray // 2, |
|||
} |
|||
|
|||
# renderings |
|||
self.renderer = ImportanceRenderer() |
|||
self.ray_sampler = RaySampler() |
|||
|
|||
# modules |
|||
self.decoder = OSGDecoder(n_features=triplane_dim) |
|||
|
|||
def forward(self, planes, cameras, render_size=128, crop_params=None): |
|||
# planes: (N, 3, D', H', W') |
|||
# cameras: (N, M, D_cam) |
|||
# render_size: int |
|||
assert planes.shape[0] == cameras.shape[0], "Batch size mismatch for planes and cameras" |
|||
N, M = cameras.shape[:2] |
|||
|
|||
cam2world_matrix = cameras[..., :16].view(N, M, 4, 4) |
|||
intrinsics = cameras[..., 16:25].view(N, M, 3, 3) |
|||
|
|||
# Create a batch of rays for volume rendering |
|||
ray_origins, ray_directions = self.ray_sampler( |
|||
cam2world_matrix=cam2world_matrix.reshape(-1, 4, 4), |
|||
intrinsics=intrinsics.reshape(-1, 3, 3), |
|||
render_size=render_size, |
|||
) |
|||
assert N*M == ray_origins.shape[0], "Batch size mismatch for ray_origins" |
|||
assert ray_origins.dim() == 3, "ray_origins should be 3-dimensional" |
|||
|
|||
# Crop rays if crop_params is available |
|||
if crop_params is not None: |
|||
ray_origins = ray_origins.reshape(N*M, render_size, render_size, 3) |
|||
ray_directions = ray_directions.reshape(N*M, render_size, render_size, 3) |
|||
i, j, h, w = crop_params |
|||
ray_origins = ray_origins[:, i:i+h, j:j+w, :].reshape(N*M, -1, 3) |
|||
ray_directions = ray_directions[:, i:i+h, j:j+w, :].reshape(N*M, -1, 3) |
|||
|
|||
# Perform volume rendering |
|||
rgb_samples, depth_samples, weights_samples = self.renderer( |
|||
planes.repeat_interleave(M, dim=0), self.decoder, ray_origins, ray_directions, self.rendering_kwargs, |
|||
) |
|||
|
|||
# Reshape into 'raw' neural-rendered image |
|||
if crop_params is not None: |
|||
Himg, Wimg = crop_params[2:] |
|||
else: |
|||
Himg = Wimg = render_size |
|||
rgb_images = rgb_samples.permute(0, 2, 1).reshape(N, M, rgb_samples.shape[-1], Himg, Wimg).contiguous() |
|||
depth_images = depth_samples.permute(0, 2, 1).reshape(N, M, 1, Himg, Wimg) |
|||
weight_images = weights_samples.permute(0, 2, 1).reshape(N, M, 1, Himg, Wimg) |
|||
|
|||
out = { |
|||
'images_rgb': rgb_images, |
|||
'images_depth': depth_images, |
|||
'images_weight': weight_images, |
|||
} |
|||
return out |
|||
|
|||
def forward_grid(self, planes, grid_size: int, aabb: torch.Tensor = None): |
|||
# planes: (N, 3, D', H', W') |
|||
# grid_size: int |
|||
# aabb: (N, 2, 3) |
|||
if aabb is None: |
|||
aabb = torch.tensor([ |
|||
[self.rendering_kwargs['sampler_bbox_min']] * 3, |
|||
[self.rendering_kwargs['sampler_bbox_max']] * 3, |
|||
], device=planes.device, dtype=planes.dtype).unsqueeze(0).repeat(planes.shape[0], 1, 1) |
|||
assert planes.shape[0] == aabb.shape[0], "Batch size mismatch for planes and aabb" |
|||
N = planes.shape[0] |
|||
|
|||
# create grid points for triplane query |
|||
grid_points = [] |
|||
for i in range(N): |
|||
grid_points.append(torch.stack(torch.meshgrid( |
|||
torch.linspace(aabb[i, 0, 0], aabb[i, 1, 0], grid_size, device=planes.device), |
|||
torch.linspace(aabb[i, 0, 1], aabb[i, 1, 1], grid_size, device=planes.device), |
|||
torch.linspace(aabb[i, 0, 2], aabb[i, 1, 2], grid_size, device=planes.device), |
|||
indexing='ij', |
|||
), dim=-1).reshape(-1, 3)) |
|||
cube_grid = torch.stack(grid_points, dim=0).to(planes.device) |
|||
|
|||
features = self.forward_points(planes, cube_grid) |
|||
|
|||
# reshape into grid |
|||
features = { |
|||
k: v.reshape(N, grid_size, grid_size, grid_size, -1) |
|||
for k, v in features.items() |
|||
} |
|||
return features |
|||
|
|||
def forward_points(self, planes, points: torch.Tensor, chunk_size: int = 2**20): |
|||
# planes: (N, 3, D', H', W') |
|||
# points: (N, P, 3) |
|||
N, P = points.shape[:2] |
|||
|
|||
# query triplane in chunks |
|||
outs = [] |
|||
for i in range(0, points.shape[1], chunk_size): |
|||
chunk_points = points[:, i:i+chunk_size] |
|||
|
|||
# query triplane |
|||
chunk_out = self.renderer.run_model_activated( |
|||
planes=planes, |
|||
decoder=self.decoder, |
|||
sample_coordinates=chunk_points, |
|||
sample_directions=torch.zeros_like(chunk_points), |
|||
options=self.rendering_kwargs, |
|||
) |
|||
outs.append(chunk_out) |
|||
|
|||
# concatenate the outputs |
|||
point_features = { |
|||
k: torch.cat([out[k] for out in outs], dim=1) |
|||
for k in outs[0].keys() |
|||
} |
|||
return point_features |
@ -0,0 +1,141 @@ |
|||
# ORIGINAL LICENSE |
|||
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
|||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary |
|||
# |
|||
# Modified by Jiale Xu |
|||
# The modifications are subject to the same license as the original. |
|||
|
|||
import itertools |
|||
import torch |
|||
import torch.nn as nn |
|||
|
|||
from .utils.renderer import generate_planes, project_onto_planes, sample_from_planes |
|||
|
|||
|
|||
class OSGDecoder(nn.Module): |
|||
""" |
|||
Triplane decoder that gives RGB and sigma values from sampled features. |
|||
Using ReLU here instead of Softplus in the original implementation. |
|||
|
|||
Reference: |
|||
EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112 |
|||
""" |
|||
def __init__(self, n_features: int, |
|||
hidden_dim: int = 64, num_layers: int = 4, activation: nn.Module = nn.ReLU): |
|||
super().__init__() |
|||
|
|||
self.net_sdf = nn.Sequential( |
|||
nn.Linear(3 * n_features, hidden_dim), |
|||
activation(), |
|||
*itertools.chain(*[[ |
|||
nn.Linear(hidden_dim, hidden_dim), |
|||
activation(), |
|||
] for _ in range(num_layers - 2)]), |
|||
nn.Linear(hidden_dim, 1), |
|||
) |
|||
self.net_rgb = nn.Sequential( |
|||
nn.Linear(3 * n_features, hidden_dim), |
|||
activation(), |
|||
*itertools.chain(*[[ |
|||
nn.Linear(hidden_dim, hidden_dim), |
|||
activation(), |
|||
] for _ in range(num_layers - 2)]), |
|||
nn.Linear(hidden_dim, 3), |
|||
) |
|||
self.net_deformation = nn.Sequential( |
|||
nn.Linear(3 * n_features, hidden_dim), |
|||
activation(), |
|||
*itertools.chain(*[[ |
|||
nn.Linear(hidden_dim, hidden_dim), |
|||
activation(), |
|||
] for _ in range(num_layers - 2)]), |
|||
nn.Linear(hidden_dim, 3), |
|||
) |
|||
self.net_weight = nn.Sequential( |
|||
nn.Linear(8 * 3 * n_features, hidden_dim), |
|||
activation(), |
|||
*itertools.chain(*[[ |
|||
nn.Linear(hidden_dim, hidden_dim), |
|||
activation(), |
|||
] for _ in range(num_layers - 2)]), |
|||
nn.Linear(hidden_dim, 21), |
|||
) |
|||
|
|||
# init all bias to zero |
|||
for m in self.modules(): |
|||
if isinstance(m, nn.Linear): |
|||
nn.init.zeros_(m.bias) |
|||
|
|||
def get_geometry_prediction(self, sampled_features, flexicubes_indices): |
|||
_N, n_planes, _M, _C = sampled_features.shape |
|||
sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C) |
|||
|
|||
sdf = self.net_sdf(sampled_features) |
|||
deformation = self.net_deformation(sampled_features) |
|||
|
|||
grid_features = torch.index_select(input=sampled_features, index=flexicubes_indices.reshape(-1), dim=1) |
|||
grid_features = grid_features.reshape( |
|||
sampled_features.shape[0], flexicubes_indices.shape[0], flexicubes_indices.shape[1] * sampled_features.shape[-1]) |
|||
weight = self.net_weight(grid_features) * 0.1 |
|||
|
|||
return sdf, deformation, weight |
|||
|
|||
def get_texture_prediction(self, sampled_features): |
|||
_N, n_planes, _M, _C = sampled_features.shape |
|||
sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C) |
|||
|
|||
rgb = self.net_rgb(sampled_features) |
|||
rgb = torch.sigmoid(rgb)*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF |
|||
|
|||
return rgb |
|||
|
|||
|
|||
class TriplaneSynthesizer(nn.Module): |
|||
""" |
|||
Synthesizer that renders a triplane volume with planes and a camera. |
|||
|
|||
Reference: |
|||
EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19 |
|||
""" |
|||
|
|||
DEFAULT_RENDERING_KWARGS = { |
|||
'ray_start': 'auto', |
|||
'ray_end': 'auto', |
|||
'box_warp': 2., |
|||
'white_back': True, |
|||
'disparity_space_sampling': False, |
|||
'clamp_mode': 'softplus', |
|||
'sampler_bbox_min': -1., |
|||
'sampler_bbox_max': 1., |
|||
} |
|||
|
|||
def __init__(self, triplane_dim: int, samples_per_ray: int): |
|||
super().__init__() |
|||
|
|||
# attributes |
|||
self.triplane_dim = triplane_dim |
|||
self.rendering_kwargs = { |
|||
**self.DEFAULT_RENDERING_KWARGS, |
|||
'depth_resolution': samples_per_ray // 2, |
|||
'depth_resolution_importance': samples_per_ray // 2, |
|||
} |
|||
|
|||
# modules |
|||
self.plane_axes = generate_planes() |
|||
self.decoder = OSGDecoder(n_features=triplane_dim) |
|||
|
|||
def get_geometry_prediction(self, planes, sample_coordinates, flexicubes_indices): |
|||
plane_axes = self.plane_axes.to(planes.device) |
|||
sampled_features = sample_from_planes( |
|||
plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=self.rendering_kwargs['box_warp']) |
|||
|
|||
sdf, deformation, weight = self.decoder.get_geometry_prediction(sampled_features, flexicubes_indices) |
|||
return sdf, deformation, weight |
|||
|
|||
def get_texture_prediction(self, planes, sample_coordinates): |
|||
plane_axes = self.plane_axes.to(planes.device) |
|||
sampled_features = sample_from_planes( |
|||
plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=self.rendering_kwargs['box_warp']) |
|||
|
|||
rgb = self.decoder.get_texture_prediction(sampled_features) |
|||
return rgb |
@ -0,0 +1,9 @@ |
|||
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
|||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary |
|||
# |
|||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual |
|||
# property and proprietary rights in and to this material, related |
|||
# documentation and any modifications thereto. Any use, reproduction, |
|||
# disclosure or distribution of this material and related documentation |
|||
# without an express license agreement from NVIDIA CORPORATION or |
|||
# its affiliates is strictly prohibited. |
@ -0,0 +1,118 @@ |
|||
# MIT License |
|||
|
|||
# Copyright (c) 2022 Petr Kellnhofer |
|||
|
|||
# Permission is hereby granted, free of charge, to any person obtaining a copy |
|||
# of this software and associated documentation files (the "Software"), to deal |
|||
# in the Software without restriction, including without limitation the rights |
|||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
|||
# copies of the Software, and to permit persons to whom the Software is |
|||
# furnished to do so, subject to the following conditions: |
|||
|
|||
# The above copyright notice and this permission notice shall be included in all |
|||
# copies or substantial portions of the Software. |
|||
|
|||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
|||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
|||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
|||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
|||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
|||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
|||
# SOFTWARE. |
|||
|
|||
import torch |
|||
|
|||
def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor: |
|||
""" |
|||
Left-multiplies MxM @ NxM. Returns NxM. |
|||
""" |
|||
res = torch.matmul(vectors4, matrix.T) |
|||
return res |
|||
|
|||
|
|||
def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor: |
|||
""" |
|||
Normalize vector lengths. |
|||
""" |
|||
return vectors / (torch.norm(vectors, dim=-1, keepdim=True)) |
|||
|
|||
def torch_dot(x: torch.Tensor, y: torch.Tensor): |
|||
""" |
|||
Dot product of two tensors. |
|||
""" |
|||
return (x * y).sum(-1) |
|||
|
|||
|
|||
def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_length): |
|||
""" |
|||
Author: Petr Kellnhofer |
|||
Intersects rays with the [-1, 1] NDC volume. |
|||
Returns min and max distance of entry. |
|||
Returns -1 for no intersection. |
|||
https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection |
|||
""" |
|||
o_shape = rays_o.shape |
|||
rays_o = rays_o.detach().reshape(-1, 3) |
|||
rays_d = rays_d.detach().reshape(-1, 3) |
|||
|
|||
|
|||
bb_min = [-1*(box_side_length/2), -1*(box_side_length/2), -1*(box_side_length/2)] |
|||
bb_max = [1*(box_side_length/2), 1*(box_side_length/2), 1*(box_side_length/2)] |
|||
bounds = torch.tensor([bb_min, bb_max], dtype=rays_o.dtype, device=rays_o.device) |
|||
is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device) |
|||
|
|||
# Precompute inverse for stability. |
|||
invdir = 1 / rays_d |
|||
sign = (invdir < 0).long() |
|||
|
|||
# Intersect with YZ plane. |
|||
tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0] |
|||
tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0] |
|||
|
|||
# Intersect with XZ plane. |
|||
tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1] |
|||
tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1] |
|||
|
|||
# Resolve parallel rays. |
|||
is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False |
|||
|
|||
# Use the shortest intersection. |
|||
tmin = torch.max(tmin, tymin) |
|||
tmax = torch.min(tmax, tymax) |
|||
|
|||
# Intersect with XY plane. |
|||
tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2] |
|||
tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2] |
|||
|
|||
# Resolve parallel rays. |
|||
is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False |
|||
|
|||
# Use the shortest intersection. |
|||
tmin = torch.max(tmin, tzmin) |
|||
tmax = torch.min(tmax, tzmax) |
|||
|
|||
# Mark invalid. |
|||
tmin[torch.logical_not(is_valid)] = -1 |
|||
tmax[torch.logical_not(is_valid)] = -2 |
|||
|
|||
return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1) |
|||
|
|||
|
|||
def linspace(start: torch.Tensor, stop: torch.Tensor, num: int): |
|||
""" |
|||
Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive. |
|||
Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch. |
|||
""" |
|||
# create a tensor of 'num' steps from 0 to 1 |
|||
steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1) |
|||
|
|||
# reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings |
|||
# - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript |
|||
# "cannot statically infer the expected size of a list in this contex", hence the code below |
|||
for i in range(start.ndim): |
|||
steps = steps.unsqueeze(-1) |
|||
|
|||
# the output starts at 'start' and increments until 'stop' in each dimension |
|||
out = start[None] + steps * (stop - start)[None] |
|||
|
|||
return out |
@ -0,0 +1,72 @@ |
|||
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
|||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary |
|||
# |
|||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual |
|||
# property and proprietary rights in and to this material, related |
|||
# documentation and any modifications thereto. Any use, reproduction, |
|||
# disclosure or distribution of this material and related documentation |
|||
# without an express license agreement from NVIDIA CORPORATION or |
|||
# its affiliates is strictly prohibited. |
|||
# |
|||
# Modified by Jiale Xu |
|||
# The modifications are subject to the same license as the original. |
|||
|
|||
|
|||
""" |
|||
The ray marcher takes the raw output of the implicit representation and uses the volume rendering equation to produce composited colors and depths. |
|||
Based off of the implementation in MipNeRF (this one doesn't do any cone tracing though!) |
|||
""" |
|||
|
|||
import torch |
|||
import torch.nn as nn |
|||
import torch.nn.functional as F |
|||
|
|||
|
|||
class MipRayMarcher2(nn.Module): |
|||
def __init__(self, activation_factory): |
|||
super().__init__() |
|||
self.activation_factory = activation_factory |
|||
|
|||
def run_forward(self, colors, densities, depths, rendering_options, normals=None): |
|||
dtype = colors.dtype |
|||
deltas = depths[:, :, 1:] - depths[:, :, :-1] |
|||
colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2 |
|||
densities_mid = (densities[:, :, :-1] + densities[:, :, 1:]) / 2 |
|||
depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2 |
|||
|
|||
# using factory mode for better usability |
|||
densities_mid = self.activation_factory(rendering_options)(densities_mid).to(dtype) |
|||
|
|||
density_delta = densities_mid * deltas |
|||
|
|||
alpha = 1 - torch.exp(-density_delta).to(dtype) |
|||
|
|||
alpha_shifted = torch.cat([torch.ones_like(alpha[:, :, :1]), 1-alpha + 1e-10], -2) |
|||
weights = alpha * torch.cumprod(alpha_shifted, -2)[:, :, :-1] |
|||
weights = weights.to(dtype) |
|||
|
|||
composite_rgb = torch.sum(weights * colors_mid, -2) |
|||
weight_total = weights.sum(2) |
|||
# composite_depth = torch.sum(weights * depths_mid, -2) / weight_total |
|||
composite_depth = torch.sum(weights * depths_mid, -2) |
|||
|
|||
# clip the composite to min/max range of depths |
|||
composite_depth = torch.nan_to_num(composite_depth, float('inf')).to(dtype) |
|||
composite_depth = torch.clamp(composite_depth, torch.min(depths), torch.max(depths)) |
|||
|
|||
if rendering_options.get('white_back', False): |
|||
composite_rgb = composite_rgb + 1 - weight_total |
|||
|
|||
# rendered value scale is 0-1, comment out original mipnerf scaling |
|||
# composite_rgb = composite_rgb * 2 - 1 # Scale to (-1, 1) |
|||
|
|||
return composite_rgb, composite_depth, weights |
|||
|
|||
|
|||
def forward(self, colors, densities, depths, rendering_options, normals=None): |
|||
if normals is not None: |
|||
composite_rgb, composite_depth, composite_normals, weights = self.run_forward(colors, densities, depths, rendering_options, normals) |
|||
return composite_rgb, composite_depth, composite_normals, weights |
|||
|
|||
composite_rgb, composite_depth, weights = self.run_forward(colors, densities, depths, rendering_options) |
|||
return composite_rgb, composite_depth, weights |
@ -0,0 +1,141 @@ |
|||
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
|||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary |
|||
# |
|||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual |
|||
# property and proprietary rights in and to this material, related |
|||
# documentation and any modifications thereto. Any use, reproduction, |
|||
# disclosure or distribution of this material and related documentation |
|||
# without an express license agreement from NVIDIA CORPORATION or |
|||
# its affiliates is strictly prohibited. |
|||
# |
|||
# Modified by Jiale Xu |
|||
# The modifications are subject to the same license as the original. |
|||
|
|||
|
|||
""" |
|||
The ray sampler is a module that takes in camera matrices and resolution and batches of rays. |
|||
Expects cam2world matrices that use the OpenCV camera coordinate system conventions. |
|||
""" |
|||
|
|||
import torch |
|||
|
|||
class RaySampler(torch.nn.Module): |
|||
def __init__(self): |
|||
super().__init__() |
|||
self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None |
|||
|
|||
|
|||
def forward(self, cam2world_matrix, intrinsics, render_size): |
|||
""" |
|||
Create batches of rays and return origins and directions. |
|||
|
|||
cam2world_matrix: (N, 4, 4) |
|||
intrinsics: (N, 3, 3) |
|||
render_size: int |
|||
|
|||
ray_origins: (N, M, 3) |
|||
ray_dirs: (N, M, 2) |
|||
""" |
|||
|
|||
dtype = cam2world_matrix.dtype |
|||
device = cam2world_matrix.device |
|||
N, M = cam2world_matrix.shape[0], render_size**2 |
|||
cam_locs_world = cam2world_matrix[:, :3, 3] |
|||
fx = intrinsics[:, 0, 0] |
|||
fy = intrinsics[:, 1, 1] |
|||
cx = intrinsics[:, 0, 2] |
|||
cy = intrinsics[:, 1, 2] |
|||
sk = intrinsics[:, 0, 1] |
|||
|
|||
uv = torch.stack(torch.meshgrid( |
|||
torch.arange(render_size, dtype=dtype, device=device), |
|||
torch.arange(render_size, dtype=dtype, device=device), |
|||
indexing='ij', |
|||
)) |
|||
uv = uv.flip(0).reshape(2, -1).transpose(1, 0) |
|||
uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1) |
|||
|
|||
x_cam = uv[:, :, 0].view(N, -1) * (1./render_size) + (0.5/render_size) |
|||
y_cam = uv[:, :, 1].view(N, -1) * (1./render_size) + (0.5/render_size) |
|||
z_cam = torch.ones((N, M), dtype=dtype, device=device) |
|||
|
|||
x_lift = (x_cam - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y_cam/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z_cam |
|||
y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam |
|||
|
|||
cam_rel_points = torch.stack((x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1).to(dtype) |
|||
|
|||
_opencv2blender = torch.tensor([ |
|||
[1, 0, 0, 0], |
|||
[0, -1, 0, 0], |
|||
[0, 0, -1, 0], |
|||
[0, 0, 0, 1], |
|||
], dtype=dtype, device=device).unsqueeze(0).repeat(N, 1, 1) |
|||
|
|||
cam2world_matrix = torch.bmm(cam2world_matrix, _opencv2blender) |
|||
|
|||
world_rel_points = torch.bmm(cam2world_matrix, cam_rel_points.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3] |
|||
|
|||
ray_dirs = world_rel_points - cam_locs_world[:, None, :] |
|||
ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2).to(dtype) |
|||
|
|||
ray_origins = cam_locs_world.unsqueeze(1).repeat(1, ray_dirs.shape[1], 1) |
|||
|
|||
return ray_origins, ray_dirs |
|||
|
|||
|
|||
class OrthoRaySampler(torch.nn.Module): |
|||
def __init__(self): |
|||
super().__init__() |
|||
self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None |
|||
|
|||
|
|||
def forward(self, cam2world_matrix, ortho_scale, render_size): |
|||
""" |
|||
Create batches of rays and return origins and directions. |
|||
|
|||
cam2world_matrix: (N, 4, 4) |
|||
ortho_scale: float |
|||
render_size: int |
|||
|
|||
ray_origins: (N, M, 3) |
|||
ray_dirs: (N, M, 3) |
|||
""" |
|||
|
|||
N, M = cam2world_matrix.shape[0], render_size**2 |
|||
|
|||
uv = torch.stack(torch.meshgrid( |
|||
torch.arange(render_size, dtype=torch.float32, device=cam2world_matrix.device), |
|||
torch.arange(render_size, dtype=torch.float32, device=cam2world_matrix.device), |
|||
indexing='ij', |
|||
)) |
|||
uv = uv.flip(0).reshape(2, -1).transpose(1, 0) |
|||
uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1) |
|||
|
|||
x_cam = uv[:, :, 0].view(N, -1) * (1./render_size) + (0.5/render_size) |
|||
y_cam = uv[:, :, 1].view(N, -1) * (1./render_size) + (0.5/render_size) |
|||
z_cam = torch.zeros((N, M), device=cam2world_matrix.device) |
|||
|
|||
x_lift = (x_cam - 0.5) * ortho_scale |
|||
y_lift = (y_cam - 0.5) * ortho_scale |
|||
|
|||
cam_rel_points = torch.stack((x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1) |
|||
|
|||
_opencv2blender = torch.tensor([ |
|||
[1, 0, 0, 0], |
|||
[0, -1, 0, 0], |
|||
[0, 0, -1, 0], |
|||
[0, 0, 0, 1], |
|||
], dtype=torch.float32, device=cam2world_matrix.device).unsqueeze(0).repeat(N, 1, 1) |
|||
|
|||
cam2world_matrix = torch.bmm(cam2world_matrix, _opencv2blender) |
|||
|
|||
ray_origins = torch.bmm(cam2world_matrix, cam_rel_points.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3] |
|||
|
|||
ray_dirs_cam = torch.stack([ |
|||
torch.zeros((N, M), device=cam2world_matrix.device), |
|||
torch.zeros((N, M), device=cam2world_matrix.device), |
|||
torch.ones((N, M), device=cam2world_matrix.device), |
|||
], dim=-1) |
|||
ray_dirs = torch.bmm(cam2world_matrix[:, :3, :3], ray_dirs_cam.permute(0, 2, 1)).permute(0, 2, 1) |
|||
|
|||
return ray_origins, ray_dirs |
@ -0,0 +1,323 @@ |
|||
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
|||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary |
|||
# |
|||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual |
|||
# property and proprietary rights in and to this material, related |
|||
# documentation and any modifications thereto. Any use, reproduction, |
|||
# disclosure or distribution of this material and related documentation |
|||
# without an express license agreement from NVIDIA CORPORATION or |
|||
# its affiliates is strictly prohibited. |
|||
# |
|||
# Modified by Jiale Xu |
|||
# The modifications are subject to the same license as the original. |
|||
|
|||
|
|||
""" |
|||
The renderer is a module that takes in rays, decides where to sample along each |
|||
ray, and computes pixel colors using the volume rendering equation. |
|||
""" |
|||
|
|||
import torch |
|||
import torch.nn as nn |
|||
import torch.nn.functional as F |
|||
|
|||
from .ray_marcher import MipRayMarcher2 |
|||
from . import math_utils |
|||
|
|||
|
|||
def generate_planes(): |
|||
""" |
|||
Defines planes by the three vectors that form the "axes" of the |
|||
plane. Should work with arbitrary number of planes and planes of |
|||
arbitrary orientation. |
|||
|
|||
Bugfix reference: https://github.com/NVlabs/eg3d/issues/67 |
|||
""" |
|||
return torch.tensor([[[1, 0, 0], |
|||
[0, 1, 0], |
|||
[0, 0, 1]], |
|||
[[1, 0, 0], |
|||
[0, 0, 1], |
|||
[0, 1, 0]], |
|||
[[0, 0, 1], |
|||
[0, 1, 0], |
|||
[1, 0, 0]]], dtype=torch.float32) |
|||
|
|||
def project_onto_planes(planes, coordinates): |
|||
""" |
|||
Does a projection of a 3D point onto a batch of 2D planes, |
|||
returning 2D plane coordinates. |
|||
|
|||
Takes plane axes of shape n_planes, 3, 3 |
|||
# Takes coordinates of shape N, M, 3 |
|||
# returns projections of shape N*n_planes, M, 2 |
|||
""" |
|||
N, M, C = coordinates.shape |
|||
n_planes, _, _ = planes.shape |
|||
coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3) |
|||
inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3) |
|||
projections = torch.bmm(coordinates, inv_planes) |
|||
return projections[..., :2] |
|||
|
|||
def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None): |
|||
assert padding_mode == 'zeros' |
|||
N, n_planes, C, H, W = plane_features.shape |
|||
_, M, _ = coordinates.shape |
|||
plane_features = plane_features.view(N*n_planes, C, H, W) |
|||
dtype = plane_features.dtype |
|||
|
|||
coordinates = (2/box_warp) * coordinates # add specific box bounds |
|||
|
|||
projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1) |
|||
output_features = torch.nn.functional.grid_sample( |
|||
plane_features, |
|||
projected_coordinates.to(dtype), |
|||
mode=mode, |
|||
padding_mode=padding_mode, |
|||
align_corners=False, |
|||
).permute(0, 3, 2, 1).reshape(N, n_planes, M, C) |
|||
return output_features |
|||
|
|||
def sample_from_3dgrid(grid, coordinates): |
|||
""" |
|||
Expects coordinates in shape (batch_size, num_points_per_batch, 3) |
|||
Expects grid in shape (1, channels, H, W, D) |
|||
(Also works if grid has batch size) |
|||
Returns sampled features of shape (batch_size, num_points_per_batch, feature_channels) |
|||
""" |
|||
batch_size, n_coords, n_dims = coordinates.shape |
|||
sampled_features = torch.nn.functional.grid_sample( |
|||
grid.expand(batch_size, -1, -1, -1, -1), |
|||
coordinates.reshape(batch_size, 1, 1, -1, n_dims), |
|||
mode='bilinear', |
|||
padding_mode='zeros', |
|||
align_corners=False, |
|||
) |
|||
N, C, H, W, D = sampled_features.shape |
|||
sampled_features = sampled_features.permute(0, 4, 3, 2, 1).reshape(N, H*W*D, C) |
|||
return sampled_features |
|||
|
|||
class ImportanceRenderer(torch.nn.Module): |
|||
""" |
|||
Modified original version to filter out-of-box samples as TensoRF does. |
|||
|
|||
Reference: |
|||
TensoRF: https://github.com/apchenstu/TensoRF/blob/main/models/tensorBase.py#L277 |
|||
""" |
|||
def __init__(self): |
|||
super().__init__() |
|||
self.activation_factory = self._build_activation_factory() |
|||
self.ray_marcher = MipRayMarcher2(self.activation_factory) |
|||
self.plane_axes = generate_planes() |
|||
|
|||
def _build_activation_factory(self): |
|||
def activation_factory(options: dict): |
|||
if options['clamp_mode'] == 'softplus': |
|||
return lambda x: F.softplus(x - 1) # activation bias of -1 makes things initialize better |
|||
else: |
|||
assert False, "Renderer only supports `clamp_mode`=`softplus`!" |
|||
return activation_factory |
|||
|
|||
def _forward_pass(self, depths: torch.Tensor, ray_directions: torch.Tensor, ray_origins: torch.Tensor, |
|||
planes: torch.Tensor, decoder: nn.Module, rendering_options: dict): |
|||
""" |
|||
Additional filtering is applied to filter out-of-box samples. |
|||
Modifications made by Zexin He. |
|||
""" |
|||
|
|||
# context related variables |
|||
batch_size, num_rays, samples_per_ray, _ = depths.shape |
|||
device = depths.device |
|||
|
|||
# define sample points with depths |
|||
sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, samples_per_ray, -1).reshape(batch_size, -1, 3) |
|||
sample_coordinates = (ray_origins.unsqueeze(-2) + depths * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3) |
|||
|
|||
# filter out-of-box samples |
|||
mask_inbox = \ |
|||
(rendering_options['sampler_bbox_min'] <= sample_coordinates) & \ |
|||
(sample_coordinates <= rendering_options['sampler_bbox_max']) |
|||
mask_inbox = mask_inbox.all(-1) |
|||
|
|||
# forward model according to all samples |
|||
_out = self.run_model(planes, decoder, sample_coordinates, sample_directions, rendering_options) |
|||
|
|||
# set out-of-box samples to zeros(rgb) & -inf(sigma) |
|||
SAFE_GUARD = 3 |
|||
DATA_TYPE = _out['sigma'].dtype |
|||
colors_pass = torch.zeros(batch_size, num_rays * samples_per_ray, 3, device=device, dtype=DATA_TYPE) |
|||
densities_pass = torch.nan_to_num(torch.full((batch_size, num_rays * samples_per_ray, 1), -float('inf'), device=device, dtype=DATA_TYPE)) / SAFE_GUARD |
|||
colors_pass[mask_inbox], densities_pass[mask_inbox] = _out['rgb'][mask_inbox], _out['sigma'][mask_inbox] |
|||
|
|||
# reshape back |
|||
colors_pass = colors_pass.reshape(batch_size, num_rays, samples_per_ray, colors_pass.shape[-1]) |
|||
densities_pass = densities_pass.reshape(batch_size, num_rays, samples_per_ray, densities_pass.shape[-1]) |
|||
|
|||
return colors_pass, densities_pass |
|||
|
|||
def forward(self, planes, decoder, ray_origins, ray_directions, rendering_options): |
|||
# self.plane_axes = self.plane_axes.to(ray_origins.device) |
|||
|
|||
if rendering_options['ray_start'] == rendering_options['ray_end'] == 'auto': |
|||
ray_start, ray_end = math_utils.get_ray_limits_box(ray_origins, ray_directions, box_side_length=rendering_options['box_warp']) |
|||
is_ray_valid = ray_end > ray_start |
|||
if torch.any(is_ray_valid).item(): |
|||
ray_start[~is_ray_valid] = ray_start[is_ray_valid].min() |
|||
ray_end[~is_ray_valid] = ray_start[is_ray_valid].max() |
|||
depths_coarse = self.sample_stratified(ray_origins, ray_start, ray_end, rendering_options['depth_resolution'], rendering_options['disparity_space_sampling']) |
|||
else: |
|||
# Create stratified depth samples |
|||
depths_coarse = self.sample_stratified(ray_origins, rendering_options['ray_start'], rendering_options['ray_end'], rendering_options['depth_resolution'], rendering_options['disparity_space_sampling']) |
|||
|
|||
# Coarse Pass |
|||
colors_coarse, densities_coarse = self._forward_pass( |
|||
depths=depths_coarse, ray_directions=ray_directions, ray_origins=ray_origins, |
|||
planes=planes, decoder=decoder, rendering_options=rendering_options) |
|||
|
|||
# Fine Pass |
|||
N_importance = rendering_options['depth_resolution_importance'] |
|||
if N_importance > 0: |
|||
_, _, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options) |
|||
|
|||
depths_fine = self.sample_importance(depths_coarse, weights, N_importance) |
|||
|
|||
colors_fine, densities_fine = self._forward_pass( |
|||
depths=depths_fine, ray_directions=ray_directions, ray_origins=ray_origins, |
|||
planes=planes, decoder=decoder, rendering_options=rendering_options) |
|||
|
|||
all_depths, all_colors, all_densities = self.unify_samples(depths_coarse, colors_coarse, densities_coarse, |
|||
depths_fine, colors_fine, densities_fine) |
|||
|
|||
rgb_final, depth_final, weights = self.ray_marcher(all_colors, all_densities, all_depths, rendering_options) |
|||
else: |
|||
rgb_final, depth_final, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options) |
|||
|
|||
return rgb_final, depth_final, weights.sum(2) |
|||
|
|||
def run_model(self, planes, decoder, sample_coordinates, sample_directions, options): |
|||
plane_axes = self.plane_axes.to(planes.device) |
|||
sampled_features = sample_from_planes(plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=options['box_warp']) |
|||
|
|||
out = decoder(sampled_features, sample_directions) |
|||
if options.get('density_noise', 0) > 0: |
|||
out['sigma'] += torch.randn_like(out['sigma']) * options['density_noise'] |
|||
return out |
|||
|
|||
def run_model_activated(self, planes, decoder, sample_coordinates, sample_directions, options): |
|||
out = self.run_model(planes, decoder, sample_coordinates, sample_directions, options) |
|||
out['sigma'] = self.activation_factory(options)(out['sigma']) |
|||
return out |
|||
|
|||
def sort_samples(self, all_depths, all_colors, all_densities): |
|||
_, indices = torch.sort(all_depths, dim=-2) |
|||
all_depths = torch.gather(all_depths, -2, indices) |
|||
all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1])) |
|||
all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1)) |
|||
return all_depths, all_colors, all_densities |
|||
|
|||
def unify_samples(self, depths1, colors1, densities1, depths2, colors2, densities2, normals1=None, normals2=None): |
|||
all_depths = torch.cat([depths1, depths2], dim = -2) |
|||
all_colors = torch.cat([colors1, colors2], dim = -2) |
|||
all_densities = torch.cat([densities1, densities2], dim = -2) |
|||
|
|||
if normals1 is not None and normals2 is not None: |
|||
all_normals = torch.cat([normals1, normals2], dim = -2) |
|||
else: |
|||
all_normals = None |
|||
|
|||
_, indices = torch.sort(all_depths, dim=-2) |
|||
all_depths = torch.gather(all_depths, -2, indices) |
|||
all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1])) |
|||
all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1)) |
|||
|
|||
if all_normals is not None: |
|||
all_normals = torch.gather(all_normals, -2, indices.expand(-1, -1, -1, all_normals.shape[-1])) |
|||
return all_depths, all_colors, all_normals, all_densities |
|||
|
|||
return all_depths, all_colors, all_densities |
|||
|
|||
def sample_stratified(self, ray_origins, ray_start, ray_end, depth_resolution, disparity_space_sampling=False): |
|||
""" |
|||
Return depths of approximately uniformly spaced samples along rays. |
|||
""" |
|||
N, M, _ = ray_origins.shape |
|||
if disparity_space_sampling: |
|||
depths_coarse = torch.linspace(0, |
|||
1, |
|||
depth_resolution, |
|||
device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1) |
|||
depth_delta = 1/(depth_resolution - 1) |
|||
depths_coarse += torch.rand_like(depths_coarse) * depth_delta |
|||
depths_coarse = 1./(1./ray_start * (1. - depths_coarse) + 1./ray_end * depths_coarse) |
|||
else: |
|||
if type(ray_start) == torch.Tensor: |
|||
depths_coarse = math_utils.linspace(ray_start, ray_end, depth_resolution).permute(1,2,0,3) |
|||
depth_delta = (ray_end - ray_start) / (depth_resolution - 1) |
|||
depths_coarse += torch.rand_like(depths_coarse) * depth_delta[..., None] |
|||
else: |
|||
depths_coarse = torch.linspace(ray_start, ray_end, depth_resolution, device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1) |
|||
depth_delta = (ray_end - ray_start)/(depth_resolution - 1) |
|||
depths_coarse += torch.rand_like(depths_coarse) * depth_delta |
|||
|
|||
return depths_coarse |
|||
|
|||
def sample_importance(self, z_vals, weights, N_importance): |
|||
""" |
|||
Return depths of importance sampled points along rays. See NeRF importance sampling for more. |
|||
""" |
|||
with torch.no_grad(): |
|||
batch_size, num_rays, samples_per_ray, _ = z_vals.shape |
|||
|
|||
z_vals = z_vals.reshape(batch_size * num_rays, samples_per_ray) |
|||
weights = weights.reshape(batch_size * num_rays, -1) # -1 to account for loss of 1 sample in MipRayMarcher |
|||
|
|||
# smooth weights |
|||
weights = torch.nn.functional.max_pool1d(weights.unsqueeze(1), 2, 1, padding=1) |
|||
weights = torch.nn.functional.avg_pool1d(weights, 2, 1).squeeze() |
|||
weights = weights + 0.01 |
|||
|
|||
z_vals_mid = 0.5 * (z_vals[: ,:-1] + z_vals[: ,1:]) |
|||
importance_z_vals = self.sample_pdf(z_vals_mid, weights[:, 1:-1], |
|||
N_importance).detach().reshape(batch_size, num_rays, N_importance, 1) |
|||
return importance_z_vals |
|||
|
|||
def sample_pdf(self, bins, weights, N_importance, det=False, eps=1e-5): |
|||
""" |
|||
Sample @N_importance samples from @bins with distribution defined by @weights. |
|||
Inputs: |
|||
bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2" |
|||
weights: (N_rays, N_samples_) |
|||
N_importance: the number of samples to draw from the distribution |
|||
det: deterministic or not |
|||
eps: a small number to prevent division by zero |
|||
Outputs: |
|||
samples: the sampled samples |
|||
""" |
|||
N_rays, N_samples_ = weights.shape |
|||
weights = weights + eps # prevent division by zero (don't do inplace op!) |
|||
pdf = weights / torch.sum(weights, -1, keepdim=True) # (N_rays, N_samples_) |
|||
cdf = torch.cumsum(pdf, -1) # (N_rays, N_samples), cumulative distribution function |
|||
cdf = torch.cat([torch.zeros_like(cdf[: ,:1]), cdf], -1) # (N_rays, N_samples_+1) |
|||
# padded to 0~1 inclusive |
|||
|
|||
if det: |
|||
u = torch.linspace(0, 1, N_importance, device=bins.device) |
|||
u = u.expand(N_rays, N_importance) |
|||
else: |
|||
u = torch.rand(N_rays, N_importance, device=bins.device) |
|||
u = u.contiguous() |
|||
|
|||
inds = torch.searchsorted(cdf, u, right=True) |
|||
below = torch.clamp_min(inds-1, 0) |
|||
above = torch.clamp_max(inds, N_samples_) |
|||
|
|||
inds_sampled = torch.stack([below, above], -1).view(N_rays, 2*N_importance) |
|||
cdf_g = torch.gather(cdf, 1, inds_sampled).view(N_rays, N_importance, 2) |
|||
bins_g = torch.gather(bins, 1, inds_sampled).view(N_rays, N_importance, 2) |
|||
|
|||
denom = cdf_g[...,1]-cdf_g[...,0] |
|||
denom[denom<eps] = 1 # denom equals 0 means a bin has weight 0, in which case it will not be sampled |
|||
# anyway, therefore any value for it is fine (set to 1 here) |
|||
|
|||
samples = bins_g[...,0] + (u-cdf_g[...,0])/denom * (bins_g[...,1]-bins_g[...,0]) |
|||
return samples |
@ -0,0 +1,111 @@ |
|||
import torch |
|||
import torch.nn.functional as F |
|||
import numpy as np |
|||
|
|||
|
|||
def pad_camera_extrinsics_4x4(extrinsics): |
|||
if extrinsics.shape[-2] == 4: |
|||
return extrinsics |
|||
padding = torch.tensor([[0, 0, 0, 1]]).to(extrinsics) |
|||
if extrinsics.ndim == 3: |
|||
padding = padding.unsqueeze(0).repeat(extrinsics.shape[0], 1, 1) |
|||
extrinsics = torch.cat([extrinsics, padding], dim=-2) |
|||
return extrinsics |
|||
|
|||
|
|||
def center_looking_at_camera_pose(camera_position: torch.Tensor, look_at: torch.Tensor = None, up_world: torch.Tensor = None): |
|||
""" |
|||
Create OpenGL camera extrinsics from camera locations and look-at position. |
|||
|
|||
camera_position: (M, 3) or (3,) |
|||
look_at: (3) |
|||
up_world: (3) |
|||
return: (M, 3, 4) or (3, 4) |
|||
""" |
|||
# by default, looking at the origin and world up is z-axis |
|||
if look_at is None: |
|||
look_at = torch.tensor([0, 0, 0], dtype=torch.float32) |
|||
if up_world is None: |
|||
up_world = torch.tensor([0, 0, 1], dtype=torch.float32) |
|||
if camera_position.ndim == 2: |
|||
look_at = look_at.unsqueeze(0).repeat(camera_position.shape[0], 1) |
|||
up_world = up_world.unsqueeze(0).repeat(camera_position.shape[0], 1) |
|||
|
|||
# OpenGL camera: z-backward, x-right, y-up |
|||
z_axis = camera_position - look_at |
|||
z_axis = F.normalize(z_axis, dim=-1).float() |
|||
x_axis = torch.linalg.cross(up_world, z_axis, dim=-1) |
|||
x_axis = F.normalize(x_axis, dim=-1).float() |
|||
y_axis = torch.linalg.cross(z_axis, x_axis, dim=-1) |
|||
y_axis = F.normalize(y_axis, dim=-1).float() |
|||
|
|||
extrinsics = torch.stack([x_axis, y_axis, z_axis, camera_position], dim=-1) |
|||
extrinsics = pad_camera_extrinsics_4x4(extrinsics) |
|||
return extrinsics |
|||
|
|||
|
|||
def spherical_camera_pose(azimuths: np.ndarray, elevations: np.ndarray, radius=2.5): |
|||
azimuths = np.deg2rad(azimuths) |
|||
elevations = np.deg2rad(elevations) |
|||
|
|||
xs = radius * np.cos(elevations) * np.cos(azimuths) |
|||
ys = radius * np.cos(elevations) * np.sin(azimuths) |
|||
zs = radius * np.sin(elevations) |
|||
|
|||
cam_locations = np.stack([xs, ys, zs], axis=-1) |
|||
cam_locations = torch.from_numpy(cam_locations).float() |
|||
|
|||
c2ws = center_looking_at_camera_pose(cam_locations) |
|||
return c2ws |
|||
|
|||
|
|||
def get_circular_camera_poses(M=120, radius=2.5, elevation=30.0): |
|||
# M: number of circular views |
|||
# radius: camera dist to center |
|||
# elevation: elevation degrees of the camera |
|||
# return: (M, 4, 4) |
|||
assert M > 0 and radius > 0 |
|||
|
|||
elevation = np.deg2rad(elevation) |
|||
|
|||
camera_positions = [] |
|||
for i in range(M): |
|||
azimuth = 2 * np.pi * i / M |
|||
x = radius * np.cos(elevation) * np.cos(azimuth) |
|||
y = radius * np.cos(elevation) * np.sin(azimuth) |
|||
z = radius * np.sin(elevation) |
|||
camera_positions.append([x, y, z]) |
|||
camera_positions = np.array(camera_positions) |
|||
camera_positions = torch.from_numpy(camera_positions).float() |
|||
extrinsics = center_looking_at_camera_pose(camera_positions) |
|||
return extrinsics |
|||
|
|||
|
|||
def FOV_to_intrinsics(fov, device='cpu'): |
|||
""" |
|||
Creates a 3x3 camera intrinsics matrix from the camera field of view, specified in degrees. |
|||
Note the intrinsics are returned as normalized by image size, rather than in pixel units. |
|||
Assumes principal point is at image center. |
|||
""" |
|||
focal_length = 0.5 / np.tan(np.deg2rad(fov) * 0.5) |
|||
intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device) |
|||
return intrinsics |
|||
|
|||
|
|||
def get_zero123plus_input_cameras(batch_size=1, radius=4.0, fov=30.0): |
|||
""" |
|||
Get the input camera parameters. |
|||
""" |
|||
azimuths = np.array([30, 90, 150, 210, 270, 330]).astype(float) |
|||
elevations = np.array([20, -10, 20, -10, 20, -10]).astype(float) |
|||
|
|||
c2ws = spherical_camera_pose(azimuths, elevations, radius) |
|||
c2ws = c2ws.float().flatten(-2) |
|||
|
|||
Ks = FOV_to_intrinsics(fov).unsqueeze(0).repeat(6, 1, 1).float().flatten(-2) |
|||
|
|||
extrinsics = c2ws[:, :12] |
|||
intrinsics = torch.stack([Ks[:, 0], Ks[:, 4], Ks[:, 2], Ks[:, 5]], dim=-1) |
|||
cameras = torch.cat([extrinsics, intrinsics], dim=-1) |
|||
|
|||
return cameras.unsqueeze(0).repeat(batch_size, 1, 1) |
@ -0,0 +1,97 @@ |
|||
import os |
|||
import imageio |
|||
import rembg |
|||
import torch |
|||
import numpy as np |
|||
import PIL.Image |
|||
from PIL import Image |
|||
from typing import Any |
|||
|
|||
|
|||
def remove_background(image: PIL.Image.Image, |
|||
rembg_session: Any = None, |
|||
force: bool = False, |
|||
**rembg_kwargs, |
|||
) -> PIL.Image.Image: |
|||
do_remove = True |
|||
if image.mode == "RGBA" and image.getextrema()[3][0] < 255: |
|||
do_remove = False |
|||
do_remove = do_remove or force |
|||
if do_remove: |
|||
image = rembg.remove(image, session=rembg_session, **rembg_kwargs) |
|||
return image |
|||
|
|||
|
|||
def resize_foreground( |
|||
image: PIL.Image.Image, |
|||
ratio: float, |
|||
) -> PIL.Image.Image: |
|||
image = np.array(image) |
|||
assert image.shape[-1] == 4 |
|||
alpha = np.where(image[..., 3] > 0) |
|||
y1, y2, x1, x2 = ( |
|||
alpha[0].min(), |
|||
alpha[0].max(), |
|||
alpha[1].min(), |
|||
alpha[1].max(), |
|||
) |
|||
# crop the foreground |
|||
fg = image[y1:y2, x1:x2] |
|||
# pad to square |
|||
size = max(fg.shape[0], fg.shape[1]) |
|||
ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2 |
|||
ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0 |
|||
new_image = np.pad( |
|||
fg, |
|||
((ph0, ph1), (pw0, pw1), (0, 0)), |
|||
mode="constant", |
|||
constant_values=((0, 0), (0, 0), (0, 0)), |
|||
) |
|||
|
|||
# compute padding according to the ratio |
|||
new_size = int(new_image.shape[0] / ratio) |
|||
# pad to size, double side |
|||
ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2 |
|||
ph1, pw1 = new_size - size - ph0, new_size - size - pw0 |
|||
new_image = np.pad( |
|||
new_image, |
|||
((ph0, ph1), (pw0, pw1), (0, 0)), |
|||
mode="constant", |
|||
constant_values=((0, 0), (0, 0), (0, 0)), |
|||
) |
|||
new_image = PIL.Image.fromarray(new_image) |
|||
return new_image |
|||
|
|||
|
|||
def images_to_video( |
|||
images: torch.Tensor, |
|||
output_path: str, |
|||
fps: int = 30, |
|||
) -> None: |
|||
# images: (N, C, H, W) |
|||
video_dir = os.path.dirname(output_path) |
|||
video_name = os.path.basename(output_path) |
|||
os.makedirs(video_dir, exist_ok=True) |
|||
|
|||
frames = [] |
|||
for i in range(len(images)): |
|||
frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) |
|||
assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \ |
|||
f"Frame shape mismatch: {frame.shape} vs {images.shape}" |
|||
assert frame.min() >= 0 and frame.max() <= 255, \ |
|||
f"Frame value out of range: {frame.min()} ~ {frame.max()}" |
|||
frames.append(frame) |
|||
imageio.mimwrite(output_path, np.stack(frames), fps=fps, quality=10) |
|||
|
|||
|
|||
def save_video( |
|||
frames: torch.Tensor, |
|||
output_path: str, |
|||
fps: int = 30, |
|||
) -> None: |
|||
# images: (N, C, H, W) |
|||
frames = [(frame.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) for frame in frames] |
|||
writer = imageio.get_writer(output_path, fps=fps) |
|||
for frame in frames: |
|||
writer.append_data(frame) |
|||
writer.close() |
@ -0,0 +1,165 @@ |
|||
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
|||
# |
|||
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property |
|||
# and proprietary rights in and to this software, related documentation |
|||
# and any modifications thereto. Any use, reproduction, disclosure or |
|||
# distribution of this software and related documentation without an express |
|||
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. |
|||
|
|||
import torch |
|||
import xatlas |
|||
import trimesh |
|||
import cv2 |
|||
import numpy as np |
|||
import nvdiffrast.torch as dr |
|||
from PIL import Image |
|||
|
|||
|
|||
def save_obj(pointnp_px3, facenp_fx3, colornp_px3, fname): |
|||
mesh = trimesh.Trimesh( |
|||
vertices=pointnp_px3, |
|||
faces=facenp_fx3, |
|||
vertex_colors=colornp_px3, |
|||
) |
|||
mesh.export(fname, 'obj') |
|||
|
|||
|
|||
def save_obj_with_mtl(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, texmap_hxwx3, fname): |
|||
import os |
|||
fol, na = os.path.split(fname) |
|||
na, _ = os.path.splitext(na) |
|||
|
|||
matname = '%s/%s.mtl' % (fol, na) |
|||
fid = open(matname, 'w') |
|||
fid.write('newmtl material_0\n') |
|||
fid.write('Kd 1 1 1\n') |
|||
fid.write('Ka 0 0 0\n') |
|||
fid.write('Ks 0.4 0.4 0.4\n') |
|||
fid.write('Ns 10\n') |
|||
fid.write('illum 2\n') |
|||
fid.write('map_Kd %s.png\n' % na) |
|||
fid.close() |
|||
#### |
|||
|
|||
fid = open(fname, 'w') |
|||
fid.write('mtllib %s.mtl\n' % na) |
|||
|
|||
for pidx, p in enumerate(pointnp_px3): |
|||
pp = p |
|||
fid.write('v %f %f %f\n' % (pp[0], pp[1], pp[2])) |
|||
|
|||
for pidx, p in enumerate(tcoords_px2): |
|||
pp = p |
|||
fid.write('vt %f %f\n' % (pp[0], pp[1])) |
|||
|
|||
fid.write('usemtl material_0\n') |
|||
for i, f in enumerate(facenp_fx3): |
|||
f1 = f + 1 |
|||
f2 = facetex_fx3[i] + 1 |
|||
fid.write('f %d/%d %d/%d %d/%d\n' % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2])) |
|||
fid.close() |
|||
|
|||
# save texture map |
|||
lo, hi = 0, 1 |
|||
img = np.asarray(texmap_hxwx3, dtype=np.float32) |
|||
img = (img - lo) * (255 / (hi - lo)) |
|||
img = img.clip(0, 255) |
|||
mask = np.sum(img.astype(np.float32), axis=-1, keepdims=True) |
|||
mask = (mask <= 3.0).astype(np.float32) |
|||
kernel = np.ones((3, 3), 'uint8') |
|||
dilate_img = cv2.dilate(img, kernel, iterations=1) |
|||
img = img * (1 - mask) + dilate_img * mask |
|||
img = img.clip(0, 255).astype(np.uint8) |
|||
Image.fromarray(np.ascontiguousarray(img[::-1, :, :]), 'RGB').save(f'{fol}/{na}.png') |
|||
|
|||
|
|||
def loadobj(meshfile): |
|||
v = [] |
|||
f = [] |
|||
meshfp = open(meshfile, 'r') |
|||
for line in meshfp.readlines(): |
|||
data = line.strip().split(' ') |
|||
data = [da for da in data if len(da) > 0] |
|||
if len(data) != 4: |
|||
continue |
|||
if data[0] == 'v': |
|||
v.append([float(d) for d in data[1:]]) |
|||
if data[0] == 'f': |
|||
data = [da.split('/')[0] for da in data] |
|||
f.append([int(d) for d in data[1:]]) |
|||
meshfp.close() |
|||
|
|||
# torch need int64 |
|||
facenp_fx3 = np.array(f, dtype=np.int64) - 1 |
|||
pointnp_px3 = np.array(v, dtype=np.float32) |
|||
return pointnp_px3, facenp_fx3 |
|||
|
|||
|
|||
def loadobjtex(meshfile): |
|||
v = [] |
|||
vt = [] |
|||
f = [] |
|||
ft = [] |
|||
meshfp = open(meshfile, 'r') |
|||
for line in meshfp.readlines(): |
|||
data = line.strip().split(' ') |
|||
data = [da for da in data if len(da) > 0] |
|||
if not ((len(data) == 3) or (len(data) == 4) or (len(data) == 5)): |
|||
continue |
|||
if data[0] == 'v': |
|||
assert len(data) == 4 |
|||
|
|||
v.append([float(d) for d in data[1:]]) |
|||
if data[0] == 'vt': |
|||
if len(data) == 3 or len(data) == 4: |
|||
vt.append([float(d) for d in data[1:3]]) |
|||
if data[0] == 'f': |
|||
data = [da.split('/') for da in data] |
|||
if len(data) == 4: |
|||
f.append([int(d[0]) for d in data[1:]]) |
|||
ft.append([int(d[1]) for d in data[1:]]) |
|||
elif len(data) == 5: |
|||
idx1 = [1, 2, 3] |
|||
data1 = [data[i] for i in idx1] |
|||
f.append([int(d[0]) for d in data1]) |
|||
ft.append([int(d[1]) for d in data1]) |
|||
idx2 = [1, 3, 4] |
|||
data2 = [data[i] for i in idx2] |
|||
f.append([int(d[0]) for d in data2]) |
|||
ft.append([int(d[1]) for d in data2]) |
|||
meshfp.close() |
|||
|
|||
# torch need int64 |
|||
facenp_fx3 = np.array(f, dtype=np.int64) - 1 |
|||
ftnp_fx3 = np.array(ft, dtype=np.int64) - 1 |
|||
pointnp_px3 = np.array(v, dtype=np.float32) |
|||
uvs = np.array(vt, dtype=np.float32) |
|||
return pointnp_px3, facenp_fx3, uvs, ftnp_fx3 |
|||
|
|||
|
|||
# ============================================================================================== |
|||
def interpolate(attr, rast, attr_idx, rast_db=None): |
|||
return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all') |
|||
|
|||
|
|||
def xatlas_uvmap(ctx, mesh_v, mesh_pos_idx, resolution): |
|||
vmapping, indices, uvs = xatlas.parametrize(mesh_v.detach().cpu().numpy(), mesh_pos_idx.detach().cpu().numpy()) |
|||
|
|||
# Convert to tensors |
|||
indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64) |
|||
|
|||
uvs = torch.tensor(uvs, dtype=torch.float32, device=mesh_v.device) |
|||
mesh_tex_idx = torch.tensor(indices_int64, dtype=torch.int64, device=mesh_v.device) |
|||
# mesh_v_tex. ture |
|||
uv_clip = uvs[None, ...] * 2.0 - 1.0 |
|||
|
|||
# pad to four component coordinate |
|||
uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., 0:1]), torch.ones_like(uv_clip[..., 0:1])), dim=-1) |
|||
|
|||
# rasterize |
|||
rast, _ = dr.rasterize(ctx, uv_clip4, mesh_tex_idx.int(), (resolution, resolution)) |
|||
|
|||
# Interpolate world space position |
|||
gb_pos, _ = interpolate(mesh_v[None, ...], rast, mesh_pos_idx.int()) |
|||
mask = rast[..., 3:4] > 0 |
|||
return uvs, mesh_tex_idx, gb_pos, mask |
@ -0,0 +1,26 @@ |
|||
import importlib |
|||
|
|||
|
|||
def count_params(model, verbose=False): |
|||
total_params = sum(p.numel() for p in model.parameters()) |
|||
if verbose: |
|||
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") |
|||
return total_params |
|||
|
|||
|
|||
def instantiate_from_config(config): |
|||
if not "target" in config: |
|||
if config == '__is_first_stage__': |
|||
return None |
|||
elif config == "__is_unconditional__": |
|||
return None |
|||
raise KeyError("Expected key `target` to instantiate.") |
|||
return get_obj_from_str(config["target"])(**config.get("params", dict())) |
|||
|
|||
|
|||
def get_obj_from_str(string, reload=False): |
|||
module, cls = string.rsplit(".", 1) |
|||
if reload: |
|||
module_imp = importlib.import_module(module) |
|||
importlib.reload(module_imp) |
|||
return getattr(importlib.import_module(module, package=None), cls) |
@ -0,0 +1,286 @@ |
|||
import os, sys |
|||
import argparse |
|||
import shutil |
|||
import subprocess |
|||
from omegaconf import OmegaConf |
|||
|
|||
from pytorch_lightning import seed_everything |
|||
from pytorch_lightning.trainer import Trainer |
|||
from pytorch_lightning.strategies import DDPStrategy |
|||
from pytorch_lightning.callbacks import Callback |
|||
from pytorch_lightning.utilities import rank_zero_only |
|||
|
|||
from src.utils.train_util import instantiate_from_config |
|||
|
|||
|
|||
@rank_zero_only |
|||
def rank_zero_print(*args): |
|||
print(*args) |
|||
|
|||
|
|||
def get_parser(**parser_kwargs): |
|||
def str2bool(v): |
|||
if isinstance(v, bool): |
|||
return v |
|||
if v.lower() in ("yes", "true", "t", "y", "1"): |
|||
return True |
|||
elif v.lower() in ("no", "false", "f", "n", "0"): |
|||
return False |
|||
else: |
|||
raise argparse.ArgumentTypeError("Boolean value expected.") |
|||
|
|||
parser = argparse.ArgumentParser(**parser_kwargs) |
|||
parser.add_argument( |
|||
"-r", |
|||
"--resume", |
|||
type=str, |
|||
default=None, |
|||
help="resume from checkpoint", |
|||
) |
|||
parser.add_argument( |
|||
"--resume_weights_only", |
|||
action="store_true", |
|||
help="only resume model weights", |
|||
) |
|||
parser.add_argument( |
|||
"-b", |
|||
"--base", |
|||
type=str, |
|||
default="base_config.yaml", |
|||
help="path to base configs", |
|||
) |
|||
parser.add_argument( |
|||
"-n", |
|||
"--name", |
|||
type=str, |
|||
default="", |
|||
help="experiment name", |
|||
) |
|||
parser.add_argument( |
|||
"--num_nodes", |
|||
type=int, |
|||
default=1, |
|||
help="number of nodes to use", |
|||
) |
|||
parser.add_argument( |
|||
"--gpus", |
|||
type=str, |
|||
default="0,", |
|||
help="gpu ids to use", |
|||
) |
|||
parser.add_argument( |
|||
"-s", |
|||
"--seed", |
|||
type=int, |
|||
default=42, |
|||
help="seed for seed_everything", |
|||
) |
|||
parser.add_argument( |
|||
"-l", |
|||
"--logdir", |
|||
type=str, |
|||
default="logs", |
|||
help="directory for logging data", |
|||
) |
|||
return parser |
|||
|
|||
|
|||
class SetupCallback(Callback): |
|||
def __init__(self, resume, logdir, ckptdir, cfgdir, config): |
|||
super().__init__() |
|||
self.resume = resume |
|||
self.logdir = logdir |
|||
self.ckptdir = ckptdir |
|||
self.cfgdir = cfgdir |
|||
self.config = config |
|||
|
|||
def on_fit_start(self, trainer, pl_module): |
|||
if trainer.global_rank == 0: |
|||
# Create logdirs and save configs |
|||
os.makedirs(self.logdir, exist_ok=True) |
|||
os.makedirs(self.ckptdir, exist_ok=True) |
|||
os.makedirs(self.cfgdir, exist_ok=True) |
|||
|
|||
rank_zero_print("Project config") |
|||
rank_zero_print(OmegaConf.to_yaml(self.config)) |
|||
OmegaConf.save(self.config, |
|||
os.path.join(self.cfgdir, "project.yaml")) |
|||
|
|||
|
|||
class CodeSnapshot(Callback): |
|||
""" |
|||
Modified from https://github.com/threestudio-project/threestudio/blob/main/threestudio/utils/callbacks.py#L60 |
|||
""" |
|||
def __init__(self, savedir): |
|||
self.savedir = savedir |
|||
|
|||
def get_file_list(self): |
|||
return [ |
|||
b.decode() |
|||
for b in set( |
|||
subprocess.check_output( |
|||
'git ls-files -- ":!:configs/*"', shell=True |
|||
).splitlines() |
|||
) |
|||
| set( # hard code, TODO: use config to exclude folders or files |
|||
subprocess.check_output( |
|||
"git ls-files --others --exclude-standard", shell=True |
|||
).splitlines() |
|||
) |
|||
] |
|||
|
|||
@rank_zero_only |
|||
def save_code_snapshot(self): |
|||
os.makedirs(self.savedir, exist_ok=True) |
|||
for f in self.get_file_list(): |
|||
if not os.path.exists(f) or os.path.isdir(f): |
|||
continue |
|||
os.makedirs(os.path.join(self.savedir, os.path.dirname(f)), exist_ok=True) |
|||
shutil.copyfile(f, os.path.join(self.savedir, f)) |
|||
|
|||
def on_fit_start(self, trainer, pl_module): |
|||
try: |
|||
self.save_code_snapshot() |
|||
except: |
|||
rank_zero_warn( |
|||
"Code snapshot is not saved. Please make sure you have git installed and are in a git repository." |
|||
) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
# add cwd for convenience and to make classes in this file available when |
|||
# running as `python main.py` |
|||
sys.path.append(os.getcwd()) |
|||
|
|||
parser = get_parser() |
|||
opt, unknown = parser.parse_known_args() |
|||
|
|||
cfg_fname = os.path.split(opt.base)[-1] |
|||
cfg_name = os.path.splitext(cfg_fname)[0] |
|||
exp_name = "-" + opt.name if opt.name != "" else "" |
|||
logdir = os.path.join(opt.logdir, cfg_name+exp_name) |
|||
|
|||
ckptdir = os.path.join(logdir, "checkpoints") |
|||
cfgdir = os.path.join(logdir, "configs") |
|||
codedir = os.path.join(logdir, "code") |
|||
seed_everything(opt.seed) |
|||
|
|||
# init configs |
|||
config = OmegaConf.load(opt.base) |
|||
lightning_config = config.lightning |
|||
trainer_config = lightning_config.trainer |
|||
|
|||
trainer_config["accelerator"] = "gpu" |
|||
rank_zero_print(f"Running on GPUs {opt.gpus}") |
|||
ngpu = len(opt.gpus.strip(",").split(',')) |
|||
trainer_config['devices'] = ngpu |
|||
|
|||
trainer_opt = argparse.Namespace(**trainer_config) |
|||
lightning_config.trainer = trainer_config |
|||
|
|||
# model |
|||
model = instantiate_from_config(config.model) |
|||
if opt.resume and opt.resume_weights_only: |
|||
model = model.__class__.load_from_checkpoint(opt.resume, **config.model.params) |
|||
|
|||
model.logdir = logdir |
|||
|
|||
# trainer and callbacks |
|||
trainer_kwargs = dict() |
|||
|
|||
# logger |
|||
default_logger_cfg = { |
|||
"target": "pytorch_lightning.loggers.TensorBoardLogger", |
|||
"params": { |
|||
"name": "tensorboard", |
|||
"save_dir": logdir, |
|||
"version": "0", |
|||
} |
|||
} |
|||
logger_cfg = OmegaConf.merge(default_logger_cfg) |
|||
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg) |
|||
|
|||
# model checkpoint |
|||
default_modelckpt_cfg = { |
|||
"target": "pytorch_lightning.callbacks.ModelCheckpoint", |
|||
"params": { |
|||
"dirpath": ckptdir, |
|||
"filename": "{step:08}", |
|||
"verbose": True, |
|||
"save_last": True, |
|||
"every_n_train_steps": 5000, |
|||
"save_top_k": -1, # save all checkpoints |
|||
} |
|||
} |
|||
|
|||
if "modelcheckpoint" in lightning_config: |
|||
modelckpt_cfg = lightning_config.modelcheckpoint |
|||
else: |
|||
modelckpt_cfg = OmegaConf.create() |
|||
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg) |
|||
|
|||
# callbacks |
|||
default_callbacks_cfg = { |
|||
"setup_callback": { |
|||
"target": "train.SetupCallback", |
|||
"params": { |
|||
"resume": opt.resume, |
|||
"logdir": logdir, |
|||
"ckptdir": ckptdir, |
|||
"cfgdir": cfgdir, |
|||
"config": config, |
|||
} |
|||
}, |
|||
"learning_rate_logger": { |
|||
"target": "pytorch_lightning.callbacks.LearningRateMonitor", |
|||
"params": { |
|||
"logging_interval": "step", |
|||
} |
|||
}, |
|||
"code_snapshot": { |
|||
"target": "train.CodeSnapshot", |
|||
"params": { |
|||
"savedir": codedir, |
|||
} |
|||
}, |
|||
} |
|||
default_callbacks_cfg["checkpoint_callback"] = modelckpt_cfg |
|||
|
|||
if "callbacks" in lightning_config: |
|||
callbacks_cfg = lightning_config.callbacks |
|||
else: |
|||
callbacks_cfg = OmegaConf.create() |
|||
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg) |
|||
|
|||
trainer_kwargs["callbacks"] = [ |
|||
instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg] |
|||
|
|||
trainer_kwargs['precision'] = '32-true' |
|||
trainer_kwargs["strategy"] = DDPStrategy(find_unused_parameters=True) |
|||
|
|||
# trainer |
|||
trainer = Trainer(**trainer_config, **trainer_kwargs, num_nodes=opt.num_nodes) |
|||
trainer.logdir = logdir |
|||
|
|||
# data |
|||
data = instantiate_from_config(config.data) |
|||
data.prepare_data() |
|||
data.setup("fit") |
|||
|
|||
# configure learning rate |
|||
base_lr = config.model.base_learning_rate |
|||
if 'accumulate_grad_batches' in lightning_config.trainer: |
|||
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches |
|||
else: |
|||
accumulate_grad_batches = 1 |
|||
rank_zero_print(f"accumulate_grad_batches = {accumulate_grad_batches}") |
|||
lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches |
|||
model.learning_rate = base_lr |
|||
rank_zero_print("++++ NOT USING LR SCALING ++++") |
|||
rank_zero_print(f"Setting learning rate to {model.learning_rate:.2e}") |
|||
|
|||
# run training loop |
|||
if opt.resume and not opt.resume_weights_only: |
|||
trainer.fit(model, data, ckpt_path=opt.resume) |
|||
else: |
|||
trainer.fit(model, data) |
@ -0,0 +1,406 @@ |
|||
from typing import Any, Dict, Optional |
|||
from diffusers.models import AutoencoderKL, UNet2DConditionModel |
|||
from diffusers.schedulers import KarrasDiffusionSchedulers |
|||
|
|||
import numpy |
|||
import torch |
|||
import torch.nn as nn |
|||
import torch.utils.checkpoint |
|||
import torch.distributed |
|||
import transformers |
|||
from collections import OrderedDict |
|||
from PIL import Image |
|||
from torchvision import transforms |
|||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer |
|||
|
|||
import diffusers |
|||
from diffusers import ( |
|||
AutoencoderKL, |
|||
DDPMScheduler, |
|||
DiffusionPipeline, |
|||
EulerAncestralDiscreteScheduler, |
|||
UNet2DConditionModel, |
|||
ImagePipelineOutput |
|||
) |
|||
from diffusers.image_processor import VaeImageProcessor |
|||
from diffusers.models.attention_processor import Attention, AttnProcessor, XFormersAttnProcessor, AttnProcessor2_0 |
|||
from diffusers.utils.import_utils import is_xformers_available |
|||
|
|||
|
|||
def to_rgb_image(maybe_rgba: Image.Image): |
|||
if maybe_rgba.mode == 'RGB': |
|||
return maybe_rgba |
|||
elif maybe_rgba.mode == 'RGBA': |
|||
rgba = maybe_rgba |
|||
img = numpy.random.randint(255, 256, size=[rgba.size[1], rgba.size[0], 3], dtype=numpy.uint8) |
|||
img = Image.fromarray(img, 'RGB') |
|||
img.paste(rgba, mask=rgba.getchannel('A')) |
|||
return img |
|||
else: |
|||
raise ValueError("Unsupported image type.", maybe_rgba.mode) |
|||
|
|||
|
|||
class ReferenceOnlyAttnProc(torch.nn.Module): |
|||
def __init__( |
|||
self, |
|||
chained_proc, |
|||
enabled=False, |
|||
name=None |
|||
) -> None: |
|||
super().__init__() |
|||
self.enabled = enabled |
|||
self.chained_proc = chained_proc |
|||
self.name = name |
|||
|
|||
def __call__( |
|||
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, |
|||
mode="w", ref_dict: dict = None, is_cfg_guidance = False |
|||
) -> Any: |
|||
if encoder_hidden_states is None: |
|||
encoder_hidden_states = hidden_states |
|||
if self.enabled and is_cfg_guidance: |
|||
res0 = self.chained_proc(attn, hidden_states[:1], encoder_hidden_states[:1], attention_mask) |
|||
hidden_states = hidden_states[1:] |
|||
encoder_hidden_states = encoder_hidden_states[1:] |
|||
if self.enabled: |
|||
if mode == 'w': |
|||
ref_dict[self.name] = encoder_hidden_states |
|||
elif mode == 'r': |
|||
encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict.pop(self.name)], dim=1) |
|||
elif mode == 'm': |
|||
encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict[self.name]], dim=1) |
|||
else: |
|||
assert False, mode |
|||
res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask) |
|||
if self.enabled and is_cfg_guidance: |
|||
res = torch.cat([res0, res]) |
|||
return res |
|||
|
|||
|
|||
class RefOnlyNoisedUNet(torch.nn.Module): |
|||
def __init__(self, unet: UNet2DConditionModel, train_sched: DDPMScheduler, val_sched: EulerAncestralDiscreteScheduler) -> None: |
|||
super().__init__() |
|||
self.unet = unet |
|||
self.train_sched = train_sched |
|||
self.val_sched = val_sched |
|||
|
|||
unet_lora_attn_procs = dict() |
|||
for name, _ in unet.attn_processors.items(): |
|||
if torch.__version__ >= '2.0': |
|||
default_attn_proc = AttnProcessor2_0() |
|||
elif is_xformers_available(): |
|||
default_attn_proc = XFormersAttnProcessor() |
|||
else: |
|||
default_attn_proc = AttnProcessor() |
|||
unet_lora_attn_procs[name] = ReferenceOnlyAttnProc( |
|||
default_attn_proc, enabled=name.endswith("attn1.processor"), name=name |
|||
) |
|||
unet.set_attn_processor(unet_lora_attn_procs) |
|||
|
|||
def __getattr__(self, name: str): |
|||
try: |
|||
return super().__getattr__(name) |
|||
except AttributeError: |
|||
return getattr(self.unet, name) |
|||
|
|||
def forward_cond(self, noisy_cond_lat, timestep, encoder_hidden_states, class_labels, ref_dict, is_cfg_guidance, **kwargs): |
|||
if is_cfg_guidance: |
|||
encoder_hidden_states = encoder_hidden_states[1:] |
|||
class_labels = class_labels[1:] |
|||
self.unet( |
|||
noisy_cond_lat, timestep, |
|||
encoder_hidden_states=encoder_hidden_states, |
|||
class_labels=class_labels, |
|||
cross_attention_kwargs=dict(mode="w", ref_dict=ref_dict), |
|||
**kwargs |
|||
) |
|||
|
|||
def forward( |
|||
self, sample, timestep, encoder_hidden_states, class_labels=None, |
|||
*args, cross_attention_kwargs, |
|||
down_block_res_samples=None, mid_block_res_sample=None, |
|||
**kwargs |
|||
): |
|||
cond_lat = cross_attention_kwargs['cond_lat'] |
|||
is_cfg_guidance = cross_attention_kwargs.get('is_cfg_guidance', False) |
|||
noise = torch.randn_like(cond_lat) |
|||
if self.training: |
|||
noisy_cond_lat = self.train_sched.add_noise(cond_lat, noise, timestep) |
|||
noisy_cond_lat = self.train_sched.scale_model_input(noisy_cond_lat, timestep) |
|||
else: |
|||
noisy_cond_lat = self.val_sched.add_noise(cond_lat, noise, timestep.reshape(-1)) |
|||
noisy_cond_lat = self.val_sched.scale_model_input(noisy_cond_lat, timestep.reshape(-1)) |
|||
ref_dict = {} |
|||
self.forward_cond( |
|||
noisy_cond_lat, timestep, |
|||
encoder_hidden_states, class_labels, |
|||
ref_dict, is_cfg_guidance, **kwargs |
|||
) |
|||
weight_dtype = self.unet.dtype |
|||
return self.unet( |
|||
sample, timestep, |
|||
encoder_hidden_states, *args, |
|||
class_labels=class_labels, |
|||
cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict, is_cfg_guidance=is_cfg_guidance), |
|||
down_block_additional_residuals=[ |
|||
sample.to(dtype=weight_dtype) for sample in down_block_res_samples |
|||
] if down_block_res_samples is not None else None, |
|||
mid_block_additional_residual=( |
|||
mid_block_res_sample.to(dtype=weight_dtype) |
|||
if mid_block_res_sample is not None else None |
|||
), |
|||
**kwargs |
|||
) |
|||
|
|||
|
|||
def scale_latents(latents): |
|||
latents = (latents - 0.22) * 0.75 |
|||
return latents |
|||
|
|||
|
|||
def unscale_latents(latents): |
|||
latents = latents / 0.75 + 0.22 |
|||
return latents |
|||
|
|||
|
|||
def scale_image(image): |
|||
image = image * 0.5 / 0.8 |
|||
return image |
|||
|
|||
|
|||
def unscale_image(image): |
|||
image = image / 0.5 * 0.8 |
|||
return image |
|||
|
|||
|
|||
class DepthControlUNet(torch.nn.Module): |
|||
def __init__(self, unet: RefOnlyNoisedUNet, controlnet: Optional[diffusers.ControlNetModel] = None, conditioning_scale=1.0) -> None: |
|||
super().__init__() |
|||
self.unet = unet |
|||
if controlnet is None: |
|||
self.controlnet = diffusers.ControlNetModel.from_unet(unet.unet) |
|||
else: |
|||
self.controlnet = controlnet |
|||
DefaultAttnProc = AttnProcessor2_0 |
|||
if is_xformers_available(): |
|||
DefaultAttnProc = XFormersAttnProcessor |
|||
self.controlnet.set_attn_processor(DefaultAttnProc()) |
|||
self.conditioning_scale = conditioning_scale |
|||
|
|||
def __getattr__(self, name: str): |
|||
try: |
|||
return super().__getattr__(name) |
|||
except AttributeError: |
|||
return getattr(self.unet, name) |
|||
|
|||
def forward(self, sample, timestep, encoder_hidden_states, class_labels=None, *args, cross_attention_kwargs: dict, **kwargs): |
|||
cross_attention_kwargs = dict(cross_attention_kwargs) |
|||
control_depth = cross_attention_kwargs.pop('control_depth') |
|||
down_block_res_samples, mid_block_res_sample = self.controlnet( |
|||
sample, |
|||
timestep, |
|||
encoder_hidden_states=encoder_hidden_states, |
|||
controlnet_cond=control_depth, |
|||
conditioning_scale=self.conditioning_scale, |
|||
return_dict=False, |
|||
) |
|||
return self.unet( |
|||
sample, |
|||
timestep, |
|||
encoder_hidden_states=encoder_hidden_states, |
|||
down_block_res_samples=down_block_res_samples, |
|||
mid_block_res_sample=mid_block_res_sample, |
|||
cross_attention_kwargs=cross_attention_kwargs |
|||
) |
|||
|
|||
|
|||
class ModuleListDict(torch.nn.Module): |
|||
def __init__(self, procs: dict) -> None: |
|||
super().__init__() |
|||
self.keys = sorted(procs.keys()) |
|||
self.values = torch.nn.ModuleList(procs[k] for k in self.keys) |
|||
|
|||
def __getitem__(self, key): |
|||
return self.values[self.keys.index(key)] |
|||
|
|||
|
|||
class SuperNet(torch.nn.Module): |
|||
def __init__(self, state_dict: Dict[str, torch.Tensor]): |
|||
super().__init__() |
|||
state_dict = OrderedDict((k, state_dict[k]) for k in sorted(state_dict.keys())) |
|||
self.layers = torch.nn.ModuleList(state_dict.values()) |
|||
self.mapping = dict(enumerate(state_dict.keys())) |
|||
self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} |
|||
|
|||
# .processor for unet, .self_attn for text encoder |
|||
self.split_keys = [".processor", ".self_attn"] |
|||
|
|||
# we add a hook to state_dict() and load_state_dict() so that the |
|||
# naming fits with `unet.attn_processors` |
|||
def map_to(module, state_dict, *args, **kwargs): |
|||
new_state_dict = {} |
|||
for key, value in state_dict.items(): |
|||
num = int(key.split(".")[1]) # 0 is always "layers" |
|||
new_key = key.replace(f"layers.{num}", module.mapping[num]) |
|||
new_state_dict[new_key] = value |
|||
|
|||
return new_state_dict |
|||
|
|||
def remap_key(key, state_dict): |
|||
for k in self.split_keys: |
|||
if k in key: |
|||
return key.split(k)[0] + k |
|||
return key.split('.')[0] |
|||
|
|||
def map_from(module, state_dict, *args, **kwargs): |
|||
all_keys = list(state_dict.keys()) |
|||
for key in all_keys: |
|||
replace_key = remap_key(key, state_dict) |
|||
new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}") |
|||
state_dict[new_key] = state_dict[key] |
|||
del state_dict[key] |
|||
|
|||
self._register_state_dict_hook(map_to) |
|||
self._register_load_state_dict_pre_hook(map_from, with_module=True) |
|||
|
|||
|
|||
class Zero123PlusPipeline(diffusers.StableDiffusionPipeline): |
|||
tokenizer: transformers.CLIPTokenizer |
|||
text_encoder: transformers.CLIPTextModel |
|||
vision_encoder: transformers.CLIPVisionModelWithProjection |
|||
|
|||
feature_extractor_clip: transformers.CLIPImageProcessor |
|||
unet: UNet2DConditionModel |
|||
scheduler: diffusers.schedulers.KarrasDiffusionSchedulers |
|||
|
|||
vae: AutoencoderKL |
|||
ramping: nn.Linear |
|||
|
|||
feature_extractor_vae: transformers.CLIPImageProcessor |
|||
|
|||
depth_transforms_multi = transforms.Compose([ |
|||
transforms.ToTensor(), |
|||
transforms.Normalize([0.5], [0.5]) |
|||
]) |
|||
|
|||
def __init__( |
|||
self, |
|||
vae: AutoencoderKL, |
|||
text_encoder: CLIPTextModel, |
|||
tokenizer: CLIPTokenizer, |
|||
unet: UNet2DConditionModel, |
|||
scheduler: KarrasDiffusionSchedulers, |
|||
vision_encoder: transformers.CLIPVisionModelWithProjection, |
|||
feature_extractor_clip: CLIPImageProcessor, |
|||
feature_extractor_vae: CLIPImageProcessor, |
|||
ramping_coefficients: Optional[list] = None, |
|||
safety_checker=None, |
|||
): |
|||
DiffusionPipeline.__init__(self) |
|||
|
|||
self.register_modules( |
|||
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, |
|||
unet=unet, scheduler=scheduler, safety_checker=None, |
|||
vision_encoder=vision_encoder, |
|||
feature_extractor_clip=feature_extractor_clip, |
|||
feature_extractor_vae=feature_extractor_vae |
|||
) |
|||
self.register_to_config(ramping_coefficients=ramping_coefficients) |
|||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) |
|||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) |
|||
|
|||
def prepare(self): |
|||
train_sched = DDPMScheduler.from_config(self.scheduler.config) |
|||
if isinstance(self.unet, UNet2DConditionModel): |
|||
self.unet = RefOnlyNoisedUNet(self.unet, train_sched, self.scheduler).eval() |
|||
|
|||
def add_controlnet(self, controlnet: Optional[diffusers.ControlNetModel] = None, conditioning_scale=1.0): |
|||
self.prepare() |
|||
self.unet = DepthControlUNet(self.unet, controlnet, conditioning_scale) |
|||
return SuperNet(OrderedDict([('controlnet', self.unet.controlnet)])) |
|||
|
|||
def encode_condition_image(self, image: torch.Tensor): |
|||
image = self.vae.encode(image).latent_dist.sample() |
|||
return image |
|||
|
|||
@torch.no_grad() |
|||
def __call__( |
|||
self, |
|||
image: Image.Image = None, |
|||
prompt = "", |
|||
*args, |
|||
num_images_per_prompt: Optional[int] = 1, |
|||
guidance_scale=4.0, |
|||
depth_image: Image.Image = None, |
|||
output_type: Optional[str] = "pil", |
|||
width=640, |
|||
height=960, |
|||
num_inference_steps=28, |
|||
return_dict=True, |
|||
**kwargs |
|||
): |
|||
self.prepare() |
|||
if image is None: |
|||
raise ValueError("Inputting embeddings not supported for this pipeline. Please pass an image.") |
|||
assert not isinstance(image, torch.Tensor) |
|||
image = to_rgb_image(image) |
|||
image_1 = self.feature_extractor_vae(images=image, return_tensors="pt").pixel_values |
|||
image_2 = self.feature_extractor_clip(images=image, return_tensors="pt").pixel_values |
|||
if depth_image is not None and hasattr(self.unet, "controlnet"): |
|||
depth_image = to_rgb_image(depth_image) |
|||
depth_image = self.depth_transforms_multi(depth_image).to( |
|||
device=self.unet.controlnet.device, dtype=self.unet.controlnet.dtype |
|||
) |
|||
image = image_1.to(device=self.vae.device, dtype=self.vae.dtype) |
|||
image_2 = image_2.to(device=self.vae.device, dtype=self.vae.dtype) |
|||
cond_lat = self.encode_condition_image(image) |
|||
if guidance_scale > 1: |
|||
negative_lat = self.encode_condition_image(torch.zeros_like(image)) |
|||
cond_lat = torch.cat([negative_lat, cond_lat]) |
|||
encoded = self.vision_encoder(image_2, output_hidden_states=False) |
|||
global_embeds = encoded.image_embeds |
|||
global_embeds = global_embeds.unsqueeze(-2) |
|||
|
|||
if hasattr(self, "encode_prompt"): |
|||
encoder_hidden_states = self.encode_prompt( |
|||
prompt, |
|||
self.device, |
|||
num_images_per_prompt, |
|||
False |
|||
)[0] |
|||
else: |
|||
encoder_hidden_states = self._encode_prompt( |
|||
prompt, |
|||
self.device, |
|||
num_images_per_prompt, |
|||
False |
|||
) |
|||
ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1) |
|||
encoder_hidden_states = encoder_hidden_states + global_embeds * ramp |
|||
cak = dict(cond_lat=cond_lat) |
|||
if hasattr(self.unet, "controlnet"): |
|||
cak['control_depth'] = depth_image |
|||
latents: torch.Tensor = super().__call__( |
|||
None, |
|||
*args, |
|||
cross_attention_kwargs=cak, |
|||
guidance_scale=guidance_scale, |
|||
num_images_per_prompt=num_images_per_prompt, |
|||
prompt_embeds=encoder_hidden_states, |
|||
num_inference_steps=num_inference_steps, |
|||
output_type='latent', |
|||
width=width, |
|||
height=height, |
|||
**kwargs |
|||
).images |
|||
latents = unscale_latents(latents) |
|||
if not output_type == "latent": |
|||
image = unscale_image(self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]) |
|||
else: |
|||
image = latents |
|||
|
|||
image = self.image_processor.postprocess(image, output_type=output_type) |
|||
if not return_dict: |
|||
return (image,) |
|||
|
|||
return ImagePipelineOutput(images=image) |
Loading…
Reference in new issue