import os import importlib import imageio import torch import numpy as np import PIL.Image from PIL import Image from typing import Any from torchvision import transforms def instantiate_from_config(config): if not "target" in config: if config == '__is_first_stage__': return None elif config == "__is_unconditional__": return None raise KeyError("Expected key `target` to instantiate.") return get_obj_from_str(config["target"])(**config.get("params", dict())) def get_obj_from_str(string, reload=False): module, cls = string.rsplit(".", 1) if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) return getattr(importlib.import_module(module, package=None), cls) # def resize_without_crop(pil_image, target_width, target_height): # resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS) # return np.array(resized_image)[:, :, :3] # @torch.inference_mode() # def numpy2pytorch(imgs): # h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 255.0 * 2.0 - 1.0 # h = h.movedim(-1, 1) # return h # @torch.inference_mode() # def remove_background( # image: PIL.Image.Image, # rembg: Any = None, # force: bool = False, # **rembg_kwargs, # ) -> PIL.Image.Image: # do_remove = True # if image.mode == "RGBA" and image.getextrema()[3][0] < 255: # do_remove = False # do_remove = do_remove or force # if do_remove: # W, H = image.size # k = (256.0 / float(H * W)) ** 0.5 # feed = resize_without_crop(image, int(64 * round(W * k)), int(64 * round(H * k))) # feed = numpy2pytorch([feed]).to(device=rembg.device, dtype=torch.float32) # alpha = rembg(feed)[0][0] # alpha = torch.nn.functional.interpolate(alpha, size=(H, W), mode="bilinear") # alpha = alpha.squeeze().clamp(0, 1) # alpha = (alpha * 255).cpu().data.numpy().astype(np.uint8) # alpha = Image.fromarray(alpha) # no_bg_image = Image.new("RGBA", alpha.size, (0, 0, 0, 0)) # no_bg_image.paste(image, mask=alpha) # image = no_bg_image # return image @torch.inference_mode() def remove_background( image: PIL.Image.Image, rembg: Any = None, force: bool = False, **rembg_kwargs, ) -> PIL.Image.Image: do_remove = True if image.mode == "RGBA" and image.getextrema()[3][0] < 255: do_remove = False do_remove = do_remove or force if do_remove: transform_image = transforms.Compose([ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) image = image.convert('RGB') input_images = transform_image(image).unsqueeze(0).to(rembg.device) with torch.no_grad(): preds = rembg(input_images)[-1].sigmoid().cpu() pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(image.size) image.putalpha(mask) return image def resize_foreground( image: PIL.Image.Image, ratio: float, ) -> PIL.Image.Image: image = np.array(image) assert image.shape[-1] == 4 alpha = np.where(image[..., 3] > 0) y1, y2, x1, x2 = ( alpha[0].min(), alpha[0].max(), alpha[1].min(), alpha[1].max(), ) # crop the foreground fg = image[y1:y2, x1:x2] # pad to square size = max(fg.shape[0], fg.shape[1]) ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2 ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0 new_image = np.pad( fg, ((ph0, ph1), (pw0, pw1), (0, 0)), mode="constant", constant_values=((0, 0), (0, 0), (0, 0)), ) # compute padding according to the ratio new_size = int(new_image.shape[0] / ratio) # pad to size, double side ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2 ph1, pw1 = new_size - size - ph0, new_size - size - pw0 new_image = np.pad( new_image, ((ph0, ph1), (pw0, pw1), (0, 0)), mode="constant", constant_values=((0, 0), (0, 0), (0, 0)), ) new_image = Image.fromarray(new_image) return new_image def rgba_to_white_background(image: PIL.Image.Image) -> torch.Tensor: image = np.asarray(image, dtype=np.float32) / 255.0 image = torch.from_numpy(image).movedim(2, 0).float() image, alpha = image.split([3, 1], dim=0) image = image * alpha + torch.ones_like(image) * (1 - alpha) return image, alpha def save_video( frames: torch.Tensor, output_path: str, fps: int = 30, ) -> None: # images: (N, C, H, W) frames = [(frame.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) for frame in frames] writer = imageio.get_writer(output_path, mode='I', fps=fps, codec='libx264') for frame in frames: writer.append_data(frame) writer.close()