Browse Source

update dataloader

main
bluestyle97 9 months ago
parent
commit
6f7c4f97c6
  1. 20
      src/data/objaverse.py

20
src/data/objaverse.py

@ -1,9 +1,11 @@
import os import os, sys
import math import math
import json import json
import importlib
from pathlib import Path from pathlib import Path
import cv2 import cv2
import random
import numpy as np import numpy as np
from PIL import Image from PIL import Image
import webdataset as wds import webdataset as wds
@ -20,7 +22,7 @@ from src.utils.train_util import instantiate_from_config
from src.utils.camera_util import ( from src.utils.camera_util import (
FOV_to_intrinsics, FOV_to_intrinsics,
center_looking_at_camera_pose, center_looking_at_camera_pose,
get_surrounding_views, get_circular_camera_poses,
) )
@ -97,7 +99,7 @@ class ObjaverseData(Dataset):
paths = filtered_dict['good_objs'] paths = filtered_dict['good_objs']
self.paths = paths self.paths = paths
self.depth_scale = 4.0 self.depth_scale = 6.0
total_objects = len(self.paths) total_objects = len(self.paths)
print('============= length of dataset %d =============' % len(self.paths)) print('============= length of dataset %d =============' % len(self.paths))
@ -222,7 +224,7 @@ class ObjaverseData(Dataset):
'input_alphas': alphas[:self.input_view_num], # (6, 1, H, W) 'input_alphas': alphas[:self.input_view_num], # (6, 1, H, W)
'input_depths': depths[:self.input_view_num], # (6, 1, H, W) 'input_depths': depths[:self.input_view_num], # (6, 1, H, W)
'input_normals': normals[:self.input_view_num], # (6, 3, H, W) 'input_normals': normals[:self.input_view_num], # (6, 3, H, W)
'input_c2ws': c2ws_input[:self.input_view_num], # (6, 4, 4) 'input_c2ws': c2ws[:self.input_view_num], # (6, 4, 4)
'input_Ks': Ks[:self.input_view_num], # (6, 3, 3) 'input_Ks': Ks[:self.input_view_num], # (6, 3, 3)
# lrm generator input and supervision # lrm generator input and supervision
@ -240,8 +242,8 @@ class ValidationData(Dataset):
def __init__(self, def __init__(self,
root_dir='objaverse/', root_dir='objaverse/',
input_view_num=6, input_view_num=6,
input_image_size=256, input_image_size=320,
fov=50, fov=30,
): ):
self.root_dir = Path(root_dir) self.root_dir = Path(root_dir)
self.input_view_num = input_view_num self.input_view_num = input_view_num
@ -251,9 +253,9 @@ class ValidationData(Dataset):
self.paths = sorted(os.listdir(self.root_dir)) self.paths = sorted(os.listdir(self.root_dir))
print('============= length of dataset %d =============' % len(self.paths)) print('============= length of dataset %d =============' % len(self.paths))
cam_distance = 2.5 cam_distance = 4.0
azimuths = np.array([30, 90, 150, 210, 270, 330]) azimuths = np.array([30, 90, 150, 210, 270, 330])
elevations = np.array([30, -20, 30, -20, 30, -20]) elevations = np.array([20, -10, 20, -10, 20, -10])
azimuths = np.deg2rad(azimuths) azimuths = np.deg2rad(azimuths)
elevations = np.deg2rad(elevations) elevations = np.deg2rad(elevations)
@ -267,7 +269,7 @@ class ValidationData(Dataset):
self.c2ws = c2ws.float() self.c2ws = c2ws.float()
self.Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(6, 1, 1).float() self.Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(6, 1, 1).float()
render_c2ws = get_surrounding_views(M=8, radius=cam_distance) render_c2ws = get_circular_camera_poses(M=8, radius=cam_distance, elevation=20.0)
render_Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(render_c2ws.shape[0], 1, 1) render_Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(render_c2ws.shape[0], 1, 1)
self.render_c2ws = render_c2ws.float() self.render_c2ws = render_c2ws.float()
self.render_Ks = render_Ks.float() self.render_Ks = render_Ks.float()

Loading…
Cancel
Save