Spaces:
Runtime error
Runtime error
File size: 3,971 Bytes
0b756df 7f475d2 0b756df 7f475d2 0b756df 7f475d2 0b756df 7f475d2 0b756df 7f475d2 0b756df 7f475d2 0b756df 71c9afb 7f475d2 0b756df 7f475d2 0b756df 7f475d2 0b756df 7f475d2 71c9afb 7f475d2 71c9afb 7f475d2 0b756df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import cv2
import numpy as np
from .model import BiSeNet
mask_regions = {
"Background":0,
"Skin":1,
"L-Eyebrow":2,
"R-Eyebrow":3,
"L-Eye":4,
"R-Eye":5,
"Eye-G":6,
"L-Ear":7,
"R-Ear":8,
"Ear-R":9,
"Nose":10,
"Mouth":11,
"U-Lip":12,
"L-Lip":13,
"Neck":14,
"Neck-L":15,
"Cloth":16,
"Hair":17,
"Hat":18
}
# Borrowed from simswap
# https://github.com/neuralchen/SimSwap/blob/26c84d2901bd56eda4d5e3c5ca6da16e65dc82a6/util/reverse2original.py#L30
class SoftErosion(nn.Module):
def __init__(self, kernel_size=15, threshold=0.6, iterations=1):
super(SoftErosion, self).__init__()
r = kernel_size // 2
self.padding = r
self.iterations = iterations
self.threshold = threshold
# Create kernel
y_indices, x_indices = torch.meshgrid(torch.arange(0., kernel_size), torch.arange(0., kernel_size))
dist = torch.sqrt((x_indices - r) ** 2 + (y_indices - r) ** 2)
kernel = dist.max() - dist
kernel /= kernel.sum()
kernel = kernel.view(1, 1, *kernel.shape)
self.register_buffer('weight', kernel)
def forward(self, x):
x = x.float()
for i in range(self.iterations - 1):
x = torch.min(x, F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding))
x = F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding)
mask = x >= self.threshold
x[mask] = 1.0
x[~mask] /= x[~mask].max()
return x, mask
device = "cpu"
def init_parser(pth_path, mode="cpu"):
global device
device = mode
n_classes = 19
net = BiSeNet(n_classes=n_classes)
if device == "cuda":
net.cuda()
net.load_state_dict(torch.load(pth_path))
else:
net.load_state_dict(torch.load(pth_path, map_location=torch.device('cpu')))
net.eval()
return net
def image_to_parsing(img, net):
img = cv2.resize(img, (512, 512))
img = img[:,:,::-1]
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
img = transform(img.copy())
img = torch.unsqueeze(img, 0)
with torch.no_grad():
img = img.to(device)
out = net(img)[0]
parsing = out.squeeze(0).cpu().numpy().argmax(0)
return parsing
def get_mask(parsing, classes):
res = parsing == classes[0]
for val in classes[1:]:
res += parsing == val
return res
def swap_regions(source, target, net, smooth_mask, includes=[1,2,3,4,5,10,11,12,13], blur=10):
parsing = image_to_parsing(source, net)
if len(includes) == 0:
return source, np.zeros_like(source)
include_mask = get_mask(parsing, includes)
mask = np.repeat(include_mask[:, :, np.newaxis], 3, axis=2).astype("float32")
if smooth_mask is not None:
mask_tensor = torch.from_numpy(mask.copy().transpose((2, 0, 1))).float().to(device)
face_mask_tensor = mask_tensor[0] + mask_tensor[1]
soft_face_mask_tensor, _ = smooth_mask(face_mask_tensor.unsqueeze_(0).unsqueeze_(0))
soft_face_mask_tensor.squeeze_()
mask = np.repeat(soft_face_mask_tensor.cpu().numpy()[:, :, np.newaxis], 3, axis=2)
if blur > 0:
mask = cv2.GaussianBlur(mask, (0, 0), blur)
resized_source = cv2.resize((source).astype("float32"), (512, 512))
resized_target = cv2.resize((target).astype("float32"), (512, 512))
result = mask * resized_source + (1 - mask) * resized_target
result = cv2.resize(result.astype("uint8"), (source.shape[1], source.shape[0]))
return result
def mask_regions_to_list(values):
out_ids = []
for value in values:
if value in mask_regions.keys():
out_ids.append(mask_regions.get(value))
return out_ids
|