You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
271 lines
9.6 KiB
271 lines
9.6 KiB
2 years ago
|
from typing import Iterable, List, Optional, Union
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
from PIL import Image
|
||
|
|
||
|
from shap_e.models.download import default_cache_dir
|
||
|
|
||
|
ImageType = Union[np.ndarray, torch.Tensor, Image.Image]
|
||
|
|
||
|
|
||
|
class ImageCLIP(nn.Module):
|
||
|
"""
|
||
|
A wrapper around a pre-trained CLIP model that automatically handles
|
||
|
batches of texts, images, and embeddings.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
device: torch.device,
|
||
|
dtype: Optional[torch.dtype] = torch.float32,
|
||
|
ensure_used_params: bool = True,
|
||
|
clip_name: str = "ViT-L/14",
|
||
|
cache_dir: Optional[str] = None,
|
||
|
):
|
||
|
super().__init__()
|
||
|
|
||
|
assert clip_name in ["ViT-L/14", "ViT-B/32"]
|
||
|
|
||
|
self.device = device
|
||
|
self.ensure_used_params = ensure_used_params
|
||
|
|
||
|
# Lazy import because of torchvision.
|
||
|
import clip
|
||
|
|
||
|
self.clip_model, self.preprocess = clip.load(
|
||
|
clip_name, device=device, download_root=cache_dir or default_cache_dir()
|
||
|
)
|
||
|
self.clip_name = clip_name
|
||
|
|
||
|
if dtype is not None:
|
||
|
self.clip_model.to(dtype)
|
||
|
self._tokenize = clip.tokenize
|
||
|
|
||
|
@property
|
||
|
def feature_dim(self) -> int:
|
||
|
if self.clip_name == "ViT-L/14":
|
||
|
return 768
|
||
|
else:
|
||
|
return 512
|
||
|
|
||
|
@property
|
||
|
def grid_size(self) -> int:
|
||
|
if self.clip_name == "ViT-L/14":
|
||
|
return 16
|
||
|
else:
|
||
|
return 7
|
||
|
|
||
|
@property
|
||
|
def grid_feature_dim(self) -> int:
|
||
|
if self.clip_name == "ViT-L/14":
|
||
|
return 1024
|
||
|
else:
|
||
|
return 768
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
batch_size: int,
|
||
|
images: Optional[Iterable[Optional[ImageType]]] = None,
|
||
|
texts: Optional[Iterable[Optional[str]]] = None,
|
||
|
embeddings: Optional[Iterable[Optional[torch.Tensor]]] = None,
|
||
|
) -> torch.Tensor:
|
||
|
"""
|
||
|
Generate a batch of embeddings from a mixture of images, texts,
|
||
|
precomputed embeddings, and possibly empty values.
|
||
|
|
||
|
For each batch element, at most one of images, texts, and embeddings
|
||
|
should have a non-None value. Embeddings from multiple modalities
|
||
|
cannot be mixed for a single batch element. If no modality is provided,
|
||
|
a zero embedding will be used for the batch element.
|
||
|
"""
|
||
|
image_seq = [None] * batch_size if images is None else list(images)
|
||
|
text_seq = [None] * batch_size if texts is None else list(texts)
|
||
|
embedding_seq = [None] * batch_size if embeddings is None else list(embeddings)
|
||
|
assert len(image_seq) == batch_size, "number of images should match batch size"
|
||
|
assert len(text_seq) == batch_size, "number of texts should match batch size"
|
||
|
assert len(embedding_seq) == batch_size, "number of embeddings should match batch size"
|
||
|
|
||
|
if self.ensure_used_params:
|
||
|
return self._static_multimodal_embed(
|
||
|
images=image_seq, texts=text_seq, embeddings=embedding_seq
|
||
|
)
|
||
|
|
||
|
result = torch.zeros((batch_size, self.feature_dim), device=self.device)
|
||
|
index_images = []
|
||
|
index_texts = []
|
||
|
for i, (image, text, emb) in enumerate(zip(image_seq, text_seq, embedding_seq)):
|
||
|
assert (
|
||
|
sum([int(image is not None), int(text is not None), int(emb is not None)]) < 2
|
||
|
), "only one modality may be non-None per batch element"
|
||
|
if image is not None:
|
||
|
index_images.append((i, image))
|
||
|
elif text is not None:
|
||
|
index_texts.append((i, text))
|
||
|
elif emb is not None:
|
||
|
result[i] = emb.to(result)
|
||
|
|
||
|
if len(index_images):
|
||
|
embs = self.embed_images((img for _, img in index_images))
|
||
|
for (i, _), emb in zip(index_images, embs):
|
||
|
result[i] = emb.to(result)
|
||
|
if len(index_texts):
|
||
|
embs = self.embed_text((text for _, text in index_texts))
|
||
|
for (i, _), emb in zip(index_texts, embs):
|
||
|
result[i] = emb.to(result)
|
||
|
|
||
|
return result
|
||
|
|
||
|
def _static_multimodal_embed(
|
||
|
self,
|
||
|
images: List[Optional[ImageType]] = None,
|
||
|
texts: List[Optional[str]] = None,
|
||
|
embeddings: List[Optional[torch.Tensor]] = None,
|
||
|
) -> torch.Tensor:
|
||
|
"""
|
||
|
Like forward(), but always runs all encoders to ensure that
|
||
|
the forward graph looks the same on every rank.
|
||
|
"""
|
||
|
image_emb = self.embed_images(images)
|
||
|
text_emb = self.embed_text(t if t else "" for t in texts)
|
||
|
joined_embs = torch.stack(
|
||
|
[
|
||
|
emb.to(device=self.device, dtype=torch.float32)
|
||
|
if emb is not None
|
||
|
else torch.zeros(self.feature_dim, device=self.device)
|
||
|
for emb in embeddings
|
||
|
],
|
||
|
dim=0,
|
||
|
)
|
||
|
|
||
|
image_flag = torch.tensor([x is not None for x in images], device=self.device)[
|
||
|
:, None
|
||
|
].expand_as(image_emb)
|
||
|
text_flag = torch.tensor([x is not None for x in texts], device=self.device)[
|
||
|
:, None
|
||
|
].expand_as(image_emb)
|
||
|
emb_flag = torch.tensor([x is not None for x in embeddings], device=self.device)[
|
||
|
:, None
|
||
|
].expand_as(image_emb)
|
||
|
|
||
|
return (
|
||
|
image_flag.float() * image_emb
|
||
|
+ text_flag.float() * text_emb
|
||
|
+ emb_flag.float() * joined_embs
|
||
|
+ self.clip_model.logit_scale * 0 # avoid unused parameters
|
||
|
)
|
||
|
|
||
|
def embed_images(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor:
|
||
|
"""
|
||
|
:param xs: N images, stored as numpy arrays, tensors, or PIL images.
|
||
|
:return: an [N x D] tensor of features.
|
||
|
"""
|
||
|
clip_inputs = self.images_to_tensor(xs)
|
||
|
results = self.clip_model.encode_image(clip_inputs).float()
|
||
|
return results / torch.linalg.norm(results, dim=-1, keepdim=True)
|
||
|
|
||
|
def embed_text(self, prompts: Iterable[str]) -> torch.Tensor:
|
||
|
"""
|
||
|
Embed text prompts as an [N x D] tensor.
|
||
|
"""
|
||
|
enc = self.clip_model.encode_text(
|
||
|
self._tokenize(list(prompts), truncate=True).to(self.device)
|
||
|
).float()
|
||
|
return enc / torch.linalg.norm(enc, dim=-1, keepdim=True)
|
||
|
|
||
|
def embed_images_grid(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor:
|
||
|
"""
|
||
|
Embed images into latent grids.
|
||
|
|
||
|
:param xs: an iterable of images to embed.
|
||
|
:return: a tensor of shape [N x C x L], where L = self.grid_size**2.
|
||
|
"""
|
||
|
if self.ensure_used_params:
|
||
|
extra_value = 0.0
|
||
|
for p in self.parameters():
|
||
|
extra_value = extra_value + p.mean() * 0.0
|
||
|
else:
|
||
|
extra_value = 0.0
|
||
|
|
||
|
x = self.images_to_tensor(xs).to(self.clip_model.dtype)
|
||
|
|
||
|
# https://github.com/openai/CLIP/blob/4d120f3ec35b30bd0f992f5d8af2d793aad98d2a/clip/model.py#L225
|
||
|
vt = self.clip_model.visual
|
||
|
x = vt.conv1(x) # shape = [*, width, grid, grid]
|
||
|
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
||
|
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
||
|
x = torch.cat(
|
||
|
[
|
||
|
vt.class_embedding.to(x.dtype)
|
||
|
+ torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
|
||
|
x,
|
||
|
],
|
||
|
dim=1,
|
||
|
) # shape = [*, grid ** 2 + 1, width]
|
||
|
x = x + vt.positional_embedding.to(x.dtype)
|
||
|
x = vt.ln_pre(x)
|
||
|
|
||
|
x = x.permute(1, 0, 2) # NLD -> LND
|
||
|
x = vt.transformer(x)
|
||
|
x = x.permute(1, 2, 0) # LND -> NDL
|
||
|
|
||
|
return x[..., 1:].contiguous().float() + extra_value
|
||
|
|
||
|
def images_to_tensor(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor:
|
||
|
return torch.stack([self.preprocess(_image_to_pil(x)) for x in xs], dim=0).to(self.device)
|
||
|
|
||
|
|
||
|
class FrozenImageCLIP:
|
||
|
def __init__(self, device: torch.device, **kwargs):
|
||
|
self.model = ImageCLIP(device, dtype=None, ensure_used_params=False, **kwargs)
|
||
|
for parameter in self.model.parameters():
|
||
|
parameter.requires_grad_(False)
|
||
|
|
||
|
@property
|
||
|
def feature_dim(self) -> int:
|
||
|
return self.model.feature_dim
|
||
|
|
||
|
@property
|
||
|
def grid_size(self) -> int:
|
||
|
return self.model.grid_size
|
||
|
|
||
|
@property
|
||
|
def grid_feature_dim(self) -> int:
|
||
|
return self.model.grid_feature_dim
|
||
|
|
||
|
def __call__(
|
||
|
self,
|
||
|
batch_size: int,
|
||
|
images: Optional[Iterable[Optional[ImageType]]] = None,
|
||
|
texts: Optional[Iterable[Optional[str]]] = None,
|
||
|
embeddings: Optional[Iterable[Optional[torch.Tensor]]] = None,
|
||
|
) -> torch.Tensor:
|
||
|
# We don't do a no_grad() here so that gradients could still
|
||
|
# flow to the input embeddings argument.
|
||
|
# This behavior is currently not used, but it could be.
|
||
|
return self.model(batch_size=batch_size, images=images, texts=texts, embeddings=embeddings)
|
||
|
|
||
|
def embed_images(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor:
|
||
|
with torch.no_grad():
|
||
|
return self.model.embed_images(xs)
|
||
|
|
||
|
def embed_text(self, prompts: Iterable[str]) -> torch.Tensor:
|
||
|
with torch.no_grad():
|
||
|
return self.model.embed_text(prompts)
|
||
|
|
||
|
def embed_images_grid(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor:
|
||
|
with torch.no_grad():
|
||
|
return self.model.embed_images_grid(xs)
|
||
|
|
||
|
|
||
|
def _image_to_pil(obj: Optional[ImageType]) -> Image.Image:
|
||
|
if obj is None:
|
||
|
return Image.fromarray(np.zeros([64, 64, 3], dtype=np.uint8))
|
||
|
if isinstance(obj, np.ndarray):
|
||
|
return Image.fromarray(obj.astype(np.uint8))
|
||
|
elif isinstance(obj, torch.Tensor):
|
||
|
return Image.fromarray(obj.detach().cpu().numpy().astype(np.uint8))
|
||
|
else:
|
||
|
return obj
|