a fork of shap-e for gc
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

101 lines
3.3 KiB

3 months ago
import torch
from shap_e.diffusion.sample import sample_latents
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
from shap_e.models.download import load_model, load_config
from shap_e.util.notebooks import decode_latent_mesh
from tqdm import tqdm
import pygltflib
from pygltflib import GLTF2
import trimesh
import open3d as o3d
import os
import datetime
class ShapeGenerator:
def __init__(self, output_path, batch_size, step_size, guidance):
self.device = None
self.xm = None
self.model = None
self.diffusion = None
self.iterations = 0
self.latents = None
self.output_path = output_path
self.batch_size = batch_size
self.step_size = step_size
self.guidance = guidance
def run(self):
print("Loading Models..")
self.load_models()
print("Finished Loading Models!")
3 months ago
3 months ago
def load_models(self):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.xm = load_model('transmitter', device=self.device)
self.model = load_model('text300M', device=self.device)
self.diffusion = diffusion_from_config(load_config('diffusion'))
def generate_object(self, prompt):
batch_size = 2
# Create random latents
latent_dim = self.model.d_latent
random_latents = torch.randn(batch_size, latent_dim).to(self.model.device)
print(random_latents.shape)
model_kwargs = {}
3 months ago
model_kwargs = dict(texts=[prompt] * self.batch_size)
3 months ago
self.latents = sample_latents(
batch_size=self.batch_size,
model=self.model,
diffusion=self.diffusion,
guidance_scale=self.guidance,
model_kwargs=model_kwargs,
progress=True, # This should already show progress
clip_denoised=True,
use_fp16=True,
use_karras=True,
karras_steps=self.step_size,
sigma_min=1e-3,
sigma_max=160,
s_churn=0,
device = self.model.device,
)
3 months ago
mesh = self.export_model(prompt)
return mesh, prompt
3 months ago
def export_model(self, prompt):
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
obj_filepath = f'{prompt}-{self.iterations}.obj'
output_filepath = f'{self.output_path}/{prompt}-{timestamp}.gltf'
print(output_filepath)
for i, latent in enumerate(self.latents):
t = decode_latent_mesh(self.xm, latent).tri_mesh()
with open(obj_filepath, 'w') as f:
t.write_obj(f)
final_mesh = self.construct_mesh(obj_filepath)
o3d.io.write_triangle_mesh(output_filepath, final_mesh)
self.iterations += 1
3 months ago
return final_mesh
3 months ago
def construct_mesh(self, obj_fp):
mesh = o3d.io.read_triangle_mesh(obj_fp)
if os.path.exists(obj_fp):
os.remove(obj_fp)
original_triangle_count = len(mesh.triangles)
target_triangle_count = original_triangle_count // 3
decimated_mesh = mesh.simplify_quadric_decimation(
target_number_of_triangles=target_triangle_count)
filtered_mesh = decimated_mesh.filter_smooth_simple(number_of_iterations=5)
filtered_mesh.compute_vertex_normals()
return filtered_mesh