import torch from .model import PHNet import torchvision.transforms.functional as tf from .util import inference_img, log from .stylematte import StyleMatte import numpy as np import onnx from .engine import execute_onnx_model import cv2 from torchvision import transforms import time class Inference: def __init__(self, **kwargs): self.rank = 0 self.__dict__.update(kwargs) self.model = PHNet(enc_sizes=self.enc_sizes, skips=self.skips, grid_count=self.grid_counts, init_weights=self.init_weights, init_value=self.init_value) state = torch.load(self.checkpoint.harmonizer, map_location=self.device) self.model.load_state_dict(state, strict=True) self.model.eval() def harmonize(self, composite, mask): if len(composite.shape) < 4: composite = composite.unsqueeze(0) while len(mask.shape) < 4: mask = mask.unsqueeze(0) composite = tf.resize(composite, [self.image_size, self.image_size]) mask = tf.resize(mask, [self.image_size, self.image_size]) log(composite.shape, mask.shape) with torch.no_grad(): harmonized = self.model(composite, mask) # ['harmonized'] result = harmonized * mask + composite * (1-mask) return result class Matting: def __init__(self, **kwargs): self.rank = 0 self.__dict__.update(kwargs) if self.onnx: self.model = onnx.load(self.checkpoint.matting_onnx) else: self.model = StyleMatte().to(self.device) state = torch.load(self.checkpoint.matting, map_location=self.device) self.model.load_state_dict(state, strict=True) self.model.eval() def extract(self, inp): mask = inference_img(self.model, inp, self.device, self.onnx) inp_np = np.array(inp) fg = mask[:, :, None]*inp_np return [mask, fg] def inference_img(model, img, device='cpu', onnx=True): beg = time.time() h, w, _ = img.shape # print(img.shape) if h % 8 != 0 or w % 8 != 0: img = cv2.copyMakeBorder(img, 8-h % 8, 0, 8-w % 8, 0, cv2.BORDER_REFLECT) # print(img.shape) tensor_img = torch.from_numpy(img).permute(2, 0, 1).to(device) input_t = tensor_img/255.0 normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) input_t = normalize(input_t) input_t = input_t.unsqueeze(0).float() end_p = time.time() if onnx: out = execute_onnx_model(input_t, model) else: with torch.no_grad(): out = model(input_t).cpu().numpy() end = time.time() log(f"Inference time: {end-beg}, processing time: {end_p-beg}") # print("out",out.shape) result = out[0][:, -h:, -w:] # print(result.shape) return result[0]