You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							191 lines
						
					
					
						
							4.9 KiB
						
					
					
				
			
		
		
		
			
			
			
				
					
				
				
					
				
			
		
		
	
	
							191 lines
						
					
					
						
							4.9 KiB
						
					
					
				
								{
							 | 
						|
								 "cells": [
							 | 
						|
								  {
							 | 
						|
								   "cell_type": "code",
							 | 
						|
								   "execution_count": 1,
							 | 
						|
								   "id": "964ccced",
							 | 
						|
								   "metadata": {},
							 | 
						|
								   "outputs": [],
							 | 
						|
								   "source": [
							 | 
						|
								    "import torch\n",
							 | 
						|
								    "\n",
							 | 
						|
								    "from shap_e.diffusion.sample import sample_latents\n",
							 | 
						|
								    "from shap_e.diffusion.gaussian_diffusion import diffusion_from_config\n",
							 | 
						|
								    "from shap_e.models.download import load_model, load_config\n",
							 | 
						|
								    "from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget"
							 | 
						|
								   ]
							 | 
						|
								  },
							 | 
						|
								  {
							 | 
						|
								   "cell_type": "code",
							 | 
						|
								   "execution_count": 2,
							 | 
						|
								   "id": "8eed3a76",
							 | 
						|
								   "metadata": {},
							 | 
						|
								   "outputs": [],
							 | 
						|
								   "source": [
							 | 
						|
								    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
							 | 
						|
								   ]
							 | 
						|
								  },
							 | 
						|
								  {
							 | 
						|
								   "cell_type": "code",
							 | 
						|
								   "execution_count": 3,
							 | 
						|
								   "id": "4127249f-da93-4da9-a15e-47fc1d918758",
							 | 
						|
								   "metadata": {},
							 | 
						|
								   "outputs": [
							 | 
						|
								    {
							 | 
						|
								     "name": "stdout",
							 | 
						|
								     "output_type": "stream",
							 | 
						|
								     "text": [
							 | 
						|
								      "NVIDIA GeForce RTX 3090\n"
							 | 
						|
								     ]
							 | 
						|
								    }
							 | 
						|
								   ],
							 | 
						|
								   "source": [
							 | 
						|
								    "print(torch.cuda.get_device_name())"
							 | 
						|
								   ]
							 | 
						|
								  },
							 | 
						|
								  {
							 | 
						|
								   "cell_type": "code",
							 | 
						|
								   "execution_count": 4,
							 | 
						|
								   "id": "2d922637",
							 | 
						|
								   "metadata": {},
							 | 
						|
								   "outputs": [],
							 | 
						|
								   "source": [
							 | 
						|
								    "xm = load_model('transmitter', device=device)\n",
							 | 
						|
								    "model = load_model('text300M', device=device)\n",
							 | 
						|
								    "diffusion = diffusion_from_config(load_config('diffusion'))"
							 | 
						|
								   ]
							 | 
						|
								  },
							 | 
						|
								  {
							 | 
						|
								   "cell_type": "code",
							 | 
						|
								   "execution_count": 5,
							 | 
						|
								   "id": "53d329d0",
							 | 
						|
								   "metadata": {},
							 | 
						|
								   "outputs": [
							 | 
						|
								    {
							 | 
						|
								     "data": {
							 | 
						|
								      "application/vnd.jupyter.widget-view+json": {
							 | 
						|
								       "model_id": "299c91406ddc4a368d1c80ed81c20a84",
							 | 
						|
								       "version_major": 2,
							 | 
						|
								       "version_minor": 0
							 | 
						|
								      },
							 | 
						|
								      "text/plain": [
							 | 
						|
								       "  0%|          | 0/64 [00:00<?, ?it/s]"
							 | 
						|
								      ]
							 | 
						|
								     },
							 | 
						|
								     "metadata": {},
							 | 
						|
								     "output_type": "display_data"
							 | 
						|
								    }
							 | 
						|
								   ],
							 | 
						|
								   "source": [
							 | 
						|
								    "batch_size = 1\n",
							 | 
						|
								    "guidance_scale = 30.0\n",
							 | 
						|
								    "prompt = \"road sign\"\n",
							 | 
						|
								    "\n",
							 | 
						|
								    "latents = sample_latents(\n",
							 | 
						|
								    "    batch_size=batch_size,\n",
							 | 
						|
								    "    model=model,\n",
							 | 
						|
								    "    diffusion=diffusion,\n",
							 | 
						|
								    "    guidance_scale=guidance_scale,\n",
							 | 
						|
								    "    model_kwargs=dict(texts=[prompt] * batch_size),\n",
							 | 
						|
								    "    progress=True,\n",
							 | 
						|
								    "    clip_denoised=True,\n",
							 | 
						|
								    "    use_fp16=True,\n",
							 | 
						|
								    "    use_karras=True,\n",
							 | 
						|
								    "    karras_steps=64,\n",
							 | 
						|
								    "    sigma_min=1e-3,\n",
							 | 
						|
								    "    sigma_max=160,\n",
							 | 
						|
								    "    s_churn=0,\n",
							 | 
						|
								    ")"
							 | 
						|
								   ]
							 | 
						|
								  },
							 | 
						|
								  {
							 | 
						|
								   "cell_type": "code",
							 | 
						|
								   "execution_count": 20,
							 | 
						|
								   "id": "633da2ec",
							 | 
						|
								   "metadata": {},
							 | 
						|
								   "outputs": [
							 | 
						|
								    {
							 | 
						|
								     "data": {
							 | 
						|
								      "application/vnd.jupyter.widget-view+json": {
							 | 
						|
								       "model_id": "7c8b8946a49847dd9aa5376f9568775f",
							 | 
						|
								       "version_major": 2,
							 | 
						|
								       "version_minor": 0
							 | 
						|
								      },
							 | 
						|
								      "text/plain": [
							 | 
						|
								       "HTML(value='<img src=\"…"
							 | 
						|
								      ]
							 | 
						|
								     },
							 | 
						|
								     "metadata": {},
							 | 
						|
								     "output_type": "display_data"
							 | 
						|
								    }
							 | 
						|
								   ],
							 | 
						|
								   "source": [
							 | 
						|
								    "render_mode = 'nerf' # you can change this to 'stf'\n",
							 | 
						|
								    "size = 64 # this is the size of the renders; higher values take longer to render.\n",
							 | 
						|
								    "\n",
							 | 
						|
								    "cameras = create_pan_cameras(size, device)\n",
							 | 
						|
								    "for i, latent in enumerate(latents):\n",
							 | 
						|
								    "    images = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)\n",
							 | 
						|
								    "    display(gif_widget(images))"
							 | 
						|
								   ]
							 | 
						|
								  },
							 | 
						|
								  {
							 | 
						|
								   "cell_type": "code",
							 | 
						|
								   "execution_count": 6,
							 | 
						|
								   "id": "85a4dce4",
							 | 
						|
								   "metadata": {},
							 | 
						|
								   "outputs": [
							 | 
						|
								    {
							 | 
						|
								     "name": "stderr",
							 | 
						|
								     "output_type": "stream",
							 | 
						|
								     "text": [
							 | 
						|
								      "/home/cailean/shap-e/shap_e/models/stf/renderer.py:286: UserWarning: exception rendering with PyTorch3D: No module named 'pytorch3d'\n",
							 | 
						|
								      "  warnings.warn(f\"exception rendering with PyTorch3D: {exc}\")\n",
							 | 
						|
								      "/home/cailean/shap-e/shap_e/models/stf/renderer.py:287: UserWarning: falling back on native PyTorch renderer, which does not support full gradients\n",
							 | 
						|
								      "  warnings.warn(\n"
							 | 
						|
								     ]
							 | 
						|
								    }
							 | 
						|
								   ],
							 | 
						|
								   "source": [
							 | 
						|
								    "# Example of saving the latents as meshes.\n",
							 | 
						|
								    "from shap_e.util.notebooks import decode_latent_mesh\n",
							 | 
						|
								    "\n",
							 | 
						|
								    "for i, latent in enumerate(latents):\n",
							 | 
						|
								    "    t = decode_latent_mesh(xm, latent).tri_mesh()\n",
							 | 
						|
								    "    with open(f'road_example_mesh_{i}.ply', 'wb') as f:\n",
							 | 
						|
								    "        t.write_ply(f)\n",
							 | 
						|
								    "    with open(f'road_example_mesh_{i}.obj', 'w') as f:\n",
							 | 
						|
								    "        t.write_obj(f)"
							 | 
						|
								   ]
							 | 
						|
								  },
							 | 
						|
								  {
							 | 
						|
								   "cell_type": "code",
							 | 
						|
								   "execution_count": null,
							 | 
						|
								   "id": "71b5ace4-b449-4a7e-b4e3-66ee6a5d03c3",
							 | 
						|
								   "metadata": {},
							 | 
						|
								   "outputs": [],
							 | 
						|
								   "source": []
							 | 
						|
								  }
							 | 
						|
								 ],
							 | 
						|
								 "metadata": {
							 | 
						|
								  "kernelspec": {
							 | 
						|
								   "display_name": "Python 3 (ipykernel)",
							 | 
						|
								   "language": "python",
							 | 
						|
								   "name": "python3"
							 | 
						|
								  },
							 | 
						|
								  "language_info": {
							 | 
						|
								   "codemirror_mode": {
							 | 
						|
								    "name": "ipython",
							 | 
						|
								    "version": 3
							 | 
						|
								   },
							 | 
						|
								   "file_extension": ".py",
							 | 
						|
								   "mimetype": "text/x-python",
							 | 
						|
								   "name": "python",
							 | 
						|
								   "nbconvert_exporter": "python",
							 | 
						|
								   "pygments_lexer": "ipython3",
							 | 
						|
								   "version": "3.10.12"
							 | 
						|
								  }
							 | 
						|
								 },
							 | 
						|
								 "nbformat": 4,
							 | 
						|
								 "nbformat_minor": 5
							 | 
						|
								}
							 | 
						|
								
							 |