from models import create_model from util.get_transform import get_transform from util.util import tensor2im from PIL import Image import os ckp_path = os.path.join(os.path.dirname(__file__), 'checkpoints') class Options(object): def __init__(self, *initial_data, **kwargs): for dictionary in initial_data: for key in dictionary: setattr(self, key, dictionary[key]) for key in kwargs: setattr(self, key, kwargs[key]) class ModelLoader: def __init__(self, gpu_ids='', max_img_wh=512) -> None: self.opt = Options({ 'isGradio': True, # Custom 'name': 'original', # Checkpoints name 'checkpoints_dir': ckp_path, # Checkpoint folder 'gpu_ids': gpu_ids.split(',') if gpu_ids else [], 'init_gain': 0.02, # Scaling Factor 'init_type': 'xavier', # list: 'normal', 'xavier', 'kaiming', 'orthogonal' 'input_nc': 3, # 3 -> RGB, 1 -> Grayscale 'output_nc': 3, 'isTrain': False, 'model': 'cwr', 'nce_idt': False, 'nce_layers': '0', 'ndf': 64, # Nb of discrim filters in the first conv layer 'netD': 'basic', 'netG': 'resnet_9blocks', 'netF': 'mlp_sample', 'netF_nc': 256, 'ngf': 64, # Nb of gen filters in the last conv layer 'no_antialias_up': False, 'no_antialias': False, 'no_dropout': True, 'normD': 'instance', 'normG': 'instance', 'preprocess': 'yarflam_auto', # see more: util.get_transform 'dataroot': 'placeholder', 'num_threads': 1, # test code only supports num_threads = 1 'batch_size': 1, # test code only supports batch_size = 1 'serial_batches': False, # disable data shuffling; comment this line if results on randomly chosen images are needed. 'no_flip': True, # no flip; comment this line if results on flipped images are needed. 'display_id': -1, # no visdom display; the test code saves the results to a HTML file. 'direction': 'AtoB', # inference 'flip_equivariance': False, 'load_size': 1680, # not used 'crop_size': 512, # not used 'yarflam_img_wh': max_img_wh, # max width|height + auto scale down }) self.transform = get_transform(self.opt, grayscale=False) self.model = None def load(self) -> None: self.model = create_model(self.opt) self.model.load_networks('latest') def inference(self, src='', image_pil=None): if self.model == None: self.load() # Loading if isinstance(image_pil, Image.Image): img = self.transform(image_pil.convert('RGB')).unsqueeze(0) else: if not os.path.isfile(src): raise Exception('The image %s is not found!' % src) print('Loading the image %s' % src) source = Image.open(src).convert('RGB') img = self.transform(source).unsqueeze(0) print(img.shape) # Inference self.model.set_input({ 'A': img, 'A_paths': src, 'B': img, 'B_paths': src }) self.model.forward() out_data = list(self.model.get_current_visuals().items())[1][1] out_img = Image.fromarray(tensor2im(out_data)) return out_img