Browse Source

gui added

main
Cailean 3 months ago
parent
commit
a6160ab3ef
  1. 6
      shap_e/examples/gc/ShapeGenerator.py
  2. 86
      shap_e/examples/gc/app.py

6
shap_e/examples/gc/ShapeGenerator.py

@ -29,6 +29,7 @@ class ShapeGenerator:
self.load_models() self.load_models()
print("Finished Loading Models!") print("Finished Loading Models!")
def load_models(self): def load_models(self):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.xm = load_model('transmitter', device=self.device) self.xm = load_model('transmitter', device=self.device)
@ -43,6 +44,7 @@ class ShapeGenerator:
random_latents = torch.randn(batch_size, latent_dim).to(self.model.device) random_latents = torch.randn(batch_size, latent_dim).to(self.model.device)
print(random_latents.shape) print(random_latents.shape)
model_kwargs = {} model_kwargs = {}
model_kwargs = dict(texts=[prompt] * self.batch_size)
self.latents = sample_latents( self.latents = sample_latents(
@ -62,7 +64,8 @@ class ShapeGenerator:
device = self.model.device, device = self.model.device,
) )
self.export_model(prompt) mesh = self.export_model(prompt)
return mesh, prompt
def export_model(self, prompt): def export_model(self, prompt):
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
@ -76,6 +79,7 @@ class ShapeGenerator:
final_mesh = self.construct_mesh(obj_filepath) final_mesh = self.construct_mesh(obj_filepath)
o3d.io.write_triangle_mesh(output_filepath, final_mesh) o3d.io.write_triangle_mesh(output_filepath, final_mesh)
self.iterations += 1 self.iterations += 1
return final_mesh
def construct_mesh(self, obj_fp): def construct_mesh(self, obj_fp):
mesh = o3d.io.read_triangle_mesh(obj_fp) mesh = o3d.io.read_triangle_mesh(obj_fp)

86
shap_e/examples/gc/app.py

@ -2,6 +2,10 @@ import argparse
import threading import threading
import time import time
import random import random
import open3d as o3d
import open3d.visualization.gui as gui
import open3d.visualization.rendering as rendering
from ShapeGenerator import ShapeGenerator from ShapeGenerator import ShapeGenerator
from TextGenerator import TextGenerator from TextGenerator import TextGenerator
@ -9,12 +13,13 @@ from TextGenerator import TextGenerator
# python app.py --output_dir /mnt/c/Users/caile/Desktop/output # python app.py --output_dir /mnt/c/Users/caile/Desktop/output
class GCApp: class GCApp:
def __init__(self, output_dir, batch_size, step_size, guidance_scale): def __init__(self, output_dir, batch_size, step_size, guidance_scale, vis):
self.output_dir = output_dir self.output_dir = output_dir
self.obj_gen = ShapeGenerator(self.output_dir, batch_size, step_size, guidance_scale) self.obj_gen = ShapeGenerator(self.output_dir, batch_size, step_size, guidance_scale)
self.running = False self.running = False
self.stop_event = threading.Event() self.stop_event = threading.Event()
self.thread = None self.thread = None
self.vis = vis
self.waste_items = [ self.waste_items = [
"Plastic bottle", "Aluminum can", "Glass bottle", "Food wrapper", "Plastic bottle", "Aluminum can", "Glass bottle", "Food wrapper",
"Cardboard box", "Paper bag", "Plastic bag", "Electronics", "Cardboard box", "Paper bag", "Plastic bag", "Electronics",
@ -61,9 +66,16 @@ class GCApp:
def get_random_item_prompt(self): def get_random_item_prompt(self):
return random.choice(self.waste_items) return random.choice(self.waste_items)
def add_to_scene(self, mesh, prompt):
scene = self.vis.get_scene()
self.vis.update_label(prompt)
scene.scene.clear_geometry()
scene.scene.add_geometry("name", mesh, rendering.MaterialRecord())
def _generate_objects(self): def _generate_objects(self):
while not self.stop_event.is_set(): while not self.stop_event.is_set():
self.obj_gen.generate_object(self.get_random_item_prompt()) mesh, prompt = self.obj_gen.generate_object(self.get_random_item_prompt())
self.add_to_scene(mesh, prompt)
time.sleep(1) time.sleep(1)
def run(self): def run(self):
@ -84,17 +96,81 @@ class GCApp:
elif command.lower() == 'stop': elif command.lower() == 'stop':
print("Stopping continuous generation.") print("Stopping continuous generation.")
self.stop_generation() self.stop_generation()
elif command.startswith('generate '):
shape = command[len('generate '):]
mesh, prompt = self.obj_gen.generate_object(shape)
self.add_to_scene(mesh, prompt)
else: else:
print("Unknown command.") print("Unknown command.")
class VisApp:
def __init__(self):
self._id = 0
self.window = gui.Application.instance.create_window(
"garbage collector", 1024, 768)
self.obj_label = None
self.layout = None
self.scene = None
self.create_scene()
self.create_layout()
def create_layout(self):
self.layout = gui.Vert(0, gui.Margins(10, 10, 10, 10))
self.obj_label = gui.Label("None")
self.layout.add_child(self.obj_label)
self.window.set_on_layout(self._on_layout)
self.window.add_child(self.layout)
def create_scene(self):
self.scene = gui.SceneWidget()
self.scene.scene = rendering.Open3DScene(self.window.renderer)
self.scene.scene.set_background([1, 1, 1, 1])
self.scene.scene.scene.set_sun_light(
[-1, -1, -1], # direction
[1, 1, 1], # color
100000) # intensity
self.scene.scene.scene.enable_sun_light(True)
bbox = o3d.geometry.AxisAlignedBoundingBox([-5, -5, -5],
[5, 5, 5])
self.scene.setup_camera(20, bbox, [0, 0, 0])
self.window.add_child(self.scene)
def _on_layout(self, layout_context):
# The on_layout callback should set the frame (position + size) of every
# child correctly. After the callback is done the window will layout
# the grandchildren.
r = self.window.content_rect
self.scene.frame = r
width = 17 * layout_context.theme.font_size
height = min(
r.height,
self.layout.calc_preferred_size(
layout_context, gui.Widget.Constraints()).height)
self.layout.frame = gui.Rect(r.get_right() - width, r.y, width,
height)
def get_scene(self):
return self.scene
def update_label(self, name):
self.obj_label.text = "Object: " + name
def main(output_dir, batch_size, step_size, guidance_scale): def main(output_dir, batch_size, step_size, guidance_scale):
app = GCApp(output_dir, batch_size, step_size, guidance_scale) gui.Application.instance.initialize()
app.run()
vis = VisApp()
app = GCApp(output_dir, batch_size, step_size, guidance_scale, vis)
threading.Thread(target=app.run, daemon=True).start()
gui.Application.instance.run()
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate shapes with the ShapeGenerator.") parser = argparse.ArgumentParser(description="Generate shapes with the ShapeGenerator.")
parser.add_argument("--output_dir", type=str, required=True, help="The directory to save generated shapes.") parser.add_argument("--output_dir", type=str, required=True, help="The directory to save generated shapes.")
parser.add_argument("--batch_size", type=int, default=2, help="The number of batches for shap-e. the higher the batch size the longer it will take to process but will output a more refined mesh.") parser.add_argument("--batch_size", type=int, default=4, help="The number of batches for shap-e. the higher the batch size the longer it will take to process but will output a more refined mesh.")
parser.add_argument("--step_size", type=int, default=64, help="The number of steps/iterations for shap-e. the higher the step size the longer it will take to process but will output a more refined mesh.") parser.add_argument("--step_size", type=int, default=64, help="The number of steps/iterations for shap-e. the higher the step size the longer it will take to process but will output a more refined mesh.")
parser.add_argument("--guidance_scale", type=int, default=30, help="The guidance scale in context to the text prompt. The higher this value, the model will generate something closer to the text description (CLIP).") parser.add_argument("--guidance_scale", type=int, default=30, help="The guidance scale in context to the text prompt. The higher this value, the model will generate something closer to the text description (CLIP).")

Loading…
Cancel
Save