Spaces:
Running
Running
File size: 5,737 Bytes
a104d3f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import logging
import os
import time
from typing import List
import torch
from third_party.arcface import verification
class AverageMeter(object):
""" Computes and stores the average and current value
"""
def __init__(self):
self.val = None
self.avg = None
self.sum = None
self.count = None
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class CallBackVerification(object):
def __init__(self, frequent, rank, val_targets, rec_prefix, image_size=(112, 112),
is_gray=False):
self.frequent: int = frequent
self.rank: int = rank
self.highest_acc: float = 0.0
self.highest_acc_list: List[float] = [0.0] * len(val_targets)
self.ver_list: List[object] = []
self.ver_name_list: List[str] = []
if self.rank is 0:
self.init_dataset(val_targets=val_targets, data_dir=rec_prefix, image_size=image_size)
self.is_gray = is_gray
def ver_test(self, backbone: torch.nn.Module, global_step: int):
results = []
for i in range(len(self.ver_list)):
acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
self.ver_list[i], backbone, 10, 10,
is_gray=self.is_gray)
# logging.info('[%s][%d]XNorm: %f' % (self.ver_name_list[i], global_step, xnorm))
# logging.info('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (self.ver_name_list[i], global_step, acc2, std2))
print('[%s][%d]XNorm: %f' % (self.ver_name_list[i], global_step, xnorm))
print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (self.ver_name_list[i], global_step, acc2, std2))
if acc2 > self.highest_acc_list[i]:
self.highest_acc_list[i] = acc2
# logging.info(
# '[%s][%d]Accuracy-Highest: %1.5f' % (self.ver_name_list[i], global_step, self.highest_acc_list[i]))
print(
'[%s][%d]Accuracy-Highest: %1.5f' % (self.ver_name_list[i], global_step, self.highest_acc_list[i]))
results.append(acc2)
def init_dataset(self, val_targets, data_dir, image_size):
for name in val_targets:
path = os.path.join(data_dir, name + ".bin")
if os.path.exists(path):
data_set = verification.load_bin(path, image_size)
self.ver_list.append(data_set)
self.ver_name_list.append(name)
def __call__(self, num_update, backbone: torch.nn.Module):
if self.rank is 0 and num_update > 0 and num_update % self.frequent == 0:
backbone.eval()
self.ver_test(backbone, num_update)
backbone.train()
class CallBackLogging(object):
def __init__(self, frequent, rank, total_step, batch_size, world_size, writer=None):
self.frequent: int = frequent
self.rank: int = rank
self.time_start = time.time()
self.total_step: int = total_step
self.batch_size: int = batch_size
self.world_size: int = world_size
self.writer = writer
self.init = False
self.tic = 0
def __call__(self, global_step, loss: AverageMeter, epoch: int, fp16: bool, grad_scaler: torch.cuda.amp.GradScaler):
if self.rank is 0 and global_step > 0 and global_step % self.frequent == 0:
if self.init:
try:
speed: float = self.frequent * self.batch_size / (time.time() - self.tic)
speed_total = speed * self.world_size
except ZeroDivisionError:
speed_total = float('inf')
time_now = (time.time() - self.time_start) / 3600
time_total = time_now / ((global_step + 1) / self.total_step)
time_for_end = time_total - time_now
if self.writer is not None:
self.writer.add_scalar('time_for_end', time_for_end, global_step)
self.writer.add_scalar('loss', loss.avg, global_step)
if fp16:
msg = "Speed %.2f samples/sec Loss %.4f Epoch: %d Global Step: %d "\
"Fp16 Grad Scale: %2.f Required: %1.f hours" % (
speed_total, loss.avg, epoch, global_step, grad_scaler.get_scale(), time_for_end
)
else:
msg = "Speed %.2f samples/sec Loss %.4f Epoch: %d Global Step: %d Required: %1.f hours" % (
speed_total, loss.avg, epoch, global_step, time_for_end
)
logging.info(msg)
loss.reset()
self.tic = time.time()
else:
self.init = True
self.tic = time.time()
class CallBackModelCheckpoint(object):
def __init__(self, rank, output="./"):
self.rank: int = rank
self.output: str = output
def __call__(self,
global_step,
backbone: torch.nn.Module,
partial_fc=None,
awloss=None,):
print('CallBackModelCheckpoint...')
if global_step > 100 and self.rank is 0:
torch.save(backbone.module.state_dict(), os.path.join(self.output, "backbone.pth"))
if global_step > 100 and partial_fc is not None:
partial_fc.save_params()
if global_step > 100 and awloss is not None:
torch.save(awloss.state_dict(), os.path.join(self.output, "awloss.pth"))
|