Spaces:
Running
Running
import torch | |
from .model import PHNet | |
import torchvision.transforms.functional as tf | |
from .util import inference_img, log | |
from .stylematte import StyleMatte | |
import numpy as np | |
class Inference: | |
def __init__(self, **kwargs): | |
self.rank = 0 | |
self.__dict__.update(kwargs) | |
self.model = PHNet(enc_sizes=self.enc_sizes, | |
skips=self.skips, | |
grid_count=self.grid_counts, | |
init_weights=self.init_weights, | |
init_value=self.init_value) | |
log(f"checkpoint: {self.checkpoint.harmonizer}") | |
state = torch.load(self.checkpoint.harmonizer, | |
map_location=self.device) | |
self.model.load_state_dict(state, strict=True) | |
self.model.eval() | |
def harmonize(self, composite, mask): | |
if len(composite.shape) < 4: | |
composite = composite.unsqueeze(0) | |
while len(mask.shape) < 4: | |
mask = mask.unsqueeze(0) | |
composite = tf.resize(composite, [self.image_size, self.image_size]) | |
mask = tf.resize(mask, [self.image_size, self.image_size]) | |
log(composite.shape, mask.shape) | |
with torch.no_grad(): | |
harmonized = self.model(composite, mask)['harmonized'] | |
result = harmonized * mask + composite * (1-mask) | |
print(result.shape) | |
return result | |
class Matting: | |
def __init__(self, **kwargs): | |
self.rank = 0 | |
self.__dict__.update(kwargs) | |
self.model = StyleMatte().to(self.device) | |
log(f"checkpoint: {self.checkpoint.matting}") | |
state = torch.load(self.checkpoint.matting, map_location=self.device) | |
self.model.load_state_dict(state, strict=True) | |
self.model.eval() | |
def extract(self, inp): | |
mask = inference_img(self.model, inp, self.device) | |
inp_np = np.array(inp) | |
fg = mask[:, :, None]*inp_np | |
return [mask, fg] | |