Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
import cv2 | |
import os | |
from loss import batch_episym | |
from tqdm import tqdm | |
import sys | |
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | |
sys.path.insert(0, ROOT_DIR) | |
from utils import evaluation_utils, train_utils | |
def valid(valid_loader, model, match_loss, config, model_config): | |
model.eval() | |
loader_iter = iter(valid_loader) | |
num_pair = 0 | |
total_loss, total_acc_corr, total_acc_incorr = 0, 0, 0 | |
total_precision, total_recall = torch.zeros( | |
model_config.layer_num, device="cuda" | |
), torch.zeros(model_config.layer_num, device="cuda") | |
total_acc_mid = torch.zeros(len(model_config.seedlayer) - 1, device="cuda") | |
with torch.no_grad(): | |
if config.local_rank == 0: | |
loader_iter = tqdm(loader_iter) | |
print("validating...") | |
for test_data in loader_iter: | |
num_pair += 1 | |
test_data = train_utils.tocuda(test_data) | |
res = model(test_data) | |
loss_res = match_loss.run(test_data, res) | |
total_acc_corr += loss_res["acc_corr"] | |
total_acc_incorr += loss_res["acc_incorr"] | |
total_loss += loss_res["total_loss"] | |
if config.model_name == "SGM": | |
total_acc_mid += loss_res["mid_acc_corr"] | |
total_precision, total_recall = ( | |
total_precision + loss_res["pre_seed_conf"], | |
total_recall + loss_res["recall_seed_conf"], | |
) | |
total_acc_corr /= num_pair | |
total_acc_incorr /= num_pair | |
total_precision /= num_pair | |
total_recall /= num_pair | |
total_acc_mid /= num_pair | |
# apply tensor reduction | |
( | |
total_loss, | |
total_acc_corr, | |
total_acc_incorr, | |
total_precision, | |
total_recall, | |
total_acc_mid, | |
) = ( | |
train_utils.reduce_tensor(total_loss, "sum"), | |
train_utils.reduce_tensor(total_acc_corr, "mean"), | |
train_utils.reduce_tensor(total_acc_incorr, "mean"), | |
train_utils.reduce_tensor(total_precision, "mean"), | |
train_utils.reduce_tensor(total_recall, "mean"), | |
train_utils.reduce_tensor(total_acc_mid, "mean"), | |
) | |
model.train() | |
return ( | |
total_loss, | |
total_acc_corr, | |
total_acc_incorr, | |
total_precision, | |
total_recall, | |
total_acc_mid, | |
) | |
def dump_train_vis(res, data, step, config): | |
# batch matching | |
p = res["p"][:, :-1, :-1] | |
score, index1 = torch.max(p, dim=-1) | |
_, index2 = torch.max(p, dim=-2) | |
mask_th = score > 0.2 | |
mask_mc = index2.gather(index=index1, dim=1) == torch.arange(len(p[0])).cuda()[None] | |
mask_p = mask_th & mask_mc # B*N | |
corr1, corr2 = data["x1"], data["x2"].gather( | |
index=index1[:, :, None].expand(-1, -1, 2), dim=1 | |
) | |
corr1_kpt, corr2_kpt = data["kpt1"], data["kpt2"].gather( | |
index=index1[:, :, None].expand(-1, -1, 2), dim=1 | |
) | |
epi_dis = batch_episym(corr1, corr2, data["e_gt"]) | |
mask_inlier = epi_dis < config.inlier_th # B*N | |
# dump vis | |
for cur_mask_p, cur_mask_inlier, cur_corr1, cur_corr2, img_path1, img_path2 in zip( | |
mask_p, mask_inlier, corr1_kpt, corr2_kpt, data["img_path1"], data["img_path2"] | |
): | |
img1, img2 = cv2.imread(img_path1), cv2.imread(img_path2) | |
dis_play = evaluation_utils.draw_match( | |
img1, | |
img2, | |
cur_corr1[cur_mask_p].cpu().numpy(), | |
cur_corr2[cur_mask_p].cpu().numpy(), | |
inlier=cur_mask_inlier, | |
) | |
base_name_seq = os.path.join( | |
img_path1.split("/")[-1] | |
+ "_" | |
+ img_path2.split("/")[-1] | |
+ "_" | |
+ img_path1.split("/")[-2] | |
) | |
save_path = os.path.join( | |
config.train_vis_folder, | |
"train_vis", | |
config.log_base, | |
str(step), | |
base_name_seq + ".png", | |
) | |
cv2.imwrite(save_path, dis_play) | |