diff --git a/shap_e/examples/gc/ShapeGenerator.py b/shap_e/examples/gc/ShapeGenerator.py index 30304ae..9a82062 100644 --- a/shap_e/examples/gc/ShapeGenerator.py +++ b/shap_e/examples/gc/ShapeGenerator.py @@ -29,6 +29,7 @@ class ShapeGenerator: self.load_models() print("Finished Loading Models!") + def load_models(self): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.xm = load_model('transmitter', device=self.device) @@ -43,6 +44,7 @@ class ShapeGenerator: random_latents = torch.randn(batch_size, latent_dim).to(self.model.device) print(random_latents.shape) model_kwargs = {} + model_kwargs = dict(texts=[prompt] * self.batch_size) self.latents = sample_latents( @@ -62,7 +64,8 @@ class ShapeGenerator: device = self.model.device, ) - self.export_model(prompt) + mesh = self.export_model(prompt) + return mesh, prompt def export_model(self, prompt): timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") @@ -76,6 +79,7 @@ class ShapeGenerator: final_mesh = self.construct_mesh(obj_filepath) o3d.io.write_triangle_mesh(output_filepath, final_mesh) self.iterations += 1 + return final_mesh def construct_mesh(self, obj_fp): mesh = o3d.io.read_triangle_mesh(obj_fp) diff --git a/shap_e/examples/gc/app.py b/shap_e/examples/gc/app.py index 4967c69..166f76c 100644 --- a/shap_e/examples/gc/app.py +++ b/shap_e/examples/gc/app.py @@ -2,6 +2,10 @@ import argparse import threading import time import random +import open3d as o3d +import open3d.visualization.gui as gui +import open3d.visualization.rendering as rendering + from ShapeGenerator import ShapeGenerator from TextGenerator import TextGenerator @@ -9,12 +13,13 @@ from TextGenerator import TextGenerator # python app.py --output_dir /mnt/c/Users/caile/Desktop/output class GCApp: - def __init__(self, output_dir, batch_size, step_size, guidance_scale): + def __init__(self, output_dir, batch_size, step_size, guidance_scale, vis): self.output_dir = output_dir self.obj_gen = ShapeGenerator(self.output_dir, batch_size, step_size, guidance_scale) self.running = False self.stop_event = threading.Event() self.thread = None + self.vis = vis self.waste_items = [ "Plastic bottle", "Aluminum can", "Glass bottle", "Food wrapper", "Cardboard box", "Paper bag", "Plastic bag", "Electronics", @@ -61,9 +66,16 @@ class GCApp: def get_random_item_prompt(self): return random.choice(self.waste_items) + def add_to_scene(self, mesh, prompt): + scene = self.vis.get_scene() + self.vis.update_label(prompt) + scene.scene.clear_geometry() + scene.scene.add_geometry("name", mesh, rendering.MaterialRecord()) + def _generate_objects(self): while not self.stop_event.is_set(): - self.obj_gen.generate_object(self.get_random_item_prompt()) + mesh, prompt = self.obj_gen.generate_object(self.get_random_item_prompt()) + self.add_to_scene(mesh, prompt) time.sleep(1) def run(self): @@ -84,17 +96,81 @@ class GCApp: elif command.lower() == 'stop': print("Stopping continuous generation.") self.stop_generation() + elif command.startswith('generate '): + shape = command[len('generate '):] + mesh, prompt = self.obj_gen.generate_object(shape) + self.add_to_scene(mesh, prompt) else: print("Unknown command.") +class VisApp: + + def __init__(self): + self._id = 0 + self.window = gui.Application.instance.create_window( + "garbage collector", 1024, 768) + + self.obj_label = None + self.layout = None + self.scene = None + + self.create_scene() + self.create_layout() + + def create_layout(self): + self.layout = gui.Vert(0, gui.Margins(10, 10, 10, 10)) + self.obj_label = gui.Label("None") + self.layout.add_child(self.obj_label) + self.window.set_on_layout(self._on_layout) + self.window.add_child(self.layout) + + def create_scene(self): + self.scene = gui.SceneWidget() + self.scene.scene = rendering.Open3DScene(self.window.renderer) + self.scene.scene.set_background([1, 1, 1, 1]) + self.scene.scene.scene.set_sun_light( + [-1, -1, -1], # direction + [1, 1, 1], # color + 100000) # intensity + self.scene.scene.scene.enable_sun_light(True) + bbox = o3d.geometry.AxisAlignedBoundingBox([-5, -5, -5], + [5, 5, 5]) + self.scene.setup_camera(20, bbox, [0, 0, 0]) + self.window.add_child(self.scene) + + def _on_layout(self, layout_context): + # The on_layout callback should set the frame (position + size) of every + # child correctly. After the callback is done the window will layout + # the grandchildren. + r = self.window.content_rect + self.scene.frame = r + width = 17 * layout_context.theme.font_size + height = min( + r.height, + self.layout.calc_preferred_size( + layout_context, gui.Widget.Constraints()).height) + self.layout.frame = gui.Rect(r.get_right() - width, r.y, width, + height) + + def get_scene(self): + return self.scene + + def update_label(self, name): + self.obj_label.text = "Object: " + name + def main(output_dir, batch_size, step_size, guidance_scale): - app = GCApp(output_dir, batch_size, step_size, guidance_scale) - app.run() + gui.Application.instance.initialize() + + vis = VisApp() + app = GCApp(output_dir, batch_size, step_size, guidance_scale, vis) + + threading.Thread(target=app.run, daemon=True).start() + gui.Application.instance.run() if __name__ == "__main__": parser = argparse.ArgumentParser(description="Generate shapes with the ShapeGenerator.") parser.add_argument("--output_dir", type=str, required=True, help="The directory to save generated shapes.") - parser.add_argument("--batch_size", type=int, default=2, help="The number of batches for shap-e. the higher the batch size the longer it will take to process but will output a more refined mesh.") + parser.add_argument("--batch_size", type=int, default=4, help="The number of batches for shap-e. the higher the batch size the longer it will take to process but will output a more refined mesh.") parser.add_argument("--step_size", type=int, default=64, help="The number of steps/iterations for shap-e. the higher the step size the longer it will take to process but will output a more refined mesh.") parser.add_argument("--guidance_scale", type=int, default=30, help="The guidance scale in context to the text prompt. The higher this value, the model will generate something closer to the text description (CLIP).")