import copy import json import os import time import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as transforms import tqdm from sklearn.metrics import * from tqdm import tqdm from transformers import BertModel from FakeVD.code_test.utils.metrics import * from zmq import device from .coattention import * from .layers import * class Trainer3(): def __init__(self, model, device, lr, dropout, dataloaders, weight_decay, save_param_path, writer, epoch_stop, epoches, mode, model_name, event_num, save_threshold = 0.0, start_epoch = 0, ): self.model = model self.device = device self.mode = mode self.model_name = model_name self.event_num = event_num self.dataloaders = dataloaders self.start_epoch = start_epoch self.num_epochs = epoches self.epoch_stop = epoch_stop self.save_threshold = save_threshold self.writer = writer if os.path.exists(save_param_path): self.save_param_path = save_param_path else: self.save_param_path = os.makedirs(save_param_path) self.save_param_path= save_param_path self.lr = lr self.weight_decay = weight_decay self.dropout = dropout self.criterion = nn.CrossEntropyLoss() def train(self): since = time.time() self.model.cuda() best_model_wts_val = copy.deepcopy(self.model.state_dict()) best_acc_val = 0.0 best_epoch_val = 0 is_earlystop = False if self.mode == "eann": best_acc_val_event = 0.0 best_epoch_val_event = 0 for epoch in range(self.start_epoch, self.start_epoch+self.num_epochs): if is_earlystop: break print('-' * 50) print('Epoch {}/{}'.format(epoch+1, self.start_epoch+self.num_epochs)) print('-' * 50) p = float(epoch) / 100 lr = self.lr / (1. + 10 * p) ** 0.75 self.optimizer = torch.optim.Adam(params=self.model.parameters(), lr=lr) for phase in ['train', 'val', 'test']: if phase == 'train': self.model.train() else: self.model.eval() print('-' * 10) print (phase.upper()) print('-' * 10) running_loss_fnd = 0.0 running_loss = 0.0 tpred = [] tlabel = [] if self.mode == "eann": running_loss_event = 0.0 tpred_event = [] tlabel_event = [] for batch in tqdm(self.dataloaders[phase]): batch_data=batch for k,v in batch_data.items(): batch_data[k]=v.cuda() label = batch_data['label'] if self.mode == "eann": label_event = batch_data['label_event'] self.optimizer.zero_grad() with torch.set_grad_enabled(phase == 'train'): if self.mode == "eann": outputs, outputs_event,fea = self.model(**batch_data) loss_fnd = self.criterion(outputs, label) loss_event = self.criterion(outputs_event, label_event) loss = loss_fnd + loss_event _, preds = torch.max(outputs, 1) _, preds_event = torch.max(outputs_event, 1) else: outputs,fea = self.model(**batch_data) _, preds = torch.max(outputs, 1) loss = self.criterion(outputs, label) if phase == 'train': loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) self.optimizer.step() self.optimizer.zero_grad() tlabel.extend(label.detach().cpu().numpy().tolist()) tpred.extend(preds.detach().cpu().numpy().tolist()) running_loss += loss.item() * label.size(0) if self.mode == "eann": tlabel_event.extend(label_event.detach().cpu().numpy().tolist()) tpred_event.extend(preds_event.detach().cpu().numpy().tolist()) running_loss_event += loss_event.item() * label_event.size(0) running_loss_fnd += loss_fnd.item() * label.size(0) epoch_loss = running_loss / len(self.dataloaders[phase].dataset) print('Loss: {:.4f} '.format(epoch_loss)) results = metrics(tlabel, tpred) print (results) self.writer.add_scalar('Loss/'+phase, epoch_loss, epoch+1) self.writer.add_scalar('Acc/'+phase, results['acc'], epoch+1) self.writer.add_scalar('F1/'+phase, results['f1'], epoch+1) if self.mode == "eann": epoch_loss_fnd = running_loss_fnd / len(self.dataloaders[phase].dataset) print('Loss_fnd: {:.4f} '.format(epoch_loss_fnd)) epoch_loss_event = running_loss_event / len(self.dataloaders[phase].dataset) print('Loss_event: {:.4f} '.format(epoch_loss_event)) self.writer.add_scalar('Loss_fnd/'+phase, epoch_loss_fnd, epoch+1) self.writer.add_scalar('Loss_event/'+phase, epoch_loss_event, epoch+1) if phase == 'val' and results['acc'] > best_acc_val: best_acc_val = results['acc'] best_model_wts_val = copy.deepcopy(self.model.state_dict()) best_epoch_val = epoch+1 if best_acc_val > self.save_threshold: torch.save(self.model.state_dict(), self.save_param_path + "_val_epoch" + str(best_epoch_val) + "_{0:.4f}".format(best_acc_val)) print ("saved " + self.save_param_path + "_val_epoch" + str(best_epoch_val) + "_{0:.4f}".format(best_acc_val) ) else: if epoch-best_epoch_val >= self.epoch_stop-1: is_earlystop = True print ("early stopping...") time_elapsed = time.time() - since print('Training complete in {:.0f}m {:.0f}s'.format( time_elapsed // 60, time_elapsed % 60)) print("Best model on val: epoch" + str(best_epoch_val) + "_" + str(best_acc_val)) if self.mode == "eann": print("Event: Best model on val: epoch" + str(best_epoch_val_event) + "_" + str(best_acc_val_event)) self.model.load_state_dict(best_model_wts_val) print ("test result when using best model on val") return self.test() def test(self): since = time.time() self.model.cuda() self.model.eval() pred = [] label = [] if self.mode == "eann": pred_event = [] label_event = [] for batch in tqdm(self.dataloaders['test']): with torch.no_grad(): batch_data=batch for k,v in batch_data.items(): batch_data[k]=v.cuda() batch_label = batch_data['label'] if self.mode == "eann": batch_label_event = batch_data['label_event'] batch_outputs, batch_outputs_event, fea = self.model(**batch_data) _, batch_preds_event = torch.max(batch_outputs_event, 1) label_event.extend(batch_label_event.detach().cpu().numpy().tolist()) pred_event.extend(batch_preds_event.detach().cpu().numpy().tolist()) else: batch_outputs,fea = self.model(**batch_data) _, batch_preds = torch.max(batch_outputs, 1) label.extend(batch_label.detach().cpu().numpy().tolist()) pred.extend(batch_preds.detach().cpu().numpy().tolist()) print (get_confusionmatrix_fnd(np.array(pred), np.array(label))) print (metrics(label, pred)) if self.mode == "eann" and self.model_name != "FANVM": print ("event:") print (accuracy_score(np.array(label_event), np.array(pred_event))) return metrics(label, pred)