You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
100 lines
3.3 KiB
100 lines
3.3 KiB
import torch
|
|
from shap_e.diffusion.sample import sample_latents
|
|
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
|
|
from shap_e.models.download import load_model, load_config
|
|
from shap_e.util.notebooks import decode_latent_mesh
|
|
from tqdm import tqdm
|
|
import pygltflib
|
|
from pygltflib import GLTF2
|
|
import trimesh
|
|
import open3d as o3d
|
|
import os
|
|
import datetime
|
|
|
|
class ShapeGenerator:
|
|
def __init__(self, output_path, batch_size, step_size, guidance):
|
|
self.device = None
|
|
self.xm = None
|
|
self.model = None
|
|
self.diffusion = None
|
|
self.iterations = 0
|
|
self.latents = None
|
|
self.output_path = output_path
|
|
self.batch_size = batch_size
|
|
self.step_size = step_size
|
|
self.guidance = guidance
|
|
|
|
def run(self):
|
|
print("Loading Models..")
|
|
self.load_models()
|
|
print("Finished Loading Models!")
|
|
|
|
|
|
def load_models(self):
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
self.xm = load_model('transmitter', device=self.device)
|
|
self.model = load_model('text300M', device=self.device)
|
|
self.diffusion = diffusion_from_config(load_config('diffusion'))
|
|
|
|
|
|
def generate_object(self, prompt):
|
|
batch_size = 2
|
|
# Create random latents
|
|
latent_dim = self.model.d_latent
|
|
random_latents = torch.randn(batch_size, latent_dim).to(self.model.device)
|
|
print(random_latents.shape)
|
|
model_kwargs = {}
|
|
model_kwargs = dict(texts=[prompt] * self.batch_size)
|
|
|
|
|
|
self.latents = sample_latents(
|
|
batch_size=self.batch_size,
|
|
model=self.model,
|
|
diffusion=self.diffusion,
|
|
guidance_scale=self.guidance,
|
|
model_kwargs=model_kwargs,
|
|
progress=True, # This should already show progress
|
|
clip_denoised=True,
|
|
use_fp16=True,
|
|
use_karras=True,
|
|
karras_steps=self.step_size,
|
|
sigma_min=1e-3,
|
|
sigma_max=160,
|
|
s_churn=0,
|
|
device = self.model.device,
|
|
)
|
|
|
|
mesh = self.export_model(prompt)
|
|
return mesh, prompt
|
|
|
|
def export_model(self, prompt):
|
|
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
|
obj_filepath = f'{prompt}-{self.iterations}.obj'
|
|
output_filepath = f'{self.output_path}/{prompt}-{timestamp}.gltf'
|
|
print(output_filepath)
|
|
for i, latent in enumerate(self.latents):
|
|
t = decode_latent_mesh(self.xm, latent).tri_mesh()
|
|
with open(obj_filepath, 'w') as f:
|
|
t.write_obj(f)
|
|
final_mesh = self.construct_mesh(obj_filepath)
|
|
o3d.io.write_triangle_mesh(output_filepath, final_mesh)
|
|
self.iterations += 1
|
|
return final_mesh
|
|
|
|
def construct_mesh(self, obj_fp):
|
|
mesh = o3d.io.read_triangle_mesh(obj_fp)
|
|
|
|
if os.path.exists(obj_fp):
|
|
os.remove(obj_fp)
|
|
|
|
original_triangle_count = len(mesh.triangles)
|
|
target_triangle_count = original_triangle_count // 3
|
|
decimated_mesh = mesh.simplify_quadric_decimation(
|
|
target_number_of_triangles=target_triangle_count)
|
|
filtered_mesh = decimated_mesh.filter_smooth_simple(number_of_iterations=5)
|
|
filtered_mesh.compute_vertex_normals()
|
|
return filtered_mesh
|
|
|
|
|
|
|
|
|
|
|