From 0d335d074d355358f13314e3edbd73afb00ba142 Mon Sep 17 00:00:00 2001 From: bluestyle97 Date: Thu, 11 Apr 2024 21:12:54 +0800 Subject: [PATCH] add training config --- configs/instant-mesh-large-train.yaml | 68 +++++++++++++++++++++++++++ configs/instant-nerf-large-train.yaml | 66 ++++++++++++++++++++++++++ src/data/objaverse.py | 23 ++++----- src/model.py | 11 +---- 4 files changed, 146 insertions(+), 22 deletions(-) create mode 100644 configs/instant-mesh-large-train.yaml create mode 100644 configs/instant-nerf-large-train.yaml diff --git a/configs/instant-mesh-large-train.yaml b/configs/instant-mesh-large-train.yaml new file mode 100644 index 0000000..b5d99a2 --- /dev/null +++ b/configs/instant-mesh-large-train.yaml @@ -0,0 +1,68 @@ +model: + base_learning_rate: 4.0e-05 + target: src.model_mesh.MVRecon + params: + init_ckpt: logs/instant-mesh-large/checkpoints/last.ckpt + input_size: 320 + render_size: 512 + + lrm_generator_config: + target: src.models.lrm_mesh.InstantMesh + params: + camera_embed_dim: 1024 + rendering_samples_per_ray: 128 + transformer_dim: 1024 + transformer_layers: 16 + transformer_heads: 16 + triplane_low_res: 32 + triplane_high_res: 64 + triplane_dim: 80 + encoder_feat_dim: 768 + encoder_freeze: false + encoder_model_name: facebook/dino-vitb16 + grid_res: 128 + grid_scale: 2.1 + + +data: + target: src.data.objaverse.DataModuleFromConfig + params: + batch_size: 2 + num_workers: 8 + train: + target: src.data.objaverse.ObjaverseData + params: + root_dir: data/objaverse + meta_fname: filtered_obj_name.json + input_image_dir: rendering_random_32views + target_image_dir: rendering_random_32views + input_view_num: 6 + target_view_num: 4 + total_view_n: 32 + fov: 50 + camera_rotation: true + validation: false + validation: + target: src.data.objaverse.ValidationData + params: + root_dir: data/valid_samples + input_view_num: 6 + input_image_size: 320 + fov: 30 + + +lightning: + modelcheckpoint: + params: + every_n_train_steps: 2000 + save_top_k: -1 + save_last: true + callbacks: {} + + trainer: + benchmark: true + max_epochs: -1 + val_check_interval: 1000 + num_sanity_val_steps: 0 + accumulate_grad_batches: 1 + check_val_every_n_epoch: null # if not set this, validation does not run diff --git a/configs/instant-nerf-large-train.yaml b/configs/instant-nerf-large-train.yaml new file mode 100644 index 0000000..3544f8e --- /dev/null +++ b/configs/instant-nerf-large-train.yaml @@ -0,0 +1,66 @@ +model: + base_learning_rate: 4.0e-04 + target: src.model.MVRecon + params: + input_size: 320 + render_size: 192 + + lrm_generator_config: + target: src.models.lrm.InstantNeRF + params: + camera_embed_dim: 1024 + rendering_samples_per_ray: 128 + transformer_dim: 1024 + transformer_layers: 16 + transformer_heads: 16 + triplane_low_res: 32 + triplane_high_res: 64 + triplane_dim: 80 + encoder_feat_dim: 768 + encoder_freeze: false + encoder_model_name: facebook/dino-vitb16 + + +data: + target: src.data.objaverse.DataModuleFromConfig + params: + batch_size: 2 + num_workers: 8 + train: + target: src.data.objaverse.ObjaverseData + params: + root_dir: data/objaverse + meta_fname: filtered_obj_name.json + input_image_dir: rendering_random_32views + target_image_dir: rendering_random_32views + input_view_num: 6 + target_view_num: 4 + total_view_n: 32 + fov: 50 + camera_rotation: true + validation: false + validation: + target: src.data.objaverse.ValidationData + params: + root_dir: data/valid_samples + input_view_num: 6 + input_image_size: 320 + fov: 30 + + +lightning: + modelcheckpoint: + params: + every_n_train_steps: 1000 + save_top_k: -1 + save_last: true + callbacks: {} + + trainer: + benchmark: true + max_epochs: -1 + gradient_clip_val: 1.0 + val_check_interval: 1000 + num_sanity_val_steps: 0 + accumulate_grad_batches: 1 + check_val_every_n_epoch: null # if not set this, validation does not run diff --git a/src/data/objaverse.py b/src/data/objaverse.py index be7c12b..37491f0 100644 --- a/src/data/objaverse.py +++ b/src/data/objaverse.py @@ -20,7 +20,7 @@ from src.utils.train_util import instantiate_from_config from src.utils.camera_util import ( FOV_to_intrinsics, center_looking_at_camera_pose, - get_circular_camera_poses, + get_surrounding_views, ) @@ -76,7 +76,7 @@ class ObjaverseData(Dataset): input_image_dir='rendering_random_32views', target_image_dir='rendering_random_32views', input_view_num=6, - target_view_num=2, + target_view_num=4, total_view_n=32, fov=50, camera_rotation=True, @@ -97,7 +97,7 @@ class ObjaverseData(Dataset): paths = filtered_dict['good_objs'] self.paths = paths - self.depth_scale = 5.0 + self.depth_scale = 4.0 total_objects = len(self.paths) print('============= length of dataset %d =============' % len(self.paths)) @@ -120,7 +120,6 @@ class ObjaverseData(Dataset): return image, alpha def __getitem__(self, index): - # load data while True: input_image_path = os.path.join(self.root_dir, self.input_image_dir, self.paths[index]) target_image_path = os.path.join(self.root_dir, self.target_image_dir, self.paths[index]) @@ -210,7 +209,7 @@ class ObjaverseData(Dataset): # random scaling if np.random.rand() < 0.5: - scale = np.random.uniform(0.8, 1.0) + scale = np.random.uniform(0.7, 1.1) c2ws[:, :3, 3] *= scale depths *= scale @@ -219,11 +218,11 @@ class ObjaverseData(Dataset): Ks = K.unsqueeze(0).repeat(self.input_view_num + self.target_view_num, 1, 1).float() data = { - 'input_images': images[:self.input_view_num], # (6, 3, H, W) + 'input_images': images[:self.input_view_num], # (6, 3, H, W) 'input_alphas': alphas[:self.input_view_num], # (6, 1, H, W) 'input_depths': depths[:self.input_view_num], # (6, 1, H, W) 'input_normals': normals[:self.input_view_num], # (6, 3, H, W) - 'input_c2ws': c2ws[:self.input_view_num], # (6, 4, 4) + 'input_c2ws': c2ws_input[:self.input_view_num], # (6, 4, 4) 'input_Ks': Ks[:self.input_view_num], # (6, 3, 3) # lrm generator input and supervision @@ -233,8 +232,6 @@ class ObjaverseData(Dataset): 'target_normals': normals[self.input_view_num:], # (V, 3, H, W) 'target_c2ws': c2ws[self.input_view_num:], # (V, 4, 4) 'target_Ks': Ks[self.input_view_num:], # (V, 3, 3) - - 'depth_available': 1, } return data @@ -244,7 +241,7 @@ class ValidationData(Dataset): root_dir='objaverse/', input_view_num=6, input_image_size=256, - fov=30, + fov=50, ): self.root_dir = Path(root_dir) self.input_view_num = input_view_num @@ -254,9 +251,9 @@ class ValidationData(Dataset): self.paths = sorted(os.listdir(self.root_dir)) print('============= length of dataset %d =============' % len(self.paths)) - cam_distance = 4.0 + cam_distance = 2.5 azimuths = np.array([30, 90, 150, 210, 270, 330]) - elevations = np.array([20, -10, 20, -10, 20, -10]) + elevations = np.array([30, -20, 30, -20, 30, -20]) azimuths = np.deg2rad(azimuths) elevations = np.deg2rad(elevations) @@ -270,7 +267,7 @@ class ValidationData(Dataset): self.c2ws = c2ws.float() self.Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(6, 1, 1).float() - render_c2ws = get_circular_camera_poses(M=8, radius=cam_distance, elevation=20.0) + render_c2ws = get_surrounding_views(M=8, radius=cam_distance) render_Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(render_c2ws.shape[0], 1, 1) self.render_c2ws = render_c2ws.float() self.render_Ks = render_Ks.float() diff --git a/src/model.py b/src/model.py index 584a6dc..2e8fcac 100644 --- a/src/model.py +++ b/src/model.py @@ -295,16 +295,9 @@ class MVRecon(pl.LightningModule): params = [] - lrm_params_fast, lrm_params_slow = [], [] - for n, p in self.lrm_generator.named_parameters(): - if 'adaLN_modulation' in n or 'camera_embedder' in n: - lrm_params_fast.append(p) - else: - lrm_params_slow.append(p) - params.append({"params": lrm_params_fast, "lr": lr, "weight_decay": 0.01 }) - params.append({"params": lrm_params_slow, "lr": lr / 10.0, "weight_decay": 0.01 }) + params.append({"params": self.lrm_generator.parameters(), "lr": lr, "weight_decay": 0.01 }) optimizer = torch.optim.AdamW(params, lr=lr, betas=(0.90, 0.95)) - scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 3000, eta_min=lr/4) + scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 3000, eta_min=lr/10) return {'optimizer': optimizer, 'lr_scheduler': scheduler}