You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
171 lines
5.1 KiB
171 lines
5.1 KiB
2 years ago
|
import random
|
||
|
from typing import Any, List, Optional, Union
|
||
|
|
||
|
import blobfile as bf
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
import torch.nn.functional as F
|
||
|
from PIL import Image
|
||
|
|
||
|
|
||
|
def center_crop(
|
||
|
img: Union[Image.Image, torch.Tensor, np.ndarray]
|
||
|
) -> Union[Image.Image, torch.Tensor, np.ndarray]:
|
||
|
"""
|
||
|
Center crops an image.
|
||
|
"""
|
||
|
if isinstance(img, (np.ndarray, torch.Tensor)):
|
||
|
height, width = img.shape[:2]
|
||
|
else:
|
||
|
width, height = img.size
|
||
|
size = min(width, height)
|
||
|
left, top = (width - size) // 2, (height - size) // 2
|
||
|
right, bottom = left + size, top + size
|
||
|
if isinstance(img, (np.ndarray, torch.Tensor)):
|
||
|
img = img[top:bottom, left:right]
|
||
|
else:
|
||
|
img = img.crop((left, top, right, bottom))
|
||
|
return img
|
||
|
|
||
|
|
||
|
def resize(
|
||
|
img: Union[Image.Image, torch.Tensor, np.ndarray],
|
||
|
*,
|
||
|
height: int,
|
||
|
width: int,
|
||
|
min_value: Optional[Any] = None,
|
||
|
max_value: Optional[Any] = None,
|
||
|
) -> Union[Image.Image, torch.Tensor, np.ndarray]:
|
||
|
"""
|
||
|
:param: img: image in HWC order
|
||
|
:return: currently written for downsampling
|
||
|
"""
|
||
|
|
||
|
orig, cls = img, type(img)
|
||
|
if isinstance(img, Image.Image):
|
||
|
img = np.array(img)
|
||
|
dtype = img.dtype
|
||
|
if isinstance(img, np.ndarray):
|
||
|
img = torch.from_numpy(img)
|
||
|
ndim = img.ndim
|
||
|
if img.ndim == 2:
|
||
|
img = img.unsqueeze(-1)
|
||
|
|
||
|
if min_value is None and max_value is None:
|
||
|
# .clamp throws an error when both are None
|
||
|
min_value = -np.inf
|
||
|
|
||
|
img = img.permute(2, 0, 1)
|
||
|
size = (height, width)
|
||
|
img = (
|
||
|
F.interpolate(img[None].float(), size=size, mode="area")[0]
|
||
|
.clamp(min_value, max_value)
|
||
|
.to(img.dtype)
|
||
|
.permute(1, 2, 0)
|
||
|
)
|
||
|
|
||
|
if ndim < img.ndim:
|
||
|
img = img.squeeze(-1)
|
||
|
if not isinstance(orig, torch.Tensor):
|
||
|
img = img.numpy()
|
||
|
img = img.astype(dtype)
|
||
|
if isinstance(orig, Image.Image):
|
||
|
img = Image.fromarray(img)
|
||
|
|
||
|
return img
|
||
|
|
||
|
|
||
|
def get_alpha(img: Image.Image) -> Image.Image:
|
||
|
"""
|
||
|
:return: the alpha channel separated out as a grayscale image
|
||
|
"""
|
||
|
img_arr = np.asarray(img)
|
||
|
if img_arr.shape[2] == 4:
|
||
|
alpha = img_arr[:, :, 3]
|
||
|
else:
|
||
|
alpha = np.full(img_arr.shape[:2], 255, dtype=np.uint8)
|
||
|
alpha = Image.fromarray(alpha)
|
||
|
return alpha
|
||
|
|
||
|
|
||
|
def remove_alpha(img: Image.Image, mode: str = "random") -> Image.Image:
|
||
|
"""
|
||
|
No op if the image doesn't have an alpha channel.
|
||
|
|
||
|
:param: mode: Defaults to "random" but has an option to use a "black" or
|
||
|
"white" background
|
||
|
|
||
|
:return: image with alpha removed
|
||
|
"""
|
||
|
img_arr = np.asarray(img)
|
||
|
if img_arr.shape[2] == 4:
|
||
|
# Add bg to get rid of alpha channel
|
||
|
if mode == "random":
|
||
|
height, width = img_arr.shape[:2]
|
||
|
bg = Image.fromarray(
|
||
|
random.choice([_black_bg, _gray_bg, _checker_bg, _noise_bg])(height, width)
|
||
|
)
|
||
|
bg.paste(img, mask=img)
|
||
|
img = bg
|
||
|
elif mode == "black" or mode == "white":
|
||
|
img_arr = img_arr.astype(float)
|
||
|
rgb, alpha = img_arr[:, :, :3], img_arr[:, :, -1:] / 255
|
||
|
background = np.zeros((1, 1, 3)) if mode == "black" else np.full((1, 1, 3), 255)
|
||
|
rgb = rgb * alpha + background * (1 - alpha)
|
||
|
img = Image.fromarray(np.round(rgb).astype(np.uint8))
|
||
|
return img
|
||
|
|
||
|
|
||
|
def _black_bg(h: int, w: int) -> np.ndarray:
|
||
|
return np.zeros([h, w, 3], dtype=np.uint8)
|
||
|
|
||
|
|
||
|
def _gray_bg(h: int, w: int) -> np.ndarray:
|
||
|
return (np.zeros([h, w, 3]) + np.random.randint(low=0, high=256)).astype(np.uint8)
|
||
|
|
||
|
|
||
|
def _checker_bg(h: int, w: int) -> np.ndarray:
|
||
|
checker_size = np.ceil(np.exp(np.random.uniform() * np.log(min(h, w))))
|
||
|
c1 = np.random.randint(low=0, high=256)
|
||
|
c2 = np.random.randint(low=0, high=256)
|
||
|
|
||
|
xs = np.arange(w)[None, :, None] + np.random.randint(low=0, high=checker_size + 1)
|
||
|
ys = np.arange(h)[:, None, None] + np.random.randint(low=0, high=checker_size + 1)
|
||
|
|
||
|
fields = np.logical_xor((xs // checker_size) % 2 == 0, (ys // checker_size) % 2 == 0)
|
||
|
return np.where(fields, np.array([c1] * 3), np.array([c2] * 3)).astype(np.uint8)
|
||
|
|
||
|
|
||
|
def _noise_bg(h: int, w: int) -> np.ndarray:
|
||
|
return np.random.randint(low=0, high=256, size=[h, w, 3]).astype(np.uint8)
|
||
|
|
||
|
|
||
|
def load_image(image_path: str) -> Image.Image:
|
||
|
with bf.BlobFile(image_path, "rb") as thefile:
|
||
|
img = Image.open(thefile)
|
||
|
img.load()
|
||
|
return img
|
||
|
|
||
|
|
||
|
def make_tile(images: List[Union[np.ndarray, Image.Image]], columns=8) -> Image.Image:
|
||
|
"""
|
||
|
to test, run
|
||
|
>>> display(make_tile([(np.zeros((128, 128, 3)) + c).astype(np.uint8) for c in np.linspace(0, 255, 15)]))
|
||
|
"""
|
||
|
images = list(map(np.array, images))
|
||
|
size = images[0].shape[0]
|
||
|
n = round_up(len(images), columns)
|
||
|
n_blanks = n - len(images)
|
||
|
images.extend([np.zeros((size, size, 3), dtype=np.uint8)] * n_blanks)
|
||
|
images = (
|
||
|
np.array(images)
|
||
|
.reshape(n // columns, columns, size, size, 3)
|
||
|
.transpose([0, 2, 1, 3, 4])
|
||
|
.reshape(n // columns * size, columns * size, 3)
|
||
|
)
|
||
|
return Image.fromarray(images)
|
||
|
|
||
|
|
||
|
def round_up(n: int, b: int) -> int:
|
||
|
return (n + b - 1) // b * b
|