Spaces:
Sleeping
Sleeping
File size: 3,060 Bytes
f0de4e8 0891b79 f0de4e8 0891b79 f0de4e8 0891b79 f0de4e8 0891b79 f0de4e8 0891b79 f0de4e8 0891b79 f0de4e8 0891b79 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import torch
from .model import PHNet
import torchvision.transforms.functional as tf
from .util import inference_img, log
from .stylematte import StyleMatte
import numpy as np
import onnx
from .engine import execute_onnx_model
import cv2
from torchvision import transforms
import time
class Inference:
def __init__(self, **kwargs):
self.rank = 0
self.__dict__.update(kwargs)
self.model = PHNet(enc_sizes=self.enc_sizes,
skips=self.skips,
grid_count=self.grid_counts,
init_weights=self.init_weights,
init_value=self.init_value)
state = torch.load(self.checkpoint.harmonizer,
map_location=self.device)
self.model.load_state_dict(state, strict=True)
self.model.eval()
def harmonize(self, composite, mask):
if len(composite.shape) < 4:
composite = composite.unsqueeze(0)
while len(mask.shape) < 4:
mask = mask.unsqueeze(0)
composite = tf.resize(composite, [self.image_size, self.image_size])
mask = tf.resize(mask, [self.image_size, self.image_size])
log(composite.shape, mask.shape)
with torch.no_grad():
harmonized = self.model(composite, mask) # ['harmonized']
result = harmonized * mask + composite * (1-mask)
return result
class Matting:
def __init__(self, **kwargs):
self.rank = 0
self.__dict__.update(kwargs)
if self.onnx:
self.model = onnx.load(self.checkpoint.matting_onnx)
else:
self.model = StyleMatte().to(self.device)
state = torch.load(self.checkpoint.matting,
map_location=self.device)
self.model.load_state_dict(state, strict=True)
self.model.eval()
def extract(self, inp):
mask = inference_img(self.model, inp, self.device, self.onnx)
inp_np = np.array(inp)
fg = mask[:, :, None]*inp_np
return [mask, fg]
def inference_img(model, img, device='cpu', onnx=True):
beg = time.time()
h, w, _ = img.shape
# print(img.shape)
if h % 8 != 0 or w % 8 != 0:
img = cv2.copyMakeBorder(img, 8-h % 8, 0, 8-w %
8, 0, cv2.BORDER_REFLECT)
# print(img.shape)
tensor_img = torch.from_numpy(img).permute(2, 0, 1).to(device)
input_t = tensor_img/255.0
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
input_t = normalize(input_t)
input_t = input_t.unsqueeze(0).float()
end_p = time.time()
if onnx:
out = execute_onnx_model(input_t, model)
else:
with torch.no_grad():
out = model(input_t).cpu().numpy()
end = time.time()
log(f"Inference time: {end-beg}, processing time: {end_p-beg}")
# print("out",out.shape)
result = out[0][:, -h:, -w:]
# print(result.shape)
return result[0]
|