UnderWater / ModelLoader.py
Yarflam's picture
Load Image
a15cce2
raw
history blame
2.62 kB
from models import create_model
from util.get_transform import get_transform
from PIL import Image
import os
ckp_path = os.path.join(os.path.dirname(__file__), 'checkpoints')
class Options(object):
def __init__(self, *initial_data, **kwargs):
for dictionary in initial_data:
for key in dictionary:
setattr(self, key, dictionary[key])
for key in kwargs:
setattr(self, key, kwargs[key])
class ModelLoader:
def __init__(self) -> None:
self.opt = Options({
'isGradio': True, # Custom
'name': 'original',
'checkpoints_dir': ckp_path,
'gpu_ids': [],
'init_gain': 0.02,
'init_type': 'xavier',
'input_nc': 3,
'output_nc': 3,
'isTrain': False,
'model': 'cwr',
'nce_idt': False,
'nce_layers': '0',
'ndf': 64,
'netD': 'basic',
'netG': 'resnet_9blocks',
'netF': 'mlp_sample',
'netF_nc': 256,
'ngf': 64,
'no_antialias_up': None,
'no_antialias': None,
'no_dropout': True,
'normD': 'instance',
'normG': 'instance',
'preprocess': 'scale_width',
'num_threads': 0, # test code only supports num_threads = 1
'batch_size': 1, # test code only supports batch_size = 1
'serial_batches': True, # disable data shuffling; comment this line if results on randomly chosen images are needed.
'no_flip': True, # no flip; comment this line if results on flipped images are needed.
'display_id': -1, # no visdom display; the test code saves the results to a HTML file.
'direction': 'AtoB', # inference
'flip_equivariance': False,
'load_size': 1680,
'crop_size': 512,
})
self.transform = get_transform(self.opt, grayscale=False)
self.model = None
def load(self) -> None:
self.model = create_model(self.opt)
self.model.load_networks('latest')
def inference(self, src=''):
if self.model == None: self.load()
if not os.path.isfile(src):
raise Exception('The image %s is not found!' % src)
# Loading
print('Loading the image %s' % src)
source = Image.open(src).convert('RGB')
img = self.transform(source)
print(img.shape)
# Inference
self.model.set_input({ 'A': img, 'B': img, 'A_paths': src })
self.model.forward()
print(self.model)