Browse Source

add training config

main
bluestyle97 6 months ago
parent
commit
0d335d074d
  1. 68
      configs/instant-mesh-large-train.yaml
  2. 66
      configs/instant-nerf-large-train.yaml
  3. 23
      src/data/objaverse.py
  4. 11
      src/model.py

68
configs/instant-mesh-large-train.yaml

@ -0,0 +1,68 @@
model:
base_learning_rate: 4.0e-05
target: src.model_mesh.MVRecon
params:
init_ckpt: logs/instant-mesh-large/checkpoints/last.ckpt
input_size: 320
render_size: 512
lrm_generator_config:
target: src.models.lrm_mesh.InstantMesh
params:
camera_embed_dim: 1024
rendering_samples_per_ray: 128
transformer_dim: 1024
transformer_layers: 16
transformer_heads: 16
triplane_low_res: 32
triplane_high_res: 64
triplane_dim: 80
encoder_feat_dim: 768
encoder_freeze: false
encoder_model_name: facebook/dino-vitb16
grid_res: 128
grid_scale: 2.1
data:
target: src.data.objaverse.DataModuleFromConfig
params:
batch_size: 2
num_workers: 8
train:
target: src.data.objaverse.ObjaverseData
params:
root_dir: data/objaverse
meta_fname: filtered_obj_name.json
input_image_dir: rendering_random_32views
target_image_dir: rendering_random_32views
input_view_num: 6
target_view_num: 4
total_view_n: 32
fov: 50
camera_rotation: true
validation: false
validation:
target: src.data.objaverse.ValidationData
params:
root_dir: data/valid_samples
input_view_num: 6
input_image_size: 320
fov: 30
lightning:
modelcheckpoint:
params:
every_n_train_steps: 2000
save_top_k: -1
save_last: true
callbacks: {}
trainer:
benchmark: true
max_epochs: -1
val_check_interval: 1000
num_sanity_val_steps: 0
accumulate_grad_batches: 1
check_val_every_n_epoch: null # if not set this, validation does not run

66
configs/instant-nerf-large-train.yaml

@ -0,0 +1,66 @@
model:
base_learning_rate: 4.0e-04
target: src.model.MVRecon
params:
input_size: 320
render_size: 192
lrm_generator_config:
target: src.models.lrm.InstantNeRF
params:
camera_embed_dim: 1024
rendering_samples_per_ray: 128
transformer_dim: 1024
transformer_layers: 16
transformer_heads: 16
triplane_low_res: 32
triplane_high_res: 64
triplane_dim: 80
encoder_feat_dim: 768
encoder_freeze: false
encoder_model_name: facebook/dino-vitb16
data:
target: src.data.objaverse.DataModuleFromConfig
params:
batch_size: 2
num_workers: 8
train:
target: src.data.objaverse.ObjaverseData
params:
root_dir: data/objaverse
meta_fname: filtered_obj_name.json
input_image_dir: rendering_random_32views
target_image_dir: rendering_random_32views
input_view_num: 6
target_view_num: 4
total_view_n: 32
fov: 50
camera_rotation: true
validation: false
validation:
target: src.data.objaverse.ValidationData
params:
root_dir: data/valid_samples
input_view_num: 6
input_image_size: 320
fov: 30
lightning:
modelcheckpoint:
params:
every_n_train_steps: 1000
save_top_k: -1
save_last: true
callbacks: {}
trainer:
benchmark: true
max_epochs: -1
gradient_clip_val: 1.0
val_check_interval: 1000
num_sanity_val_steps: 0
accumulate_grad_batches: 1
check_val_every_n_epoch: null # if not set this, validation does not run

23
src/data/objaverse.py

@ -20,7 +20,7 @@ from src.utils.train_util import instantiate_from_config
from src.utils.camera_util import ( from src.utils.camera_util import (
FOV_to_intrinsics, FOV_to_intrinsics,
center_looking_at_camera_pose, center_looking_at_camera_pose,
get_circular_camera_poses, get_surrounding_views,
) )
@ -76,7 +76,7 @@ class ObjaverseData(Dataset):
input_image_dir='rendering_random_32views', input_image_dir='rendering_random_32views',
target_image_dir='rendering_random_32views', target_image_dir='rendering_random_32views',
input_view_num=6, input_view_num=6,
target_view_num=2, target_view_num=4,
total_view_n=32, total_view_n=32,
fov=50, fov=50,
camera_rotation=True, camera_rotation=True,
@ -97,7 +97,7 @@ class ObjaverseData(Dataset):
paths = filtered_dict['good_objs'] paths = filtered_dict['good_objs']
self.paths = paths self.paths = paths
self.depth_scale = 5.0 self.depth_scale = 4.0
total_objects = len(self.paths) total_objects = len(self.paths)
print('============= length of dataset %d =============' % len(self.paths)) print('============= length of dataset %d =============' % len(self.paths))
@ -120,7 +120,6 @@ class ObjaverseData(Dataset):
return image, alpha return image, alpha
def __getitem__(self, index): def __getitem__(self, index):
# load data
while True: while True:
input_image_path = os.path.join(self.root_dir, self.input_image_dir, self.paths[index]) input_image_path = os.path.join(self.root_dir, self.input_image_dir, self.paths[index])
target_image_path = os.path.join(self.root_dir, self.target_image_dir, self.paths[index]) target_image_path = os.path.join(self.root_dir, self.target_image_dir, self.paths[index])
@ -210,7 +209,7 @@ class ObjaverseData(Dataset):
# random scaling # random scaling
if np.random.rand() < 0.5: if np.random.rand() < 0.5:
scale = np.random.uniform(0.8, 1.0) scale = np.random.uniform(0.7, 1.1)
c2ws[:, :3, 3] *= scale c2ws[:, :3, 3] *= scale
depths *= scale depths *= scale
@ -219,11 +218,11 @@ class ObjaverseData(Dataset):
Ks = K.unsqueeze(0).repeat(self.input_view_num + self.target_view_num, 1, 1).float() Ks = K.unsqueeze(0).repeat(self.input_view_num + self.target_view_num, 1, 1).float()
data = { data = {
'input_images': images[:self.input_view_num], # (6, 3, H, W) 'input_images': images[:self.input_view_num], # (6, 3, H, W)
'input_alphas': alphas[:self.input_view_num], # (6, 1, H, W) 'input_alphas': alphas[:self.input_view_num], # (6, 1, H, W)
'input_depths': depths[:self.input_view_num], # (6, 1, H, W) 'input_depths': depths[:self.input_view_num], # (6, 1, H, W)
'input_normals': normals[:self.input_view_num], # (6, 3, H, W) 'input_normals': normals[:self.input_view_num], # (6, 3, H, W)
'input_c2ws': c2ws[:self.input_view_num], # (6, 4, 4) 'input_c2ws': c2ws_input[:self.input_view_num], # (6, 4, 4)
'input_Ks': Ks[:self.input_view_num], # (6, 3, 3) 'input_Ks': Ks[:self.input_view_num], # (6, 3, 3)
# lrm generator input and supervision # lrm generator input and supervision
@ -233,8 +232,6 @@ class ObjaverseData(Dataset):
'target_normals': normals[self.input_view_num:], # (V, 3, H, W) 'target_normals': normals[self.input_view_num:], # (V, 3, H, W)
'target_c2ws': c2ws[self.input_view_num:], # (V, 4, 4) 'target_c2ws': c2ws[self.input_view_num:], # (V, 4, 4)
'target_Ks': Ks[self.input_view_num:], # (V, 3, 3) 'target_Ks': Ks[self.input_view_num:], # (V, 3, 3)
'depth_available': 1,
} }
return data return data
@ -244,7 +241,7 @@ class ValidationData(Dataset):
root_dir='objaverse/', root_dir='objaverse/',
input_view_num=6, input_view_num=6,
input_image_size=256, input_image_size=256,
fov=30, fov=50,
): ):
self.root_dir = Path(root_dir) self.root_dir = Path(root_dir)
self.input_view_num = input_view_num self.input_view_num = input_view_num
@ -254,9 +251,9 @@ class ValidationData(Dataset):
self.paths = sorted(os.listdir(self.root_dir)) self.paths = sorted(os.listdir(self.root_dir))
print('============= length of dataset %d =============' % len(self.paths)) print('============= length of dataset %d =============' % len(self.paths))
cam_distance = 4.0 cam_distance = 2.5
azimuths = np.array([30, 90, 150, 210, 270, 330]) azimuths = np.array([30, 90, 150, 210, 270, 330])
elevations = np.array([20, -10, 20, -10, 20, -10]) elevations = np.array([30, -20, 30, -20, 30, -20])
azimuths = np.deg2rad(azimuths) azimuths = np.deg2rad(azimuths)
elevations = np.deg2rad(elevations) elevations = np.deg2rad(elevations)
@ -270,7 +267,7 @@ class ValidationData(Dataset):
self.c2ws = c2ws.float() self.c2ws = c2ws.float()
self.Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(6, 1, 1).float() self.Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(6, 1, 1).float()
render_c2ws = get_circular_camera_poses(M=8, radius=cam_distance, elevation=20.0) render_c2ws = get_surrounding_views(M=8, radius=cam_distance)
render_Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(render_c2ws.shape[0], 1, 1) render_Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(render_c2ws.shape[0], 1, 1)
self.render_c2ws = render_c2ws.float() self.render_c2ws = render_c2ws.float()
self.render_Ks = render_Ks.float() self.render_Ks = render_Ks.float()

11
src/model.py

@ -295,16 +295,9 @@ class MVRecon(pl.LightningModule):
params = [] params = []
lrm_params_fast, lrm_params_slow = [], [] params.append({"params": self.lrm_generator.parameters(), "lr": lr, "weight_decay": 0.01 })
for n, p in self.lrm_generator.named_parameters():
if 'adaLN_modulation' in n or 'camera_embedder' in n:
lrm_params_fast.append(p)
else:
lrm_params_slow.append(p)
params.append({"params": lrm_params_fast, "lr": lr, "weight_decay": 0.01 })
params.append({"params": lrm_params_slow, "lr": lr / 10.0, "weight_decay": 0.01 })
optimizer = torch.optim.AdamW(params, lr=lr, betas=(0.90, 0.95)) optimizer = torch.optim.AdamW(params, lr=lr, betas=(0.90, 0.95))
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 3000, eta_min=lr/4) scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 3000, eta_min=lr/10)
return {'optimizer': optimizer, 'lr_scheduler': scheduler} return {'optimizer': optimizer, 'lr_scheduler': scheduler}

Loading…
Cancel
Save