Browse Source

gui added

main
Cailean 6 months ago
parent
commit
a6160ab3ef
  1. 6
      shap_e/examples/gc/ShapeGenerator.py
  2. 86
      shap_e/examples/gc/app.py

6
shap_e/examples/gc/ShapeGenerator.py

@ -29,6 +29,7 @@ class ShapeGenerator:
self.load_models()
print("Finished Loading Models!")
def load_models(self):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.xm = load_model('transmitter', device=self.device)
@ -43,6 +44,7 @@ class ShapeGenerator:
random_latents = torch.randn(batch_size, latent_dim).to(self.model.device)
print(random_latents.shape)
model_kwargs = {}
model_kwargs = dict(texts=[prompt] * self.batch_size)
self.latents = sample_latents(
@ -62,7 +64,8 @@ class ShapeGenerator:
device = self.model.device,
)
self.export_model(prompt)
mesh = self.export_model(prompt)
return mesh, prompt
def export_model(self, prompt):
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
@ -76,6 +79,7 @@ class ShapeGenerator:
final_mesh = self.construct_mesh(obj_filepath)
o3d.io.write_triangle_mesh(output_filepath, final_mesh)
self.iterations += 1
return final_mesh
def construct_mesh(self, obj_fp):
mesh = o3d.io.read_triangle_mesh(obj_fp)

86
shap_e/examples/gc/app.py

@ -2,6 +2,10 @@ import argparse
import threading
import time
import random
import open3d as o3d
import open3d.visualization.gui as gui
import open3d.visualization.rendering as rendering
from ShapeGenerator import ShapeGenerator
from TextGenerator import TextGenerator
@ -9,12 +13,13 @@ from TextGenerator import TextGenerator
# python app.py --output_dir /mnt/c/Users/caile/Desktop/output
class GCApp:
def __init__(self, output_dir, batch_size, step_size, guidance_scale):
def __init__(self, output_dir, batch_size, step_size, guidance_scale, vis):
self.output_dir = output_dir
self.obj_gen = ShapeGenerator(self.output_dir, batch_size, step_size, guidance_scale)
self.running = False
self.stop_event = threading.Event()
self.thread = None
self.vis = vis
self.waste_items = [
"Plastic bottle", "Aluminum can", "Glass bottle", "Food wrapper",
"Cardboard box", "Paper bag", "Plastic bag", "Electronics",
@ -61,9 +66,16 @@ class GCApp:
def get_random_item_prompt(self):
return random.choice(self.waste_items)
def add_to_scene(self, mesh, prompt):
scene = self.vis.get_scene()
self.vis.update_label(prompt)
scene.scene.clear_geometry()
scene.scene.add_geometry("name", mesh, rendering.MaterialRecord())
def _generate_objects(self):
while not self.stop_event.is_set():
self.obj_gen.generate_object(self.get_random_item_prompt())
mesh, prompt = self.obj_gen.generate_object(self.get_random_item_prompt())
self.add_to_scene(mesh, prompt)
time.sleep(1)
def run(self):
@ -84,17 +96,81 @@ class GCApp:
elif command.lower() == 'stop':
print("Stopping continuous generation.")
self.stop_generation()
elif command.startswith('generate '):
shape = command[len('generate '):]
mesh, prompt = self.obj_gen.generate_object(shape)
self.add_to_scene(mesh, prompt)
else:
print("Unknown command.")
class VisApp:
def __init__(self):
self._id = 0
self.window = gui.Application.instance.create_window(
"garbage collector", 1024, 768)
self.obj_label = None
self.layout = None
self.scene = None
self.create_scene()
self.create_layout()
def create_layout(self):
self.layout = gui.Vert(0, gui.Margins(10, 10, 10, 10))
self.obj_label = gui.Label("None")
self.layout.add_child(self.obj_label)
self.window.set_on_layout(self._on_layout)
self.window.add_child(self.layout)
def create_scene(self):
self.scene = gui.SceneWidget()
self.scene.scene = rendering.Open3DScene(self.window.renderer)
self.scene.scene.set_background([1, 1, 1, 1])
self.scene.scene.scene.set_sun_light(
[-1, -1, -1], # direction
[1, 1, 1], # color
100000) # intensity
self.scene.scene.scene.enable_sun_light(True)
bbox = o3d.geometry.AxisAlignedBoundingBox([-5, -5, -5],
[5, 5, 5])
self.scene.setup_camera(20, bbox, [0, 0, 0])
self.window.add_child(self.scene)
def _on_layout(self, layout_context):
# The on_layout callback should set the frame (position + size) of every
# child correctly. After the callback is done the window will layout
# the grandchildren.
r = self.window.content_rect
self.scene.frame = r
width = 17 * layout_context.theme.font_size
height = min(
r.height,
self.layout.calc_preferred_size(
layout_context, gui.Widget.Constraints()).height)
self.layout.frame = gui.Rect(r.get_right() - width, r.y, width,
height)
def get_scene(self):
return self.scene
def update_label(self, name):
self.obj_label.text = "Object: " + name
def main(output_dir, batch_size, step_size, guidance_scale):
app = GCApp(output_dir, batch_size, step_size, guidance_scale)
app.run()
gui.Application.instance.initialize()
vis = VisApp()
app = GCApp(output_dir, batch_size, step_size, guidance_scale, vis)
threading.Thread(target=app.run, daemon=True).start()
gui.Application.instance.run()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate shapes with the ShapeGenerator.")
parser.add_argument("--output_dir", type=str, required=True, help="The directory to save generated shapes.")
parser.add_argument("--batch_size", type=int, default=2, help="The number of batches for shap-e. the higher the batch size the longer it will take to process but will output a more refined mesh.")
parser.add_argument("--batch_size", type=int, default=4, help="The number of batches for shap-e. the higher the batch size the longer it will take to process but will output a more refined mesh.")
parser.add_argument("--step_size", type=int, default=64, help="The number of steps/iterations for shap-e. the higher the step size the longer it will take to process but will output a more refined mesh.")
parser.add_argument("--guidance_scale", type=int, default=30, help="The guidance scale in context to the text prompt. The higher this value, the model will generate something closer to the text description (CLIP).")

Loading…
Cancel
Save