File size: 4,871 Bytes
fa9e198 829ca2d e4a482e 829ca2d fa9e198 829ca2d fa9e198 829ca2d fa9e198 a0b203d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import os
import gradio as gr
from huggingface_hub import hf_hub_download
import onnxruntime as ort
import cv2
import numpy as np
from facenet_pytorch import MTCNN
from torchvision import transforms
import cv2
import torch
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device_name = ort.get_device()
if device_name == 'cpu':
providers = ['CPUExecutionProvider']
elif device_name == 'GPU':
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
#load model
mtcnn = MTCNN(image_size=256, margin=0, min_face_size=128, thresholds=[0.7, 0.8, 0.9], device=device)
# MTCNN for face detection with landmarks
def detect(img):
# Detect faces
batch_boxes, batch_probs, batch_points = mtcnn.detect(img, landmarks=True)
return batch_boxes, batch_points
# Expand the area around the detected face by margin {ratio} pixels
def margin_face(box, img_HW, margin=0.5):
x1, y1, x2, y2 = [c for c in box]
w, h = x2 - x1, y2 - y1
new_x1 = max(0, x1 - margin*w)
new_x2 = min(img_HW[1], x2 + margin * w)
x_d = min(x1-new_x1, new_x2-x2)
new_w = x2 -x1 + 2 * x_d
new_x1 = x1-x_d
new_x2 = x2+x_d
# new_h = 1.25 * new_w
new_h = 1.0 * new_w
if new_h>=h:
y_d = new_h-h
new_y1 = max(0, y1 - y_d//2)
new_y2 = min(img_HW[0], y2 + y_d//2)
else:
y_d = abs(new_h - h)
new_y1 = max(0, y1 + y_d // 2)
new_y2 = min(img_HW[0], y2 - y_d // 2)
return list(map(int, [new_x1, new_y1, new_x2, new_y2]))
def process_image(img, x32=True):
h, w = img.shape[:2]
if x32: # resize image to multiple of 32s
def to_32s(x):
return 256 if x < 256 else x - x%32
img = cv2.resize(img, (to_32s(w), to_32s(h)))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)/ 127.5 - 1.0
return img
def load_image(image_path, focus_face):
img0 = cv2.imread(image_path).astype(np.float32)
if focus_face == "Yes":
batch_boxes, batch_points = detect(img0)
if batch_boxes is None:
print("No face detected !")
return
[x1, y1, x2, y2] = margin_face(batch_boxes[0], img0.shape[:2])
img0 = img0[y1:y2, x1:x2]
img = process_image(img0)
img = np.expand_dims(img, axis=0)
return img, img0.shape[:2]
def convert(img, model, scale):
session = ort.InferenceSession(MODEL_PATH[model], providers=providers)
x = session.get_inputs()[0].name
y = session.get_outputs()[0].name
fake_img = session.run(None, {x : img})[0]
images = (np.squeeze(fake_img) + 1.) / 2 * 255
images = np.clip(images, 0, 255).astype(np.uint8)
output_image = cv2.resize(images, (scale[1],scale[0]))
return cv2.cvtColor(output_image, cv2.COLOR_RGB2BGR)
os.makedirs('output', exist_ok=True)
MODEL_PATH = {
"AnimeGANv2_Hayao": hf_hub_download('vumichien/AnimeGANv2_Hayao', 'AnimeGANv2_Hayao.onnx'),
"AnimeGANv2_Shinkai": hf_hub_download('vumichien/AnimeGANv2_Shinkai', 'AnimeGANv2_Shinkai.onnx'),
"AnimeGANv2_Paprika": hf_hub_download('vumichien/AnimeGANv2_Paprika', 'AnimeGANv2_Paprika.onnx'),
"AnimeGANv3_PortraitSketch": hf_hub_download('vumichien/AnimeGANv3_PortraitSketch', 'AnimeGANv3_PortraitSketch.onnx'),
"AnimeGANv3_JP_face": hf_hub_download('vumichien/AnimeGANv3_JP_face', 'AnimeGANv3_JP_face.onnx'),
}
def inference(img_path, model, focus_face=None):
print(img_path, model, focus_face)
mat, scale = load_image(img_path, focus_face)
output = convert(mat, model, scale)
save_path = f"output/out.{img_path.rsplit('.')[-1]}"
cv2.imwrite(save_path, output)
return output, save_path
### Layout ###
title = "AnimeGANv2: To produce your own animation 😶🌫️"
description = r"""### 🔥Demo AnimeGANv2: To produce your own animation. To use it, simply upload your image.<br>
"""
article = r"""
<center><img src='https://visitor-badge.glitch.me/badge?page_id=AnimeGAN_demo&left_color=green&right_color=blue' alt='visitor badge'></center>
<center><a href='https://github.com/TachibanaYoshino/AnimeGANv3' target='_blank'>Github Repo</a></center>
"""
gr.Interface(
inference, [
gr.Image(type="filepath", label="Input image"),
gr.Dropdown([
'AnimeGANv2_Hayao',
'AnimeGANv2_Shinkai',
'AnimeGANv2_Paprika',
'AnimeGANv3_PortraitSketch',
'AnimeGANv3_JP_face',
],
type="value",
value='AnimeGANv3_PortraitSketch',
label='AnimeGAN Style'),
gr.Radio(['Yes', 'No'], type="value", value='No', label='Extract face'),
], [
gr.Image(type="numpy", label="Output (The whole image)"),
gr.File(label="Download the output image")
],
title=title,
description=description,
article=article,
allow_flagging="never").launch(enable_queue=True) |