PortraitTransfer / tools /inference.py
befozg
added initial portrait transfer app
f0de4e8
raw
history blame
1.95 kB
import torch
from .model import PHNet
import torchvision.transforms.functional as tf
from .util import inference_img, log
from .stylematte import StyleMatte
import numpy as np
class Inference:
def __init__(self, **kwargs):
self.rank = 0
self.__dict__.update(kwargs)
self.model = PHNet(enc_sizes=self.enc_sizes,
skips=self.skips,
grid_count=self.grid_counts,
init_weights=self.init_weights,
init_value=self.init_value)
log(f"checkpoint: {self.checkpoint.harmonizer}")
state = torch.load(self.checkpoint.harmonizer,
map_location=self.device)
self.model.load_state_dict(state, strict=True)
self.model.eval()
def harmonize(self, composite, mask):
if len(composite.shape) < 4:
composite = composite.unsqueeze(0)
while len(mask.shape) < 4:
mask = mask.unsqueeze(0)
composite = tf.resize(composite, [self.image_size, self.image_size])
mask = tf.resize(mask, [self.image_size, self.image_size])
log(composite.shape, mask.shape)
with torch.no_grad():
harmonized = self.model(composite, mask)['harmonized']
result = harmonized * mask + composite * (1-mask)
print(result.shape)
return result
class Matting:
def __init__(self, **kwargs):
self.rank = 0
self.__dict__.update(kwargs)
self.model = StyleMatte().to(self.device)
log(f"checkpoint: {self.checkpoint.matting}")
state = torch.load(self.checkpoint.matting, map_location=self.device)
self.model.load_state_dict(state, strict=True)
self.model.eval()
def extract(self, inp):
mask = inference_img(self.model, inp, self.device)
inp_np = np.array(inp)
fg = mask[:, :, None]*inp_np
return [mask, fg]