From acc21209d1f81cf368262bd420b3865c950cb525 Mon Sep 17 00:00:00 2001 From: bluestyle97 Date: Thu, 11 Apr 2024 20:38:18 +0800 Subject: [PATCH] add dataloader --- .gitignore | 1 - src/data/__init__.py | 0 src/data/objaverse.py | 326 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 326 insertions(+), 1 deletion(-) create mode 100644 src/data/__init__.py create mode 100644 src/data/objaverse.py diff --git a/.gitignore b/.gitignore index 3198cbc..afc5096 100644 --- a/.gitignore +++ b/.gitignore @@ -30,7 +30,6 @@ MANIFEST tools/objaverse_rendering/blender-3.2.2-linux-x64/ tools/objaverse_rendering/output/ ckpts/ -data/ lightning_logs/ logs/ .trash/ diff --git a/src/data/__init__.py b/src/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/data/objaverse.py b/src/data/objaverse.py new file mode 100644 index 0000000..be7c12b --- /dev/null +++ b/src/data/objaverse.py @@ -0,0 +1,326 @@ +import os +import math +import json +from pathlib import Path + +import cv2 +import numpy as np +from PIL import Image +import webdataset as wds +import pytorch_lightning as pl + +import torch +import torch.nn.functional as F +from torch.utils.data import Dataset +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from torchvision import transforms + +from src.utils.train_util import instantiate_from_config +from src.utils.camera_util import ( + FOV_to_intrinsics, + center_looking_at_camera_pose, + get_circular_camera_poses, +) + + +class DataModuleFromConfig(pl.LightningDataModule): + def __init__( + self, + batch_size=8, + num_workers=4, + train=None, + validation=None, + test=None, + **kwargs, + ): + super().__init__() + + self.batch_size = batch_size + self.num_workers = num_workers + + self.dataset_configs = dict() + if train is not None: + self.dataset_configs['train'] = train + if validation is not None: + self.dataset_configs['validation'] = validation + if test is not None: + self.dataset_configs['test'] = test + + def setup(self, stage): + + if stage in ['fit']: + self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs) + else: + raise NotImplementedError + + def train_dataloader(self): + + sampler = DistributedSampler(self.datasets['train']) + return wds.WebLoader(self.datasets['train'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler) + + def val_dataloader(self): + + sampler = DistributedSampler(self.datasets['validation']) + return wds.WebLoader(self.datasets['validation'], batch_size=1, num_workers=self.num_workers, shuffle=False, sampler=sampler) + + def test_dataloader(self): + + return wds.WebLoader(self.datasets['test'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) + + +class ObjaverseData(Dataset): + def __init__(self, + root_dir='objaverse/', + meta_fname='valid_paths.json', + input_image_dir='rendering_random_32views', + target_image_dir='rendering_random_32views', + input_view_num=6, + target_view_num=2, + total_view_n=32, + fov=50, + camera_rotation=True, + validation=False, + ): + self.root_dir = Path(root_dir) + self.input_image_dir = input_image_dir + self.target_image_dir = target_image_dir + + self.input_view_num = input_view_num + self.target_view_num = target_view_num + self.total_view_n = total_view_n + self.fov = fov + self.camera_rotation = camera_rotation + + with open(os.path.join(root_dir, meta_fname)) as f: + filtered_dict = json.load(f) + paths = filtered_dict['good_objs'] + self.paths = paths + + self.depth_scale = 5.0 + + total_objects = len(self.paths) + print('============= length of dataset %d =============' % len(self.paths)) + + def __len__(self): + return len(self.paths) + + def load_im(self, path, color): + ''' + replace background pixel with random color in rendering + ''' + pil_img = Image.open(path) + + image = np.asarray(pil_img, dtype=np.float32) / 255. + alpha = image[:, :, 3:] + image = image[:, :, :3] * alpha + color * (1 - alpha) + + image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float() + alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float() + return image, alpha + + def __getitem__(self, index): + # load data + while True: + input_image_path = os.path.join(self.root_dir, self.input_image_dir, self.paths[index]) + target_image_path = os.path.join(self.root_dir, self.target_image_dir, self.paths[index]) + + indices = np.random.choice(range(self.total_view_n), self.input_view_num + self.target_view_num, replace=False) + input_indices = indices[:self.input_view_num] + target_indices = indices[self.input_view_num:] + + '''background color, default: white''' + bg_white = [1., 1., 1.] + bg_black = [0., 0., 0.] + + image_list = [] + alpha_list = [] + depth_list = [] + normal_list = [] + pose_list = [] + + try: + input_cameras = np.load(os.path.join(input_image_path, 'cameras.npz'))['cam_poses'] + for idx in input_indices: + image, alpha = self.load_im(os.path.join(input_image_path, '%03d.png' % idx), bg_white) + normal, _ = self.load_im(os.path.join(input_image_path, '%03d_normal.png' % idx), bg_black) + depth = cv2.imread(os.path.join(input_image_path, '%03d_depth.png' % idx), cv2.IMREAD_UNCHANGED) / 255.0 * self.depth_scale + depth = torch.from_numpy(depth).unsqueeze(0) + pose = input_cameras[idx] + pose = np.concatenate([pose, np.array([[0, 0, 0, 1]])], axis=0) + + image_list.append(image) + alpha_list.append(alpha) + depth_list.append(depth) + normal_list.append(normal) + pose_list.append(pose) + + target_cameras = np.load(os.path.join(target_image_path, 'cameras.npz'))['cam_poses'] + for idx in target_indices: + image, alpha = self.load_im(os.path.join(target_image_path, '%03d.png' % idx), bg_white) + normal, _ = self.load_im(os.path.join(target_image_path, '%03d_normal.png' % idx), bg_black) + depth = cv2.imread(os.path.join(target_image_path, '%03d_depth.png' % idx), cv2.IMREAD_UNCHANGED) / 255.0 * self.depth_scale + depth = torch.from_numpy(depth).unsqueeze(0) + pose = target_cameras[idx] + pose = np.concatenate([pose, np.array([[0, 0, 0, 1]])], axis=0) + + image_list.append(image) + alpha_list.append(alpha) + depth_list.append(depth) + normal_list.append(normal) + pose_list.append(pose) + + except Exception as e: + print(e) + index = np.random.randint(0, len(self.paths)) + continue + + break + + images = torch.stack(image_list, dim=0).float() # (6+V, 3, H, W) + alphas = torch.stack(alpha_list, dim=0).float() # (6+V, 1, H, W) + depths = torch.stack(depth_list, dim=0).float() # (6+V, 1, H, W) + normals = torch.stack(normal_list, dim=0).float() # (6+V, 3, H, W) + w2cs = torch.from_numpy(np.stack(pose_list, axis=0)).float() # (6+V, 4, 4) + c2ws = torch.linalg.inv(w2cs).float() + + normals = normals * 2.0 - 1.0 + normals = F.normalize(normals, dim=1) + normals = (normals + 1.0) / 2.0 + normals = torch.lerp(torch.zeros_like(normals), normals, alphas) + + # random rotation along z axis + if self.camera_rotation: + degree = np.random.uniform(0, math.pi * 2) + rot = torch.tensor([ + [np.cos(degree), -np.sin(degree), 0, 0], + [np.sin(degree), np.cos(degree), 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1], + ]).unsqueeze(0).float() + c2ws = torch.matmul(rot, c2ws) + + # rotate normals + N, _, H, W = normals.shape + normals = normals * 2.0 - 1.0 + normals = torch.matmul(rot[:, :3, :3], normals.view(N, 3, -1)).view(N, 3, H, W) + normals = F.normalize(normals, dim=1) + normals = (normals + 1.0) / 2.0 + normals = torch.lerp(torch.zeros_like(normals), normals, alphas) + + # random scaling + if np.random.rand() < 0.5: + scale = np.random.uniform(0.8, 1.0) + c2ws[:, :3, 3] *= scale + depths *= scale + + # instrinsics of perspective cameras + K = FOV_to_intrinsics(self.fov) + Ks = K.unsqueeze(0).repeat(self.input_view_num + self.target_view_num, 1, 1).float() + + data = { + 'input_images': images[:self.input_view_num], # (6, 3, H, W) + 'input_alphas': alphas[:self.input_view_num], # (6, 1, H, W) + 'input_depths': depths[:self.input_view_num], # (6, 1, H, W) + 'input_normals': normals[:self.input_view_num], # (6, 3, H, W) + 'input_c2ws': c2ws[:self.input_view_num], # (6, 4, 4) + 'input_Ks': Ks[:self.input_view_num], # (6, 3, 3) + + # lrm generator input and supervision + 'target_images': images[self.input_view_num:], # (V, 3, H, W) + 'target_alphas': alphas[self.input_view_num:], # (V, 1, H, W) + 'target_depths': depths[self.input_view_num:], # (V, 1, H, W) + 'target_normals': normals[self.input_view_num:], # (V, 3, H, W) + 'target_c2ws': c2ws[self.input_view_num:], # (V, 4, 4) + 'target_Ks': Ks[self.input_view_num:], # (V, 3, 3) + + 'depth_available': 1, + } + return data + + +class ValidationData(Dataset): + def __init__(self, + root_dir='objaverse/', + input_view_num=6, + input_image_size=256, + fov=30, + ): + self.root_dir = Path(root_dir) + self.input_view_num = input_view_num + self.input_image_size = input_image_size + self.fov = fov + + self.paths = sorted(os.listdir(self.root_dir)) + print('============= length of dataset %d =============' % len(self.paths)) + + cam_distance = 4.0 + azimuths = np.array([30, 90, 150, 210, 270, 330]) + elevations = np.array([20, -10, 20, -10, 20, -10]) + azimuths = np.deg2rad(azimuths) + elevations = np.deg2rad(elevations) + + x = cam_distance * np.cos(elevations) * np.cos(azimuths) + y = cam_distance * np.cos(elevations) * np.sin(azimuths) + z = cam_distance * np.sin(elevations) + + cam_locations = np.stack([x, y, z], axis=-1) + cam_locations = torch.from_numpy(cam_locations).float() + c2ws = center_looking_at_camera_pose(cam_locations) + self.c2ws = c2ws.float() + self.Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(6, 1, 1).float() + + render_c2ws = get_circular_camera_poses(M=8, radius=cam_distance, elevation=20.0) + render_Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(render_c2ws.shape[0], 1, 1) + self.render_c2ws = render_c2ws.float() + self.render_Ks = render_Ks.float() + + def __len__(self): + return len(self.paths) + + def load_im(self, path, color): + ''' + replace background pixel with random color in rendering + ''' + pil_img = Image.open(path) + pil_img = pil_img.resize((self.input_image_size, self.input_image_size), resample=Image.BICUBIC) + + image = np.asarray(pil_img, dtype=np.float32) / 255. + if image.shape[-1] == 4: + alpha = image[:, :, 3:] + image = image[:, :, :3] * alpha + color * (1 - alpha) + else: + alpha = np.ones_like(image[:, :, :1]) + + image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float() + alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float() + return image, alpha + + def __getitem__(self, index): + # load data + input_image_path = os.path.join(self.root_dir, self.paths[index]) + + '''background color, default: white''' + bkg_color = [1.0, 1.0, 1.0] + + image_list = [] + alpha_list = [] + + for idx in range(self.input_view_num): + image, alpha = self.load_im(os.path.join(input_image_path, f'{idx:03d}.png'), bkg_color) + image_list.append(image) + alpha_list.append(alpha) + + images = torch.stack(image_list, dim=0).float() + alphas = torch.stack(alpha_list, dim=0).float() + + data = { + 'input_images': images, + 'input_alphas': alphas, + 'input_c2ws': self.c2ws, + 'input_Ks': self.Ks, + + 'render_c2ws': self.render_c2ws, + 'render_Ks': self.render_Ks, + } + return data