Spaces:
Paused
Paused
from models import create_model | |
from util.get_transform import get_transform | |
from PIL import Image | |
import os | |
ckp_path = os.path.join(os.path.dirname(__file__), 'checkpoints') | |
class Options(object): | |
def __init__(self, *initial_data, **kwargs): | |
for dictionary in initial_data: | |
for key in dictionary: | |
setattr(self, key, dictionary[key]) | |
for key in kwargs: | |
setattr(self, key, kwargs[key]) | |
class ModelLoader: | |
def __init__(self) -> None: | |
self.opt = Options({ | |
'isGradio': True, # Custom | |
'name': 'original', | |
'checkpoints_dir': ckp_path, | |
'gpu_ids': [], | |
'init_gain': 0.02, | |
'init_type': 'xavier', | |
'input_nc': 3, | |
'output_nc': 3, | |
'isTrain': False, | |
'model': 'cwr', | |
'nce_idt': False, | |
'nce_layers': '0', | |
'ndf': 64, | |
'netD': 'basic', | |
'netG': 'resnet_9blocks', | |
'netF': 'mlp_sample', | |
'netF_nc': 256, | |
'ngf': 64, | |
'no_antialias_up': None, | |
'no_antialias': None, | |
'no_dropout': True, | |
'normD': 'instance', | |
'normG': 'instance', | |
'preprocess': 'scale_width', | |
'num_threads': 0, # test code only supports num_threads = 1 | |
'batch_size': 1, # test code only supports batch_size = 1 | |
'serial_batches': True, # disable data shuffling; comment this line if results on randomly chosen images are needed. | |
'no_flip': True, # no flip; comment this line if results on flipped images are needed. | |
'display_id': -1, # no visdom display; the test code saves the results to a HTML file. | |
'direction': 'AtoB', # inference | |
'flip_equivariance': False, | |
'load_size': 1680, | |
'crop_size': 512, | |
}) | |
self.transform = get_transform(self.opt, grayscale=False) | |
self.model = None | |
def load(self) -> None: | |
self.model = create_model(self.opt) | |
self.model.load_networks('latest') | |
def inference(self, src=''): | |
if self.model == None: self.load() | |
if not os.path.isfile(src): | |
raise Exception('The image %s is not found!' % src) | |
# Loading | |
print('Loading the image %s' % src) | |
source = Image.open(src).convert('RGB') | |
img = self.transform(source) | |
print(img.shape) | |
# Inference | |
self.model.set_input({ 'A': img, 'B': img, 'A_paths': src }) | |
self.model.forward() | |
print(self.model) |