{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "964ccced", "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "from shap_e.diffusion.sample import sample_latents\n", "from shap_e.diffusion.gaussian_diffusion import diffusion_from_config\n", "from shap_e.models.download import load_model, load_config\n", "from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget" ] }, { "cell_type": "code", "execution_count": 2, "id": "8eed3a76", "metadata": {}, "outputs": [], "source": [ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" ] }, { "cell_type": "code", "execution_count": 3, "id": "4127249f-da93-4da9-a15e-47fc1d918758", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "NVIDIA GeForce RTX 3090\n" ] } ], "source": [ "print(torch.cuda.get_device_name())" ] }, { "cell_type": "code", "execution_count": 4, "id": "2d922637", "metadata": {}, "outputs": [], "source": [ "xm = load_model('transmitter', device=device)\n", "model = load_model('text300M', device=device)\n", "diffusion = diffusion_from_config(load_config('diffusion'))" ] }, { "cell_type": "code", "execution_count": 5, "id": "53d329d0", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "299c91406ddc4a368d1c80ed81c20a84", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/64 [00:00