a fork of shap-e for gc
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

181 lines
7.7 KiB

import argparse
import threading
import time
import random
import open3d as o3d
import open3d.visualization.gui as gui
import open3d.visualization.rendering as rendering
from ShapeGenerator import ShapeGenerator
from TextGenerator import TextGenerator
# command example
# python app.py --output_dir /mnt/c/Users/caile/Desktop/output
class GCApp:
def __init__(self, output_dir, batch_size, step_size, guidance_scale, vis):
self.output_dir = output_dir
self.obj_gen = ShapeGenerator(self.output_dir, batch_size, step_size, guidance_scale)
self.running = False
self.stop_event = threading.Event()
self.thread = None
self.vis = vis
self.waste_items = [
"Plastic bottle", "Aluminum can", "Glass bottle", "Food wrapper",
"Cardboard box", "Paper bag", "Plastic bag", "Electronics",
"Old smartphone", "Broken TV", "Computer parts", "Batteries",
"Light bulbs", "Old furniture", "Styrofoam cup", "Food container",
"Takeout box", "Cigarette butts", "Plastic utensils", "Straws",
"Bottle caps", "Rubber tires", "Broken toys", "Old clothes",
"Shoes", "Wooden pallets", "Paint cans", "Cleaning products",
"Old appliances", "Wires", "Cables", "Extension cords",
"Old magazines", "Newspapers", "Scrap metal", "Construction debris",
"Yard waste", "Grass clippings", "Leaves", "Old mattresses",
"Carpeting", "Food scraps", "Pet waste", "Diapers",
"Sanitary products", "Receipts", "Plastic wrap", "Packing peanuts",
"Ice cream containers", "Fast food containers", "Takeaway cups",
"Clamshell packaging", "Plastic film", "Broken glass",
"Old books", "VCR tapes", "CDs", "DVDs",
"Game consoles", "Remote controls", "Ink cartridges",
"Toner cartridges", "Old tools", "Gardening tools",
"Bike parts", "Fishing gear", "Beach toys", "Pool floats",
"Old bicycles", "Skateboards", "Surfboards", "Helmets",
"Used batteries", "Old jewelry", "Keyboards", "Mice (computer)",
"Speakers", "Old cameras", "Projectors", "Printers",
"Scanners", "Shredded paper", "Bubble wrap", "Plastic sheeting",
"Tarps", "Old car parts", "Motor oil containers",
"Propane tanks", "Oil filters", "Windshield wipers",
"Car batteries", "Antifreeze containers", "Used tires",
"Old propane tanks", "Scrap wood", "Broken furniture",
"Old carpets", "Leather scraps", "Textile waste",
"Compostable waste"
]
def start_generation(self):
self.running = True
self.stop_event.clear()
self.thread = threading.Thread(target=self._generate_objects)
self.thread.start()
def stop_generation(self):
self.stop_event.set()
self.running = False
if self.thread:
self.thread.join()
def get_random_item_prompt(self):
return random.choice(self.waste_items)
def add_to_scene(self, mesh, prompt):
scene = self.vis.get_scene()
self.vis.update_label(prompt)
scene.scene.clear_geometry()
scene.scene.add_geometry("name", mesh, rendering.MaterialRecord())
def _generate_objects(self):
while not self.stop_event.is_set():
mesh, prompt = self.obj_gen.generate_object(self.get_random_item_prompt())
self.add_to_scene(mesh, prompt)
time.sleep(1)
def run(self):
self.obj_gen.run()
while True:
command = input("Enter a command, <start> <stop> <generate (prompt)>: ")
if command.lower() == 'exit':
print("Exiting the program.")
self.stop_generation()
break
elif command.lower() == 'start':
if not self.running:
print("Starting continuous generation.")
self.start_generation()
else:
print("Generation already running.")
elif command.lower() == 'stop':
print("Stopping continuous generation.")
self.stop_generation()
elif command.startswith('generate '):
shape = command[len('generate '):]
mesh, prompt = self.obj_gen.generate_object(shape)
self.add_to_scene(mesh, prompt)
else:
print("Unknown command.")
class VisApp:
def __init__(self):
self._id = 0
self.window = gui.Application.instance.create_window(
"garbage collector", 1024, 768)
self.obj_label = None
self.layout = None
self.scene = None
self.create_scene()
self.create_layout()
def create_layout(self):
self.layout = gui.Vert(0, gui.Margins(10, 10, 10, 10))
self.obj_label = gui.Label("None")
self.layout.add_child(self.obj_label)
self.window.set_on_layout(self._on_layout)
self.window.add_child(self.layout)
def create_scene(self):
self.scene = gui.SceneWidget()
self.scene.scene = rendering.Open3DScene(self.window.renderer)
self.scene.scene.set_background([1, 1, 1, 1])
self.scene.scene.scene.set_sun_light(
[-1, -1, -1], # direction
[1, 1, 1], # color
100000) # intensity
self.scene.scene.scene.enable_sun_light(True)
bbox = o3d.geometry.AxisAlignedBoundingBox([-5, -5, -5],
[5, 5, 5])
self.scene.setup_camera(20, bbox, [0, 0, 0])
self.window.add_child(self.scene)
def _on_layout(self, layout_context):
# The on_layout callback should set the frame (position + size) of every
# child correctly. After the callback is done the window will layout
# the grandchildren.
r = self.window.content_rect
self.scene.frame = r
width = 17 * layout_context.theme.font_size
height = min(
r.height,
self.layout.calc_preferred_size(
layout_context, gui.Widget.Constraints()).height)
self.layout.frame = gui.Rect(r.get_right() - width, r.y, width,
height)
def get_scene(self):
return self.scene
def update_label(self, name):
self.obj_label.text = "Object: " + name
def main(output_dir, batch_size, step_size, guidance_scale):
gui.Application.instance.initialize()
vis = VisApp()
app = GCApp(output_dir, batch_size, step_size, guidance_scale, vis)
threading.Thread(target=app.run, daemon=True).start()
gui.Application.instance.run()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate shapes with the ShapeGenerator.")
parser.add_argument("--output_dir", type=str, required=True, help="The directory to save generated shapes.")
parser.add_argument("--batch_size", type=int, default=4, help="The number of batches for shap-e. the higher the batch size the longer it will take to process but will output a more refined mesh.")
parser.add_argument("--step_size", type=int, default=64, help="The number of steps/iterations for shap-e. the higher the step size the longer it will take to process but will output a more refined mesh.")
parser.add_argument("--guidance_scale", type=int, default=30, help="The guidance scale in context to the text prompt. The higher this value, the model will generate something closer to the text description (CLIP).")
args = parser.parse_args()
main(args.output_dir, args.batch_size, args.step_size, args.guidance_scale)