import torch from shap_e.diffusion.sample import sample_latents from shap_e.diffusion.gaussian_diffusion import diffusion_from_config from shap_e.models.download import load_model, load_config from shap_e.util.notebooks import decode_latent_mesh from tqdm import tqdm import pygltflib from pygltflib import GLTF2 import trimesh import open3d as o3d import os import datetime class ShapeGenerator: def __init__(self, output_path, batch_size, step_size, guidance): self.device = None self.xm = None self.model = None self.diffusion = None self.iterations = 0 self.latents = None self.output_path = output_path self.batch_size = batch_size self.step_size = step_size self.guidance = guidance def run(self): print("Loading Models..") self.load_models() print("Finished Loading Models!") def load_models(self): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.xm = load_model('transmitter', device=self.device) self.model = load_model('text300M', device=self.device) self.diffusion = diffusion_from_config(load_config('diffusion')) def generate_object(self, prompt): batch_size = 2 # Create random latents latent_dim = self.model.d_latent random_latents = torch.randn(batch_size, latent_dim).to(self.model.device) print(random_latents.shape) model_kwargs = {} model_kwargs = dict(texts=[prompt] * self.batch_size) self.latents = sample_latents( batch_size=self.batch_size, model=self.model, diffusion=self.diffusion, guidance_scale=self.guidance, model_kwargs=model_kwargs, progress=True, # This should already show progress clip_denoised=True, use_fp16=True, use_karras=True, karras_steps=self.step_size, sigma_min=1e-3, sigma_max=160, s_churn=0, device = self.model.device, ) mesh = self.export_model(prompt) return mesh, prompt def export_model(self, prompt): timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") obj_filepath = f'{prompt}-{self.iterations}.obj' output_filepath = f'{self.output_path}/{prompt}-{timestamp}.gltf' print(output_filepath) for i, latent in enumerate(self.latents): t = decode_latent_mesh(self.xm, latent).tri_mesh() with open(obj_filepath, 'w') as f: t.write_obj(f) final_mesh = self.construct_mesh(obj_filepath) o3d.io.write_triangle_mesh(output_filepath, final_mesh) self.iterations += 1 return final_mesh def construct_mesh(self, obj_fp): mesh = o3d.io.read_triangle_mesh(obj_fp) if os.path.exists(obj_fp): os.remove(obj_fp) original_triangle_count = len(mesh.triangles) target_triangle_count = original_triangle_count // 3 decimated_mesh = mesh.simplify_quadric_decimation( target_number_of_triangles=target_triangle_count) filtered_mesh = decimated_mesh.filter_smooth_simple(number_of_iterations=5) filtered_mesh.compute_vertex_normals() return filtered_mesh