a fork of shap-e for gc
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

171 lines
5.1 KiB

2 years ago
import random
from typing import Any, List, Optional, Union
import blobfile as bf
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
def center_crop(
img: Union[Image.Image, torch.Tensor, np.ndarray]
) -> Union[Image.Image, torch.Tensor, np.ndarray]:
"""
Center crops an image.
"""
if isinstance(img, (np.ndarray, torch.Tensor)):
height, width = img.shape[:2]
else:
width, height = img.size
size = min(width, height)
left, top = (width - size) // 2, (height - size) // 2
right, bottom = left + size, top + size
if isinstance(img, (np.ndarray, torch.Tensor)):
img = img[top:bottom, left:right]
else:
img = img.crop((left, top, right, bottom))
return img
def resize(
img: Union[Image.Image, torch.Tensor, np.ndarray],
*,
height: int,
width: int,
min_value: Optional[Any] = None,
max_value: Optional[Any] = None,
) -> Union[Image.Image, torch.Tensor, np.ndarray]:
"""
:param: img: image in HWC order
:return: currently written for downsampling
"""
orig, cls = img, type(img)
if isinstance(img, Image.Image):
img = np.array(img)
dtype = img.dtype
if isinstance(img, np.ndarray):
img = torch.from_numpy(img)
ndim = img.ndim
if img.ndim == 2:
img = img.unsqueeze(-1)
if min_value is None and max_value is None:
# .clamp throws an error when both are None
min_value = -np.inf
img = img.permute(2, 0, 1)
size = (height, width)
img = (
F.interpolate(img[None].float(), size=size, mode="area")[0]
.clamp(min_value, max_value)
.to(img.dtype)
.permute(1, 2, 0)
)
if ndim < img.ndim:
img = img.squeeze(-1)
if not isinstance(orig, torch.Tensor):
img = img.numpy()
img = img.astype(dtype)
if isinstance(orig, Image.Image):
img = Image.fromarray(img)
return img
def get_alpha(img: Image.Image) -> Image.Image:
"""
:return: the alpha channel separated out as a grayscale image
"""
img_arr = np.asarray(img)
if img_arr.shape[2] == 4:
alpha = img_arr[:, :, 3]
else:
alpha = np.full(img_arr.shape[:2], 255, dtype=np.uint8)
alpha = Image.fromarray(alpha)
return alpha
def remove_alpha(img: Image.Image, mode: str = "random") -> Image.Image:
"""
No op if the image doesn't have an alpha channel.
:param: mode: Defaults to "random" but has an option to use a "black" or
"white" background
:return: image with alpha removed
"""
img_arr = np.asarray(img)
if img_arr.shape[2] == 4:
# Add bg to get rid of alpha channel
if mode == "random":
height, width = img_arr.shape[:2]
bg = Image.fromarray(
random.choice([_black_bg, _gray_bg, _checker_bg, _noise_bg])(height, width)
)
bg.paste(img, mask=img)
img = bg
elif mode == "black" or mode == "white":
img_arr = img_arr.astype(float)
rgb, alpha = img_arr[:, :, :3], img_arr[:, :, -1:] / 255
background = np.zeros((1, 1, 3)) if mode == "black" else np.full((1, 1, 3), 255)
rgb = rgb * alpha + background * (1 - alpha)
img = Image.fromarray(np.round(rgb).astype(np.uint8))
return img
def _black_bg(h: int, w: int) -> np.ndarray:
return np.zeros([h, w, 3], dtype=np.uint8)
def _gray_bg(h: int, w: int) -> np.ndarray:
return (np.zeros([h, w, 3]) + np.random.randint(low=0, high=256)).astype(np.uint8)
def _checker_bg(h: int, w: int) -> np.ndarray:
checker_size = np.ceil(np.exp(np.random.uniform() * np.log(min(h, w))))
c1 = np.random.randint(low=0, high=256)
c2 = np.random.randint(low=0, high=256)
xs = np.arange(w)[None, :, None] + np.random.randint(low=0, high=checker_size + 1)
ys = np.arange(h)[:, None, None] + np.random.randint(low=0, high=checker_size + 1)
fields = np.logical_xor((xs // checker_size) % 2 == 0, (ys // checker_size) % 2 == 0)
return np.where(fields, np.array([c1] * 3), np.array([c2] * 3)).astype(np.uint8)
def _noise_bg(h: int, w: int) -> np.ndarray:
return np.random.randint(low=0, high=256, size=[h, w, 3]).astype(np.uint8)
def load_image(image_path: str) -> Image.Image:
with bf.BlobFile(image_path, "rb") as thefile:
img = Image.open(thefile)
img.load()
return img
def make_tile(images: List[Union[np.ndarray, Image.Image]], columns=8) -> Image.Image:
"""
to test, run
>>> display(make_tile([(np.zeros((128, 128, 3)) + c).astype(np.uint8) for c in np.linspace(0, 255, 15)]))
"""
images = list(map(np.array, images))
size = images[0].shape[0]
n = round_up(len(images), columns)
n_blanks = n - len(images)
images.extend([np.zeros((size, size, 3), dtype=np.uint8)] * n_blanks)
images = (
np.array(images)
.reshape(n // columns, columns, size, size, 3)
.transpose([0, 2, 1, 3, 4])
.reshape(n // columns * size, columns * size, 3)
)
return Image.fromarray(images)
def round_up(n: int, b: int) -> int:
return (n + b - 1) // b * b