diff --git a/.gitignore b/.gitignore
index e4c0592..ee31a82 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,5 @@
__pycache__/
.DS_Store
*.egg-info/
+shap_e/examples/shap_e_model_cache/
+shap_e/examples/gc/shap_e_model_cache/
diff --git a/shap_e/examples/.ipynb_checkpoints/sample_text_to_3d-checkpoint.ipynb b/shap_e/examples/.ipynb_checkpoints/sample_text_to_3d-checkpoint.ipynb
new file mode 100644
index 0000000..2f1a72d
--- /dev/null
+++ b/shap_e/examples/.ipynb_checkpoints/sample_text_to_3d-checkpoint.ipynb
@@ -0,0 +1,172 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "964ccced",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "\n",
+ "from shap_e.diffusion.sample import sample_latents\n",
+ "from shap_e.diffusion.gaussian_diffusion import diffusion_from_config\n",
+ "from shap_e.models.download import load_model, load_config\n",
+ "from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "8eed3a76",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "4127249f-da93-4da9-a15e-47fc1d918758",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "NVIDIA GeForce RTX 3090\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(torch.cuda.get_device_name())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "2d922637",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "xm = load_model('transmitter', device=device)\n",
+ "model = load_model('text300M', device=device)\n",
+ "diffusion = diffusion_from_config(load_config('diffusion'))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "53d329d0",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "f76a8f93c93e4b77af91f03645eb5011",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/64 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "batch_size = 1\n",
+ "guidance_scale = 30.0\n",
+ "prompt = \"bin bag\"\n",
+ "\n",
+ "latents = sample_latents(\n",
+ " batch_size=batch_size,\n",
+ " model=model,\n",
+ " diffusion=diffusion,\n",
+ " guidance_scale=guidance_scale,\n",
+ " model_kwargs=dict(texts=[prompt] * batch_size),\n",
+ " progress=True,\n",
+ " clip_denoised=True,\n",
+ " use_fp16=True,\n",
+ " use_karras=True,\n",
+ " karras_steps=64,\n",
+ " sigma_min=1e-3,\n",
+ " sigma_max=160,\n",
+ " s_churn=0,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "633da2ec",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "7c8b8946a49847dd9aa5376f9568775f",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HTML(value=' : ")
+ if command.lower() == 'exit':
+ print("Exiting the program.")
+ self.stop_generation()
+ break
+ elif command.lower() == 'start':
+ if not self.running:
+ print("Starting continuous generation.")
+ self.start_generation()
+ else:
+ print("Generation already running.")
+ elif command.lower() == 'stop':
+ print("Stopping continuous generation.")
+ self.stop_generation()
+ else:
+ print("Unknown command.")
+
+def main(output_dir, batch_size, step_size, guidance_scale):
+ app = GCApp(output_dir, batch_size, step_size, guidance_scale)
+ app.run()
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Generate shapes with the ShapeGenerator.")
+ parser.add_argument("--output_dir", type=str, required=True, help="The directory to save generated shapes.")
+ parser.add_argument("--batch_size", type=int, default=2, help="The number of batches for shap-e. the higher the batch size the longer it will take to process but will output a more refined mesh.")
+ parser.add_argument("--step_size", type=int, default=64, help="The number of steps/iterations for shap-e. the higher the step size the longer it will take to process but will output a more refined mesh.")
+ parser.add_argument("--guidance_scale", type=int, default=30, help="The guidance scale in context to the text prompt. The higher this value, the model will generate something closer to the text description (CLIP).")
+
+ args = parser.parse_args()
+
+ main(args.output_dir, args.batch_size, args.step_size, args.guidance_scale)
+
+
diff --git a/shap_e/examples/sample_text_to_3d.ipynb b/shap_e/examples/sample_text_to_3d.ipynb
index ef39d44..9ab7398 100644
--- a/shap_e/examples/sample_text_to_3d.ipynb
+++ b/shap_e/examples/sample_text_to_3d.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 1,
"id": "964ccced",
"metadata": {},
"outputs": [],
@@ -17,7 +17,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 2,
"id": "8eed3a76",
"metadata": {},
"outputs": [],
@@ -27,7 +27,25 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 3,
+ "id": "4127249f-da93-4da9-a15e-47fc1d918758",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "NVIDIA GeForce RTX 3090\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(torch.cuda.get_device_name())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
"id": "2d922637",
"metadata": {},
"outputs": [],
@@ -39,14 +57,29 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 5,
"id": "53d329d0",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "299c91406ddc4a368d1c80ed81c20a84",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/64 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
"source": [
- "batch_size = 4\n",
- "guidance_scale = 15.0\n",
- "prompt = \"a shark\"\n",
+ "batch_size = 1\n",
+ "guidance_scale = 30.0\n",
+ "prompt = \"road sign\"\n",
"\n",
"latents = sample_latents(\n",
" batch_size=batch_size,\n",
@@ -67,10 +100,25 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 20,
"id": "633da2ec",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "7c8b8946a49847dd9aa5376f9568775f",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "HTML(value='