Spaces:
Running
on
L40S
Running
on
L40S
from PIL import Image | |
import glob | |
import io | |
import argparse | |
import inspect | |
import os | |
import random | |
from typing import Dict, Optional, Tuple | |
from omegaconf import OmegaConf | |
import numpy as np | |
import torch | |
from diffusers import AutoencoderKL, DDIMScheduler | |
from diffusers.utils import check_min_version | |
from tqdm.auto import tqdm | |
from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor, CLIPVisionModelWithProjection | |
from torchvision import transforms | |
from canonicalize.models.unet_mv2d_condition import UNetMV2DConditionModel | |
from canonicalize.models.unet_mv2d_ref import UNetMV2DRefModel | |
from canonicalize.pipeline_canonicalize import CanonicalizationPipeline | |
from einops import rearrange | |
from torchvision.utils import save_image | |
import json | |
import cv2 | |
import onnxruntime as rt | |
from huggingface_hub.file_download import hf_hub_download | |
from rm_anime_bg.cli import get_mask, SCALE | |
check_min_version("0.24.0") | |
weight_dtype = torch.float16 | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
class BkgRemover: | |
def __init__(self, force_cpu: Optional[bool] = True): | |
session_infer_path = hf_hub_download( | |
repo_id="skytnt/anime-seg", filename="isnetis.onnx", | |
) | |
providers: list[str] = ["CPUExecutionProvider"] | |
if not force_cpu and "CUDAExecutionProvider" in rt.get_available_providers(): | |
providers = ["CUDAExecutionProvider"] | |
self.session_infer = rt.InferenceSession( | |
session_infer_path, providers=providers, | |
) | |
def remove_background( | |
self, | |
img: np.ndarray, | |
alpha_min: float, | |
alpha_max: float, | |
) -> list: | |
img = np.array(img) | |
mask = get_mask(self.session_infer, img) | |
mask[mask < alpha_min] = 0.0 | |
mask[mask > alpha_max] = 1.0 | |
img_after = (mask * img).astype(np.uint8) | |
mask = (mask * SCALE).astype(np.uint8) | |
img_after = np.concatenate([img_after, mask], axis=2, dtype=np.uint8) | |
return Image.fromarray(img_after) | |
def set_seed(seed): | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
def process_image(image, totensor, width, height): | |
assert image.mode == "RGBA" | |
# Find non-transparent pixels | |
non_transparent = np.nonzero(np.array(image)[..., 3]) | |
min_x, max_x = non_transparent[1].min(), non_transparent[1].max() | |
min_y, max_y = non_transparent[0].min(), non_transparent[0].max() | |
image = image.crop((min_x, min_y, max_x, max_y)) | |
# paste to center | |
max_dim = max(image.width, image.height) | |
max_height = int(max_dim * 1.2) | |
max_width = int(max_dim / (height/width) * 1.2) | |
new_image = Image.new("RGBA", (max_width, max_height)) | |
left = (max_width - image.width) // 2 | |
top = (max_height - image.height) // 2 | |
new_image.paste(image, (left, top)) | |
image = new_image.resize((width, height), resample=Image.BICUBIC) | |
image = np.array(image) | |
image = image.astype(np.float32) / 255. | |
assert image.shape[-1] == 4 # RGBA | |
alpha = image[..., 3:4] | |
bg_color = np.array([1., 1., 1.], dtype=np.float32) | |
image = image[..., :3] * alpha + bg_color * (1 - alpha) | |
return totensor(image) | |
def inference(validation_pipeline, bkg_remover, input_image, vae, feature_extractor, image_encoder, unet, ref_unet, tokenizer, | |
text_encoder, pretrained_model_path, generator, validation, val_width, val_height, unet_condition_type, | |
use_noise=True, noise_d=256, crop=False, seed=100, timestep=20): | |
set_seed(seed) | |
totensor = transforms.ToTensor() | |
prompts = "high quality, best quality" | |
prompt_ids = tokenizer( | |
prompts, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, | |
return_tensors="pt" | |
).input_ids[0] | |
# (B*Nv, 3, H, W) | |
B = 1 | |
if input_image.mode != "RGBA": | |
# remove background | |
input_image = bkg_remover.remove_background(input_image, 0.1, 0.9) | |
imgs_in = process_image(input_image, totensor, val_width, val_height) | |
imgs_in = rearrange(imgs_in.unsqueeze(0).unsqueeze(0), "B Nv C H W -> (B Nv) C H W") | |
with torch.autocast('cuda' if torch.cuda.is_available() else 'cpu', dtype=weight_dtype): | |
imgs_in = imgs_in.to(device=device) | |
# B*Nv images | |
out = validation_pipeline(prompt=prompts, image=imgs_in.to(weight_dtype), generator=generator, | |
num_inference_steps=timestep, prompt_ids=prompt_ids, | |
height=val_height, width=val_width, unet_condition_type=unet_condition_type, | |
use_noise=use_noise, **validation,) | |
out = rearrange(out, "B C f H W -> (B f) C H W", f=1) | |
img_buf = io.BytesIO() | |
save_image(out[0], img_buf, format='PNG') | |
img_buf.seek(0) | |
img = Image.open(img_buf) | |
torch.cuda.empty_cache() | |
return img | |
def main( | |
input_dir: str, | |
output_dir: str, | |
pretrained_model_path: str, | |
validation: Dict, | |
local_crossattn: bool = True, | |
unet_from_pretrained_kwargs=None, | |
unet_condition_type=None, | |
use_noise=True, | |
noise_d=256, | |
seed: int = 42, | |
timestep: int = 40, | |
width_input: int = 640, | |
height_input: int = 1024, | |
): | |
*_, config = inspect.getargvalues(inspect.currentframe()) | |
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") | |
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") | |
image_encoder = CLIPVisionModelWithProjection.from_pretrained(pretrained_model_path, subfolder="image_encoder") | |
feature_extractor = CLIPImageProcessor() | |
vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") | |
unet = UNetMV2DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", local_crossattn=local_crossattn, **unet_from_pretrained_kwargs) | |
ref_unet = UNetMV2DRefModel.from_pretrained_2d(pretrained_model_path, subfolder="ref_unet", local_crossattn=local_crossattn, **unet_from_pretrained_kwargs) | |
text_encoder.to(device, dtype=weight_dtype) | |
image_encoder.to(device, dtype=weight_dtype) | |
vae.to(device, dtype=weight_dtype) | |
ref_unet.to(device, dtype=weight_dtype) | |
unet.to(device, dtype=weight_dtype) | |
vae.requires_grad_(False) | |
unet.requires_grad_(False) | |
ref_unet.requires_grad_(False) | |
# set pipeline | |
noise_scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler-zerosnr") | |
validation_pipeline = CanonicalizationPipeline( | |
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, ref_unet=ref_unet,feature_extractor=feature_extractor,image_encoder=image_encoder, | |
scheduler=noise_scheduler | |
) | |
validation_pipeline.set_progress_bar_config(disable=True) | |
bkg_remover = BkgRemover() | |
def canonicalize(image, width, height, seed, timestep): | |
generator = torch.Generator(device=device).manual_seed(seed) | |
return inference( | |
validation_pipeline, bkg_remover, image, vae, feature_extractor, image_encoder, unet, ref_unet, tokenizer, text_encoder, | |
pretrained_model_path, generator, validation, width, height, unet_condition_type, | |
use_noise=use_noise, noise_d=noise_d, crop=True, seed=seed, timestep=timestep | |
) | |
img_paths = sorted(glob.glob(os.path.join(input_dir, "*.png"))) | |
os.makedirs(output_dir, exist_ok=True) | |
for path in tqdm(img_paths): | |
img_input = Image.open(path) | |
if np.array(img_input)[..., 3].min() == 255: | |
# convert to RGB | |
img_input = img_input.convert("RGB") | |
img_output = canonicalize(img_input, width_input, height_input, seed, timestep) | |
img_output.save(os.path.join(output_dir, f"{os.path.basename(path).split('.')[0]}.png")) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--config", type=str, default="./configs/canonicalization-infer.yaml") | |
parser.add_argument("--input_dir", type=str, default="./input_cases") | |
parser.add_argument("--output_dir", type=str, default="./result/apose") | |
parser.add_argument("--seed", type=int, default=42) | |
args = parser.parse_args() | |
main(**OmegaConf.load(args.config), seed=args.seed, input_dir=args.input_dir, output_dir=args.output_dir) |