Spaces:
Running
Running
import torch | |
from collections import defaultdict | |
import torch.optim as optim | |
# import tensorflow as tf | |
from torch.utils.tensorboard import SummaryWriter | |
from collections import OrderedDict | |
from utils.utils import * | |
from os.path import join as pjoin | |
from utils.eval_t2m import evaluation_mask_transformer, evaluation_res_transformer | |
from models.mask_transformer.tools import * | |
from einops import rearrange, repeat | |
def def_value(): | |
return 0.0 | |
class MaskTransformerTrainer: | |
def __init__(self, args, t2m_transformer, vq_model): | |
self.opt = args | |
self.t2m_transformer = t2m_transformer | |
self.vq_model = vq_model | |
self.device = args.device | |
self.vq_model.eval() | |
if args.is_train: | |
self.logger = SummaryWriter(args.log_dir) | |
def update_lr_warm_up(self, nb_iter, warm_up_iter, lr): | |
current_lr = lr * (nb_iter + 1) / (warm_up_iter + 1) | |
for param_group in self.opt_t2m_transformer.param_groups: | |
param_group["lr"] = current_lr | |
return current_lr | |
def forward(self, batch_data): | |
conds, motion, m_lens = batch_data | |
motion = motion.detach().float().to(self.device) | |
m_lens = m_lens.detach().long().to(self.device) | |
# (b, n, q) | |
code_idx, _ = self.vq_model.encode(motion) | |
m_lens = m_lens // 4 | |
conds = conds.to(self.device).float() if torch.is_tensor(conds) else conds | |
# loss_dict = {} | |
# self.pred_ids = [] | |
# self.acc = [] | |
_loss, _pred_ids, _acc = self.t2m_transformer(code_idx[..., 0], conds, m_lens) | |
return _loss, _acc | |
def update(self, batch_data): | |
loss, acc = self.forward(batch_data) | |
self.opt_t2m_transformer.zero_grad() | |
loss.backward() | |
self.opt_t2m_transformer.step() | |
self.scheduler.step() | |
return loss.item(), acc | |
def save(self, file_name, ep, total_it): | |
t2m_trans_state_dict = self.t2m_transformer.state_dict() | |
clip_weights = [e for e in t2m_trans_state_dict.keys() if e.startswith('clip_model.')] | |
for e in clip_weights: | |
del t2m_trans_state_dict[e] | |
state = { | |
't2m_transformer': t2m_trans_state_dict, | |
'opt_t2m_transformer': self.opt_t2m_transformer.state_dict(), | |
'scheduler':self.scheduler.state_dict(), | |
'ep': ep, | |
'total_it': total_it, | |
} | |
torch.save(state, file_name) | |
def resume(self, model_dir): | |
checkpoint = torch.load(model_dir, map_location=self.device) | |
missing_keys, unexpected_keys = self.t2m_transformer.load_state_dict(checkpoint['t2m_transformer'], strict=False) | |
assert len(unexpected_keys) == 0 | |
assert all([k.startswith('clip_model.') for k in missing_keys]) | |
try: | |
self.opt_t2m_transformer.load_state_dict(checkpoint['opt_t2m_transformer']) # Optimizer | |
self.scheduler.load_state_dict(checkpoint['scheduler']) # Scheduler | |
except: | |
print('Resume wo optimizer') | |
return checkpoint['ep'], checkpoint['total_it'] | |
def train(self, train_loader, val_loader, eval_val_loader, eval_wrapper, plot_eval): | |
self.t2m_transformer.to(self.device) | |
self.vq_model.to(self.device) | |
self.opt_t2m_transformer = optim.AdamW(self.t2m_transformer.parameters(), betas=(0.9, 0.99), lr=self.opt.lr, weight_decay=1e-5) | |
self.scheduler = optim.lr_scheduler.MultiStepLR(self.opt_t2m_transformer, | |
milestones=self.opt.milestones, | |
gamma=self.opt.gamma) | |
epoch = 0 | |
it = 0 | |
if self.opt.is_continue: | |
model_dir = pjoin(self.opt.model_dir, 'latest.tar') # TODO | |
epoch, it = self.resume(model_dir) | |
print("Load model epoch:%d iterations:%d"%(epoch, it)) | |
start_time = time.time() | |
total_iters = self.opt.max_epoch * len(train_loader) | |
print(f'Total Epochs: {self.opt.max_epoch}, Total Iters: {total_iters}') | |
print('Iters Per Epoch, Training: %04d, Validation: %03d' % (len(train_loader), len(val_loader))) | |
logs = defaultdict(def_value, OrderedDict()) | |
best_fid, best_div, best_top1, best_top2, best_top3, best_matching, writer = evaluation_mask_transformer( | |
self.opt.save_root, eval_val_loader, self.t2m_transformer, self.vq_model, self.logger, epoch, | |
best_fid=100, best_div=100, | |
best_top1=0, best_top2=0, best_top3=0, | |
best_matching=100, eval_wrapper=eval_wrapper, | |
plot_func=plot_eval, save_ckpt=False, save_anim=False | |
) | |
best_acc = 0. | |
while epoch < self.opt.max_epoch: | |
self.t2m_transformer.train() | |
self.vq_model.eval() | |
for i, batch in enumerate(train_loader): | |
it += 1 | |
if it < self.opt.warm_up_iter: | |
self.update_lr_warm_up(it, self.opt.warm_up_iter, self.opt.lr) | |
loss, acc = self.update(batch_data=batch) | |
logs['loss'] += loss | |
logs['acc'] += acc | |
logs['lr'] += self.opt_t2m_transformer.param_groups[0]['lr'] | |
if it % self.opt.log_every == 0: | |
mean_loss = OrderedDict() | |
# self.logger.add_scalar('val_loss', val_loss, it) | |
# self.l | |
for tag, value in logs.items(): | |
self.logger.add_scalar('Train/%s'%tag, value / self.opt.log_every, it) | |
mean_loss[tag] = value / self.opt.log_every | |
logs = defaultdict(def_value, OrderedDict()) | |
print_current_loss(start_time, it, total_iters, mean_loss, epoch=epoch, inner_iter=i) | |
if it % self.opt.save_latest == 0: | |
self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it) | |
self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it) | |
epoch += 1 | |
print('Validation time:') | |
self.vq_model.eval() | |
self.t2m_transformer.eval() | |
val_loss = [] | |
val_acc = [] | |
with torch.no_grad(): | |
for i, batch_data in enumerate(val_loader): | |
loss, acc = self.forward(batch_data) | |
val_loss.append(loss.item()) | |
val_acc.append(acc) | |
print(f"Validation loss:{np.mean(val_loss):.3f}, accuracy:{np.mean(val_acc):.3f}") | |
self.logger.add_scalar('Val/loss', np.mean(val_loss), epoch) | |
self.logger.add_scalar('Val/acc', np.mean(val_acc), epoch) | |
if np.mean(val_acc) > best_acc: | |
print(f"Improved accuracy from {best_acc:.02f} to {np.mean(val_acc)}!!!") | |
self.save(pjoin(self.opt.model_dir, 'net_best_acc.tar'), epoch, it) | |
best_acc = np.mean(val_acc) | |
best_fid, best_div, best_top1, best_top2, best_top3, best_matching, writer = evaluation_mask_transformer( | |
self.opt.save_root, eval_val_loader, self.t2m_transformer, self.vq_model, self.logger, epoch, best_fid=best_fid, | |
best_div=best_div, best_top1=best_top1, best_top2=best_top2, best_top3=best_top3, | |
best_matching=best_matching, eval_wrapper=eval_wrapper, | |
plot_func=plot_eval, save_ckpt=True, save_anim=(epoch%self.opt.eval_every_e==0) | |
) | |
class ResidualTransformerTrainer: | |
def __init__(self, args, res_transformer, vq_model): | |
self.opt = args | |
self.res_transformer = res_transformer | |
self.vq_model = vq_model | |
self.device = args.device | |
self.vq_model.eval() | |
if args.is_train: | |
self.logger = SummaryWriter(args.log_dir) | |
# self.l1_criterion = torch.nn.SmoothL1Loss() | |
def update_lr_warm_up(self, nb_iter, warm_up_iter, lr): | |
current_lr = lr * (nb_iter + 1) / (warm_up_iter + 1) | |
for param_group in self.opt_res_transformer.param_groups: | |
param_group["lr"] = current_lr | |
return current_lr | |
def forward(self, batch_data): | |
conds, motion, m_lens = batch_data | |
motion = motion.detach().float().to(self.device) | |
m_lens = m_lens.detach().long().to(self.device) | |
# (b, n, q), (q, b, n ,d) | |
code_idx, all_codes = self.vq_model.encode(motion) | |
m_lens = m_lens // 4 | |
conds = conds.to(self.device).float() if torch.is_tensor(conds) else conds | |
ce_loss, pred_ids, acc = self.res_transformer(code_idx, conds, m_lens) | |
return ce_loss, acc | |
def update(self, batch_data): | |
loss, acc = self.forward(batch_data) | |
self.opt_res_transformer.zero_grad() | |
loss.backward() | |
self.opt_res_transformer.step() | |
self.scheduler.step() | |
return loss.item(), acc | |
def save(self, file_name, ep, total_it): | |
res_trans_state_dict = self.res_transformer.state_dict() | |
clip_weights = [e for e in res_trans_state_dict.keys() if e.startswith('clip_model.')] | |
for e in clip_weights: | |
del res_trans_state_dict[e] | |
state = { | |
'res_transformer': res_trans_state_dict, | |
'opt_res_transformer': self.opt_res_transformer.state_dict(), | |
'scheduler':self.scheduler.state_dict(), | |
'ep': ep, | |
'total_it': total_it, | |
} | |
torch.save(state, file_name) | |
def resume(self, model_dir): | |
checkpoint = torch.load(model_dir, map_location=self.device) | |
missing_keys, unexpected_keys = self.res_transformer.load_state_dict(checkpoint['res_transformer'], strict=False) | |
assert len(unexpected_keys) == 0 | |
assert all([k.startswith('clip_model.') for k in missing_keys]) | |
try: | |
self.opt_res_transformer.load_state_dict(checkpoint['opt_res_transformer']) # Optimizer | |
self.scheduler.load_state_dict(checkpoint['scheduler']) # Scheduler | |
except: | |
print('Resume wo optimizer') | |
return checkpoint['ep'], checkpoint['total_it'] | |
def train(self, train_loader, val_loader, eval_val_loader, eval_wrapper, plot_eval): | |
self.res_transformer.to(self.device) | |
self.vq_model.to(self.device) | |
self.opt_res_transformer = optim.AdamW(self.res_transformer.parameters(), betas=(0.9, 0.99), lr=self.opt.lr, weight_decay=1e-5) | |
self.scheduler = optim.lr_scheduler.MultiStepLR(self.opt_res_transformer, | |
milestones=self.opt.milestones, | |
gamma=self.opt.gamma) | |
epoch = 0 | |
it = 0 | |
if self.opt.is_continue: | |
model_dir = pjoin(self.opt.model_dir, 'latest.tar') # TODO | |
epoch, it = self.resume(model_dir) | |
print("Load model epoch:%d iterations:%d"%(epoch, it)) | |
start_time = time.time() | |
total_iters = self.opt.max_epoch * len(train_loader) | |
print(f'Total Epochs: {self.opt.max_epoch}, Total Iters: {total_iters}') | |
print('Iters Per Epoch, Training: %04d, Validation: %03d' % (len(train_loader), len(val_loader))) | |
logs = defaultdict(def_value, OrderedDict()) | |
best_fid, best_div, best_top1, best_top2, best_top3, best_matching, writer = evaluation_res_transformer( | |
self.opt.save_root, eval_val_loader, self.res_transformer, self.vq_model, self.logger, epoch, | |
best_fid=100, best_div=100, | |
best_top1=0, best_top2=0, best_top3=0, | |
best_matching=100, eval_wrapper=eval_wrapper, | |
plot_func=plot_eval, save_ckpt=False, save_anim=False | |
) | |
best_loss = 100 | |
best_acc = 0 | |
while epoch < self.opt.max_epoch: | |
self.res_transformer.train() | |
self.vq_model.eval() | |
for i, batch in enumerate(train_loader): | |
it += 1 | |
if it < self.opt.warm_up_iter: | |
self.update_lr_warm_up(it, self.opt.warm_up_iter, self.opt.lr) | |
loss, acc = self.update(batch_data=batch) | |
logs['loss'] += loss | |
logs["acc"] += acc | |
logs['lr'] += self.opt_res_transformer.param_groups[0]['lr'] | |
if it % self.opt.log_every == 0: | |
mean_loss = OrderedDict() | |
# self.logger.add_scalar('val_loss', val_loss, it) | |
# self.l | |
for tag, value in logs.items(): | |
self.logger.add_scalar('Train/%s'%tag, value / self.opt.log_every, it) | |
mean_loss[tag] = value / self.opt.log_every | |
logs = defaultdict(def_value, OrderedDict()) | |
print_current_loss(start_time, it, total_iters, mean_loss, epoch=epoch, inner_iter=i) | |
if it % self.opt.save_latest == 0: | |
self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it) | |
epoch += 1 | |
self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it) | |
print('Validation time:') | |
self.vq_model.eval() | |
self.res_transformer.eval() | |
val_loss = [] | |
val_acc = [] | |
with torch.no_grad(): | |
for i, batch_data in enumerate(val_loader): | |
loss, acc = self.forward(batch_data) | |
val_loss.append(loss.item()) | |
val_acc.append(acc) | |
print(f"Validation loss:{np.mean(val_loss):.3f}, Accuracy:{np.mean(val_acc):.3f}") | |
self.logger.add_scalar('Val/loss', np.mean(val_loss), epoch) | |
self.logger.add_scalar('Val/acc', np.mean(val_acc), epoch) | |
if np.mean(val_loss) < best_loss: | |
print(f"Improved loss from {best_loss:.02f} to {np.mean(val_loss)}!!!") | |
self.save(pjoin(self.opt.model_dir, 'net_best_loss.tar'), epoch, it) | |
best_loss = np.mean(val_loss) | |
if np.mean(val_acc) > best_acc: | |
print(f"Improved acc from {best_acc:.02f} to {np.mean(val_acc)}!!!") | |
# self.save(pjoin(self.opt.model_dir, 'net_best_loss.tar'), epoch, it) | |
best_acc = np.mean(val_acc) | |
best_fid, best_div, best_top1, best_top2, best_top3, best_matching, writer = evaluation_res_transformer( | |
self.opt.save_root, eval_val_loader, self.res_transformer, self.vq_model, self.logger, epoch, best_fid=best_fid, | |
best_div=best_div, best_top1=best_top1, best_top2=best_top2, best_top3=best_top3, | |
best_matching=best_matching, eval_wrapper=eval_wrapper, | |
plot_func=plot_eval, save_ckpt=True, save_anim=(epoch%self.opt.eval_every_e==0) | |
) |