import torch from .model import PHNet import torchvision.transforms.functional as tf from .util import inference_img, log from .stylematte import StyleMatte import numpy as np class Inference: def __init__(self, **kwargs): self.rank = 0 self.__dict__.update(kwargs) self.model = PHNet(enc_sizes=self.enc_sizes, skips=self.skips, grid_count=self.grid_counts, init_weights=self.init_weights, init_value=self.init_value) log(f"checkpoint: {self.checkpoint.harmonizer}") state = torch.load(self.checkpoint.harmonizer, map_location=self.device) self.model.load_state_dict(state, strict=True) self.model.eval() def harmonize(self, composite, mask): if len(composite.shape) < 4: composite = composite.unsqueeze(0) while len(mask.shape) < 4: mask = mask.unsqueeze(0) composite = tf.resize(composite, [self.image_size, self.image_size]) mask = tf.resize(mask, [self.image_size, self.image_size]) log(composite.shape, mask.shape) with torch.no_grad(): harmonized = self.model(composite, mask)['harmonized'] result = harmonized * mask + composite * (1-mask) print(result.shape) return result class Matting: def __init__(self, **kwargs): self.rank = 0 self.__dict__.update(kwargs) self.model = StyleMatte().to(self.device) log(f"checkpoint: {self.checkpoint.matting}") state = torch.load(self.checkpoint.matting, map_location=self.device) self.model.load_state_dict(state, strict=True) self.model.eval() def extract(self, inp): mask = inference_img(self.model, inp, self.device) inp_np = np.array(inp) fg = mask[:, :, None]*inp_np return [mask, fg]