|
import copy |
|
import datetime |
|
import json |
|
import math |
|
import os |
|
import random |
|
import signal |
|
import subprocess |
|
import sys |
|
import time |
|
import warnings |
|
from collections import defaultdict |
|
from shutil import copy2 |
|
from typing import Dict |
|
|
|
import numpy as np |
|
import prettytable as pt |
|
import torch |
|
import torch.nn as nn |
|
from termcolor import cprint |
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
|
|
class Logger(object): |
|
def __init__(self, filename, stream=sys.stdout): |
|
self.terminal = stream |
|
self.log = open(filename, "a") |
|
|
|
def write(self, message): |
|
self.terminal.write(message) |
|
self.log.write(message) |
|
|
|
def flush(self): |
|
pass |
|
|
|
|
|
class AverageMeter(object): |
|
"""Computes and stores the average and current value""" |
|
|
|
def __init__(self): |
|
self.sum = 0 |
|
self.avg = 0 |
|
self.val = 0 |
|
self.count = 0 |
|
|
|
def reset(self): |
|
self.sum = 0 |
|
self.avg = 0 |
|
self.val = 0 |
|
self.count = 0 |
|
|
|
def update(self, val, n=1): |
|
self.val = val |
|
self.sum = self.sum + val * n |
|
self.count = self.count + n |
|
self.avg = self.sum / self.count |
|
|
|
def __str__(self): |
|
return f"{self.avg: .5f}" |
|
|
|
|
|
def get_sha(): |
|
"""Get git current status""" |
|
cwd = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
def _run(command): |
|
return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() |
|
|
|
sha = "N/A" |
|
diff = "clean" |
|
branch = "N/A" |
|
message = "N/A" |
|
try: |
|
sha = _run(["git", "rev-parse", "HEAD"]) |
|
sha = sha[:8] |
|
subprocess.check_output(["git", "diff"], cwd=cwd) |
|
diff = _run(["git", "diff-index", "HEAD"]) |
|
diff = "has uncommited changes" if diff else "clean" |
|
branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) |
|
message = _run(["git", "log", "--pretty=format:'%s'", sha, "-1"]).replace( |
|
"'", "" |
|
) |
|
except Exception: |
|
pass |
|
|
|
return {"sha": sha, "status": diff, "branch": branch, "prev_commit": message} |
|
|
|
|
|
def setup_env(opt): |
|
if opt.eval or opt.debug: |
|
opt.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
torch.autograd.set_detect_anomaly(True) |
|
return None |
|
|
|
dir_name = opt.dir_name |
|
save_root_path = opt.save_root_path |
|
if not os.path.exists(save_root_path): |
|
os.mkdir(save_root_path) |
|
|
|
|
|
torch.manual_seed(opt.seed) |
|
np.random.seed(opt.seed) |
|
random.seed(opt.seed) |
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = True |
|
|
|
|
|
checkpoint = "checkpoint" |
|
if not os.path.exists(os.path.join(save_root_path, dir_name)): |
|
os.mkdir(os.path.join(save_root_path, dir_name)) |
|
os.mkdir(os.path.join(save_root_path, dir_name, checkpoint)) |
|
|
|
|
|
sys.stdout = Logger(os.path.join(save_root_path, dir_name, "log.log"), sys.stdout) |
|
sys.stderr = Logger(os.path.join(save_root_path, dir_name, "error.log"), sys.stderr) |
|
|
|
|
|
params = copy.deepcopy(vars(opt)) |
|
params.pop("device") |
|
with open(os.path.join(save_root_path, dir_name, "params.json"), "w") as f: |
|
json.dump(params, f) |
|
|
|
|
|
print( |
|
"Running on {}, PyTorch version {}, files will be saved at {}".format( |
|
opt.device, torch.__version__, os.path.join(save_root_path, dir_name) |
|
) |
|
) |
|
print("Devices:") |
|
for i in range(torch.cuda.device_count()): |
|
print(" {}:".format(i), torch.cuda.get_device_name(i)) |
|
print(f"Git: {get_sha()}.") |
|
|
|
|
|
return SummaryWriter("{}/{}/".format(opt.save_root_path, opt.dir_name)) |
|
|
|
|
|
class MetricLogger(object): |
|
def __init__(self, delimiter=" ", writer=None, suffix=None): |
|
self.meters = defaultdict(AverageMeter) |
|
self.delimiter = delimiter |
|
self.writer = writer |
|
self.suffix = suffix |
|
|
|
def update(self, **kwargs): |
|
for k, v in kwargs.items(): |
|
if isinstance(v, torch.Tensor): |
|
v = v.item() |
|
assert isinstance(v, (float, int)), f"Unsupport type {type(v)}." |
|
self.meters[k].update(v) |
|
|
|
def add_meter(self, name, meter): |
|
self.meters[name] = meter |
|
|
|
def get_meters(self): |
|
result = {} |
|
for k, v in self.meters.items(): |
|
result[k] = v.avg |
|
return result |
|
|
|
def prepend_subprefix(self, subprefix: str): |
|
old_keys = list(self.meters.keys()) |
|
for k in old_keys: |
|
self.meters[k.replace("/", f"/{subprefix}")] = self.meters[k] |
|
for k in old_keys: |
|
del self.meters[k] |
|
|
|
def log_every(self, iterable, print_freq=10, header=""): |
|
i = 0 |
|
start_time = time.time() |
|
end = time.time() |
|
iter_time = AverageMeter() |
|
space_fmt = ":" + str(len(str(len(iterable)))) + "d" |
|
log_msg = self.delimiter.join( |
|
[ |
|
header, |
|
"[{0" + space_fmt + "}/{1}]", |
|
"eta: {eta}", |
|
"{meters}", |
|
"iter time: {time}s", |
|
] |
|
) |
|
for obj in iterable: |
|
yield i, obj |
|
iter_time.update(time.time() - end) |
|
if (i + 1) % print_freq == 0 or i == len(iterable) - 1: |
|
eta_seconds = iter_time.avg * (len(iterable) - i) |
|
eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) |
|
print( |
|
log_msg.format( |
|
i + 1, |
|
len(iterable), |
|
eta=eta_string, |
|
meters=str(self), |
|
time=str(iter_time), |
|
).replace(" ", " ") |
|
) |
|
i += 1 |
|
end = time.time() |
|
total_time = time.time() - start_time |
|
total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
|
print( |
|
"{} Total time: {} ({:.4f}s / it)".format( |
|
header, total_time_str, total_time / len(iterable) |
|
) |
|
) |
|
|
|
def write_tensorboard(self, step): |
|
if self.writer is not None: |
|
for k, v in self.meters.items(): |
|
|
|
|
|
|
|
|
|
self.writer.add_scalar(k, v.avg, step) |
|
|
|
def stat_table(self): |
|
tb = pt.PrettyTable(field_names=["Metrics", "Values"]) |
|
for name, meter in self.meters.items(): |
|
tb.add_row([name, str(meter)]) |
|
return tb.get_string() |
|
|
|
def __getattr__(self, attr): |
|
if attr in self.meters: |
|
return self.meters[attr] |
|
if attr in self.__dict__: |
|
return self.__dict__[attr] |
|
raise AttributeError( |
|
"'{}' object has no attribute '{}'".format(type(self).__name__, attr) |
|
) |
|
|
|
def __str__(self): |
|
loss_str = [] |
|
for name, meter in self.meters.items(): |
|
loss_str.append("{}: {}".format(name, str(meter))) |
|
return self.delimiter.join(loss_str).replace(" ", " ") |
|
|
|
|
|
def save_model(path, model: nn.Module, epoch, opt, performance=None): |
|
if not opt.debug: |
|
try: |
|
torch.save( |
|
{ |
|
"model": model.state_dict(), |
|
"epoch": epoch, |
|
"opt": opt, |
|
"performance": performance, |
|
}, |
|
path, |
|
) |
|
except Exception as e: |
|
cprint("Failed to save {} because {}".format(path, str(e))) |
|
|
|
|
|
def resume_from(model: nn.Module, resume_path: str): |
|
checkpoint = torch.load(resume_path, map_location="cpu") |
|
state_dict = checkpoint["model"] |
|
performance = checkpoint["performance"] |
|
try: |
|
model.load_state_dict(state_dict) |
|
except Exception as e: |
|
model.load_state_dict(state_dict, strict=False) |
|
cprint("Failed to load full model because {}".format(str(e)), "red") |
|
time.sleep(3) |
|
print(f"{resume_path} model loaded. It performance is") |
|
if performance is not None: |
|
for k, v in performance.items(): |
|
print(f"{k}: {v}") |
|
|
|
|
|
def update_record(result: Dict, epoch: int, opt, file_name: str = "latest_record"): |
|
if not opt.debug: |
|
|
|
tb = pt.PrettyTable(field_names=["Metrics", "Values"]) |
|
with open( |
|
os.path.join(opt.save_root_path, opt.dir_name, f"{file_name}.txt"), "w" |
|
) as f: |
|
f.write(f"Performance at {epoch}-th epoch:\n\n") |
|
for k, v in result.items(): |
|
tb.add_row([k, "{:.7f}".format(v)]) |
|
f.write(tb.get_string()) |
|
|
|
|
|
result["epoch"] = epoch |
|
with open( |
|
os.path.join(opt.save_root_path, opt.dir_name, f"{file_name}.json"), "w" |
|
) as f: |
|
json.dump(result, f) |
|
|
|
|
|
def pixel_acc(pred, label): |
|
"""Compute pixel-level prediction accuracy.""" |
|
warnings.warn("I am not sure if this implementation is correct.") |
|
|
|
label_size = label.shape[-2:] |
|
if pred.shape[-2] != label_size: |
|
pred = torch.nn.functional.interpolate( |
|
pred, size=label_size, mode="bilinear", align_corners=False |
|
) |
|
|
|
pred[torch.where(pred > 0.5)] = 1 |
|
pred[torch.where(pred <= 0.5)] = 0 |
|
correct = torch.sum((pred + label) == 1.0) |
|
total = torch.numel(pred) |
|
return correct / (total + 1e-8) |
|
|
|
|
|
def calculate_pixel_f1(pd, gt, prefix="", suffix=""): |
|
if np.max(pd) == np.max(gt) and np.max(pd) == 0: |
|
f1, iou = 1.0, 1.0 |
|
return f1, 0.0, 0.0 |
|
seg_inv, gt_inv = np.logical_not(pd), np.logical_not(gt) |
|
true_pos = float(np.logical_and(pd, gt).sum()) |
|
false_pos = np.logical_and(pd, gt_inv).sum() |
|
false_neg = np.logical_and(seg_inv, gt).sum() |
|
f1 = 2 * true_pos / (2 * true_pos + false_pos + false_neg + 1e-6) |
|
precision = true_pos / (true_pos + false_pos + 1e-6) |
|
recall = true_pos / (true_pos + false_neg + 1e-6) |
|
|
|
return { |
|
f"{prefix}pixel_f1{suffix}": f1, |
|
f"{prefix}pixel_prec{suffix}": precision, |
|
f"{prefix}pixel_recall{suffix}": recall, |
|
} |
|
|
|
|
|
def calculate_img_score(pd, gt, prefix="", suffix="", eta=1e-6): |
|
seg_inv, gt_inv = np.logical_not(pd), np.logical_not(gt) |
|
true_pos = float(np.logical_and(pd, gt).sum()) |
|
false_pos = float(np.logical_and(pd, gt_inv).sum()) |
|
false_neg = float(np.logical_and(seg_inv, gt).sum()) |
|
true_neg = float(np.logical_and(seg_inv, gt_inv).sum()) |
|
acc = (true_pos + true_neg) / (true_pos + true_neg + false_neg + false_pos + eta) |
|
sen = true_pos / (true_pos + false_neg + eta) |
|
spe = true_neg / (true_neg + false_pos + eta) |
|
precision = true_pos / (true_pos + false_pos + eta) |
|
recall = true_pos / (true_pos + false_neg + eta) |
|
try: |
|
f1 = 2 * sen * spe / (sen + spe) |
|
except: |
|
f1 = -math.inf |
|
|
|
return { |
|
f"{prefix}image_acc{suffix}": acc, |
|
f"{prefix}image_sen{suffix}": sen, |
|
f"{prefix}image_spe{suffix}": spe, |
|
f"{prefix}image_f1{suffix}": f1, |
|
f"{prefix}image_true_pos{suffix}": true_pos, |
|
f"{prefix}image_true_neg{suffix}": true_neg, |
|
f"{prefix}image_false_pos{suffix}": false_pos, |
|
f"{prefix}image_false_neg{suffix}": false_neg, |
|
f"{prefix}image_prec{suffix}": precision, |
|
f"{prefix}image_recall{suffix}": recall, |
|
} |
|
|
|
|
|
class timeout: |
|
def __init__(self, seconds=1, error_message="Timeout"): |
|
self.seconds = seconds |
|
self.error_message = error_message |
|
|
|
def handle_timeout(self, signum, frame): |
|
raise TimeoutError(self.error_message) |
|
|
|
def __enter__(self): |
|
signal.signal(signal.SIGALRM, self.handle_timeout) |
|
signal.alarm(self.seconds) |
|
|
|
def __exit__(self, type, value, traceback): |
|
signal.alarm(0) |
|
|