Browse Source

update dataloader

main
bluestyle97 9 months ago
parent
commit
6f7c4f97c6
  1. 20
      src/data/objaverse.py

20
src/data/objaverse.py

@ -1,9 +1,11 @@
import os
import os, sys
import math
import json
import importlib
from pathlib import Path
import cv2
import random
import numpy as np
from PIL import Image
import webdataset as wds
@ -20,7 +22,7 @@ from src.utils.train_util import instantiate_from_config
from src.utils.camera_util import (
FOV_to_intrinsics,
center_looking_at_camera_pose,
get_surrounding_views,
get_circular_camera_poses,
)
@ -97,7 +99,7 @@ class ObjaverseData(Dataset):
paths = filtered_dict['good_objs']
self.paths = paths
self.depth_scale = 4.0
self.depth_scale = 6.0
total_objects = len(self.paths)
print('============= length of dataset %d =============' % len(self.paths))
@ -222,7 +224,7 @@ class ObjaverseData(Dataset):
'input_alphas': alphas[:self.input_view_num], # (6, 1, H, W)
'input_depths': depths[:self.input_view_num], # (6, 1, H, W)
'input_normals': normals[:self.input_view_num], # (6, 3, H, W)
'input_c2ws': c2ws_input[:self.input_view_num], # (6, 4, 4)
'input_c2ws': c2ws[:self.input_view_num], # (6, 4, 4)
'input_Ks': Ks[:self.input_view_num], # (6, 3, 3)
# lrm generator input and supervision
@ -240,8 +242,8 @@ class ValidationData(Dataset):
def __init__(self,
root_dir='objaverse/',
input_view_num=6,
input_image_size=256,
fov=50,
input_image_size=320,
fov=30,
):
self.root_dir = Path(root_dir)
self.input_view_num = input_view_num
@ -251,9 +253,9 @@ class ValidationData(Dataset):
self.paths = sorted(os.listdir(self.root_dir))
print('============= length of dataset %d =============' % len(self.paths))
cam_distance = 2.5
cam_distance = 4.0
azimuths = np.array([30, 90, 150, 210, 270, 330])
elevations = np.array([30, -20, 30, -20, 30, -20])
elevations = np.array([20, -10, 20, -10, 20, -10])
azimuths = np.deg2rad(azimuths)
elevations = np.deg2rad(elevations)
@ -267,7 +269,7 @@ class ValidationData(Dataset):
self.c2ws = c2ws.float()
self.Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(6, 1, 1).float()
render_c2ws = get_surrounding_views(M=8, radius=cam_distance)
render_c2ws = get_circular_camera_poses(M=8, radius=cam_distance, elevation=20.0)
render_Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(render_c2ws.shape[0], 1, 1)
self.render_c2ws = render_c2ws.float()
self.render_Ks = render_Ks.float()

Loading…
Cancel
Save