From 6f7c4f97c661444e7132a735fcdf7e7516d64905 Mon Sep 17 00:00:00 2001 From: bluestyle97 Date: Fri, 12 Apr 2024 12:26:31 +0800 Subject: [PATCH] update dataloader --- src/data/objaverse.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/data/objaverse.py b/src/data/objaverse.py index 37491f0..7329a0f 100644 --- a/src/data/objaverse.py +++ b/src/data/objaverse.py @@ -1,9 +1,11 @@ -import os +import os, sys import math import json +import importlib from pathlib import Path import cv2 +import random import numpy as np from PIL import Image import webdataset as wds @@ -20,7 +22,7 @@ from src.utils.train_util import instantiate_from_config from src.utils.camera_util import ( FOV_to_intrinsics, center_looking_at_camera_pose, - get_surrounding_views, + get_circular_camera_poses, ) @@ -97,7 +99,7 @@ class ObjaverseData(Dataset): paths = filtered_dict['good_objs'] self.paths = paths - self.depth_scale = 4.0 + self.depth_scale = 6.0 total_objects = len(self.paths) print('============= length of dataset %d =============' % len(self.paths)) @@ -222,7 +224,7 @@ class ObjaverseData(Dataset): 'input_alphas': alphas[:self.input_view_num], # (6, 1, H, W) 'input_depths': depths[:self.input_view_num], # (6, 1, H, W) 'input_normals': normals[:self.input_view_num], # (6, 3, H, W) - 'input_c2ws': c2ws_input[:self.input_view_num], # (6, 4, 4) + 'input_c2ws': c2ws[:self.input_view_num], # (6, 4, 4) 'input_Ks': Ks[:self.input_view_num], # (6, 3, 3) # lrm generator input and supervision @@ -240,8 +242,8 @@ class ValidationData(Dataset): def __init__(self, root_dir='objaverse/', input_view_num=6, - input_image_size=256, - fov=50, + input_image_size=320, + fov=30, ): self.root_dir = Path(root_dir) self.input_view_num = input_view_num @@ -251,9 +253,9 @@ class ValidationData(Dataset): self.paths = sorted(os.listdir(self.root_dir)) print('============= length of dataset %d =============' % len(self.paths)) - cam_distance = 2.5 + cam_distance = 4.0 azimuths = np.array([30, 90, 150, 210, 270, 330]) - elevations = np.array([30, -20, 30, -20, 30, -20]) + elevations = np.array([20, -10, 20, -10, 20, -10]) azimuths = np.deg2rad(azimuths) elevations = np.deg2rad(elevations) @@ -267,7 +269,7 @@ class ValidationData(Dataset): self.c2ws = c2ws.float() self.Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(6, 1, 1).float() - render_c2ws = get_surrounding_views(M=8, radius=cam_distance) + render_c2ws = get_circular_camera_poses(M=8, radius=cam_distance, elevation=20.0) render_Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(render_c2ws.shape[0], 1, 1) self.render_c2ws = render_c2ws.float() self.render_Ks = render_Ks.float()