import os import numpy as np import torch from torch.autograd import Variable from pdb import set_trace as st from IPython import embed class BaseModel(): def __init__(self): pass; def name(self): return 'BaseModel' def initialize(self, use_gpu=True, gpu_ids=[0]): self.use_gpu = use_gpu self.gpu_ids = gpu_ids def forward(self): pass def get_image_paths(self): pass def optimize_parameters(self): pass def get_current_visuals(self): return self.input def get_current_errors(self): return {} def save(self, label): pass # helper saving function that can be used by subclasses def save_network(self, network, path, network_label, epoch_label): save_filename = '%s_net_%s.pth' % (epoch_label, network_label) save_path = os.path.join(path, save_filename) torch.save(network.state_dict(), save_path) # helper loading function that can be used by subclasses def load_network(self, network, network_label, epoch_label): save_filename = '%s_net_%s.pth' % (epoch_label, network_label) save_path = os.path.join(self.save_dir, save_filename) print('Loading network from %s'%save_path) network.load_state_dict(torch.load(save_path)) def update_learning_rate(): pass def get_image_paths(self): return self.image_paths def save_done(self, flag=False): np.save(os.path.join(self.save_dir, 'done_flag'),flag) np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i')