a fork of shap-e for gc
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

170 lines
5.1 KiB

import random
from typing import Any, List, Optional, Union
import blobfile as bf
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
def center_crop(
img: Union[Image.Image, torch.Tensor, np.ndarray]
) -> Union[Image.Image, torch.Tensor, np.ndarray]:
"""
Center crops an image.
"""
if isinstance(img, (np.ndarray, torch.Tensor)):
height, width = img.shape[:2]
else:
width, height = img.size
size = min(width, height)
left, top = (width - size) // 2, (height - size) // 2
right, bottom = left + size, top + size
if isinstance(img, (np.ndarray, torch.Tensor)):
img = img[top:bottom, left:right]
else:
img = img.crop((left, top, right, bottom))
return img
def resize(
img: Union[Image.Image, torch.Tensor, np.ndarray],
*,
height: int,
width: int,
min_value: Optional[Any] = None,
max_value: Optional[Any] = None,
) -> Union[Image.Image, torch.Tensor, np.ndarray]:
"""
:param: img: image in HWC order
:return: currently written for downsampling
"""
orig, cls = img, type(img)
if isinstance(img, Image.Image):
img = np.array(img)
dtype = img.dtype
if isinstance(img, np.ndarray):
img = torch.from_numpy(img)
ndim = img.ndim
if img.ndim == 2:
img = img.unsqueeze(-1)
if min_value is None and max_value is None:
# .clamp throws an error when both are None
min_value = -np.inf
img = img.permute(2, 0, 1)
size = (height, width)
img = (
F.interpolate(img[None].float(), size=size, mode="area")[0]
.clamp(min_value, max_value)
.to(img.dtype)
.permute(1, 2, 0)
)
if ndim < img.ndim:
img = img.squeeze(-1)
if not isinstance(orig, torch.Tensor):
img = img.numpy()
img = img.astype(dtype)
if isinstance(orig, Image.Image):
img = Image.fromarray(img)
return img
def get_alpha(img: Image.Image) -> Image.Image:
"""
:return: the alpha channel separated out as a grayscale image
"""
img_arr = np.asarray(img)
if img_arr.shape[2] == 4:
alpha = img_arr[:, :, 3]
else:
alpha = np.full(img_arr.shape[:2], 255, dtype=np.uint8)
alpha = Image.fromarray(alpha)
return alpha
def remove_alpha(img: Image.Image, mode: str = "random") -> Image.Image:
"""
No op if the image doesn't have an alpha channel.
:param: mode: Defaults to "random" but has an option to use a "black" or
"white" background
:return: image with alpha removed
"""
img_arr = np.asarray(img)
if img_arr.shape[2] == 4:
# Add bg to get rid of alpha channel
if mode == "random":
height, width = img_arr.shape[:2]
bg = Image.fromarray(
random.choice([_black_bg, _gray_bg, _checker_bg, _noise_bg])(height, width)
)
bg.paste(img, mask=img)
img = bg
elif mode == "black" or mode == "white":
img_arr = img_arr.astype(float)
rgb, alpha = img_arr[:, :, :3], img_arr[:, :, -1:] / 255
background = np.zeros((1, 1, 3)) if mode == "black" else np.full((1, 1, 3), 255)
rgb = rgb * alpha + background * (1 - alpha)
img = Image.fromarray(np.round(rgb).astype(np.uint8))
return img
def _black_bg(h: int, w: int) -> np.ndarray:
return np.zeros([h, w, 3], dtype=np.uint8)
def _gray_bg(h: int, w: int) -> np.ndarray:
return (np.zeros([h, w, 3]) + np.random.randint(low=0, high=256)).astype(np.uint8)
def _checker_bg(h: int, w: int) -> np.ndarray:
checker_size = np.ceil(np.exp(np.random.uniform() * np.log(min(h, w))))
c1 = np.random.randint(low=0, high=256)
c2 = np.random.randint(low=0, high=256)
xs = np.arange(w)[None, :, None] + np.random.randint(low=0, high=checker_size + 1)
ys = np.arange(h)[:, None, None] + np.random.randint(low=0, high=checker_size + 1)
fields = np.logical_xor((xs // checker_size) % 2 == 0, (ys // checker_size) % 2 == 0)
return np.where(fields, np.array([c1] * 3), np.array([c2] * 3)).astype(np.uint8)
def _noise_bg(h: int, w: int) -> np.ndarray:
return np.random.randint(low=0, high=256, size=[h, w, 3]).astype(np.uint8)
def load_image(image_path: str) -> Image.Image:
with bf.BlobFile(image_path, "rb") as thefile:
img = Image.open(thefile)
img.load()
return img
def make_tile(images: List[Union[np.ndarray, Image.Image]], columns=8) -> Image.Image:
"""
to test, run
>>> display(make_tile([(np.zeros((128, 128, 3)) + c).astype(np.uint8) for c in np.linspace(0, 255, 15)]))
"""
images = list(map(np.array, images))
size = images[0].shape[0]
n = round_up(len(images), columns)
n_blanks = n - len(images)
images.extend([np.zeros((size, size, 3), dtype=np.uint8)] * n_blanks)
images = (
np.array(images)
.reshape(n // columns, columns, size, size, 3)
.transpose([0, 2, 1, 3, 4])
.reshape(n // columns * size, columns * size, 3)
)
return Image.fromarray(images)
def round_up(n: int, b: int) -> int:
return (n + b - 1) // b * b