import torch import torch.distributed as dist import os import os.path as osp import numpy as np import copy import json from ..log_service import print_log def singleton(class_): instances = {} def getinstance(*args, **kwargs): if class_ not in instances: instances[class_] = class_(*args, **kwargs) return instances[class_] return getinstance @singleton class get_evaluator(object): def __init__(self): self.evaluator = {} def register(self, evaf, name): self.evaluator[name] = evaf def __call__(self, pipeline_cfg=None): if pipeline_cfg is None: from . import eva_null return self.evaluator['null']() if not isinstance(pipeline_cfg, list): t = pipeline_cfg.type if t == 'miou': from . import eva_miou if t == 'psnr': from . import eva_psnr if t == 'ssim': from . import eva_ssim if t == 'lpips': from . import eva_lpips if t == 'fid': from . import eva_fid return self.evaluator[t](**pipeline_cfg.args) evaluator = [] for ci in pipeline_cfg: t = ci.type if t == 'miou': from . import eva_miou if t == 'psnr': from . import eva_psnr if t == 'ssim': from . import eva_ssim if t == 'lpips': from . import eva_lpips if t == 'fid': from . import eva_fid evaluator.append( self.evaluator[t](**ci.args)) if len(evaluator) == 0: return None else: return compose(evaluator) def register(name): def wrapper(class_): get_evaluator().register(class_, name) return class_ return wrapper class base_evaluator(object): def __init__(self, **args): ''' Args: sample_n, int, the total number of sample. used in distributed sync ''' if not dist.is_available(): raise ValueError self.world_size = dist.get_world_size() self.rank = dist.get_rank() self.sample_n = None self.final = {} def sync(self, data): """ Args: data: any, the data needs to be broadcasted """ if data is None: return None if isinstance(data, tuple): data = list(data) if isinstance(data, list): data_list = [] for datai in data: data_list.append(self.sync(datai)) data = [[*i] for i in zip(*data_list)] return data data = [ self.sync_(data, ranki) for ranki in range(self.world_size) ] return data def sync_(self, data, rank): t = type(data) is_broadcast = rank == self.rank if t is np.ndarray: dtrans = data dt = data.dtype if dt in [ int, np.bool, np.uint8, np.int8, np.int16, np.int32, np.int64,]: dtt = torch.int64 elif dt in [ float, np.float16, np.float32, np.float64,]: dtt = torch.float64 elif t is str: dtrans = np.array( [ord(c) for c in data], dtype = np.int64 ) dt = np.int64 dtt = torch.int64 else: raise ValueError if is_broadcast: n = len(dtrans.shape) n = torch.tensor(n).long() n = n.to(self.rank) dist.broadcast(n, src=rank) n = list(dtrans.shape) n = torch.tensor(n).long() n = n.to(self.rank) dist.broadcast(n, src=rank) n = torch.tensor(dtrans, dtype=dtt) n = n.to(self.rank) dist.broadcast(n, src=rank) return data n = torch.tensor(0).long() n = n.to(self.rank) dist.broadcast(n, src=rank) n = n.item() n = torch.zeros(n).long() n = n.to(self.rank) dist.broadcast(n, src=rank) n = list(n.to('cpu').numpy()) n = torch.zeros(n, dtype=dtt) n = n.to(self.rank) dist.broadcast(n, src=rank) n = n.to('cpu').numpy().astype(dt) if t is np.ndarray: return n elif t is str: n = ''.join([chr(c) for c in n]) return n def zipzap_arrange(self, data): ''' Order the data so it range like this: input [[0, 2, 4, 6], [1, 3, 5, 7]] -> output [0, 1, 2, 3, 4, 5, ...] ''' if isinstance(data[0], list): data_new = [] maxlen = max([len(i) for i in data]) totlen = sum([len(i) for i in data]) cnt = 0 for idx in range(maxlen): for datai in data: data_new += [datai[idx]] cnt += 1 if cnt >= totlen: break return data_new elif isinstance(data[0], np.ndarray): maxlen = max([i.shape[0] for i in data]) totlen = sum([i.shape[0] for i in data]) datai_shape = data[0].shape[1:] data = [ np.concatenate(datai, np.zeros(maxlen-datai.shape[0], *datai_shape), axis=0) if datai.shape[0] < maxlen else datai for datai in data ] # even the array data = np.stack(data, axis=1).reshape(-1, *datai_shape) data = data[:totlen] return data else: raise NotImplementedError def add_batch(self, **args): raise NotImplementedError def set_sample_n(self, sample_n): self.sample_n = sample_n def compute(self): raise NotImplementedError # Function needed in training to judge which # evaluated number is better def isbetter(self, old, new): return new>old def one_line_summary(self): print_log('Evaluator display') def save(self, path): if not osp.exists(path): os.makedirs(path) ofile = osp.join(path, 'result.json') with open(ofile, 'w') as f: json.dump(self.final, f, indent=4) def clear_data(self): raise NotImplementedError class compose(object): def __init__(self, pipeline): self.pipeline = pipeline self.sample_n = None self.final = {} def add_batch(self, *args, **kwargs): for pi in self.pipeline: pi.add_batch(*args, **kwargs) def set_sample_n(self, sample_n): self.sample_n = sample_n for pi in self.pipeline: pi.set_sample_n(sample_n) def compute(self): rv = {} for pi in self.pipeline: rv[pi.symbol] = pi.compute() self.final[pi.symbol] = pi.final return rv def isbetter(self, old, new): check = 0 for pi in self.pipeline: if pi.isbetter(old, new): check+=1 if check/len(self.pipeline)>0.5: return True else: return False def one_line_summary(self): for pi in self.pipeline: pi.one_line_summary() def save(self, path): if not osp.exists(path): os.makedirs(path) ofile = osp.join(path, 'result.json') with open(ofile, 'w') as f: json.dump(self.final, f, indent=4) def clear_data(self): for pi in self.pipeline: pi.clear_data()