MoE_ECGFormer / train.py
splendor1811's picture
Upload 14 files
bcf646b verified
from models.MoE_ECGFormer import MoE_ECGFormer
from data.dataloader import ECGDataloader
from configs.data_configs import get_dataset_class
from configs.hparams import get_hparams_class
from utils import AverageMeter, to_device, _save_metrics, copy_files
from utils import fix_randomness, starting_logs, save_checkpoint, _calc_metrics
import torch
import torch.nn.functional as F
import datetime
import os
import collections
import numpy as np
import warnings
import sklearn.exceptions
warnings.filterwarnings("ignore", category=sklearn.exceptions.UndefinedMetricWarning)
warnings.simplefilter(action='ignore', category=FutureWarning)
class Trainer(object):
def __init__(self, args):
# dataset parameters
self.dataset = args.dataset
self.seed_id = args.seed_id
self.device = torch.device(args.device)
# Exp Description
self.run_description = f"{args.run_description}_{datetime.datetime.now().strftime('%H_%M')}"
self.experiment_description = args.experiment_description
# paths
self.home_path = os.getcwd()
self.save_dir = os.path.join(os.getcwd(), "experiments_logs")
self.exp_log_dir = os.path.join(self.save_dir, self.experiment_description, self.run_description)
os.makedirs(self.exp_log_dir, exist_ok=True)
self.data_path = args.data_path
# Specify runs
self.num_runs = args.num_runs
# get dataset and base model configs
self.dataset_configs, self.hparams_class = self.get_configs()
# Specify hparams
self.hparams = self.hparams_class.train_params
def get_configs(self):
dataset_class = get_dataset_class(self.dataset)
hparams_class = get_hparams_class("Supervised")
return dataset_class(), hparams_class()
def load_data(self, data_type):
self.train_dl, self.cw_dict = ECGDataloader(self.data_path, data_type, self.hparams).train_dataloader()
self.test_dl = ECGDataloader(self.data_path, data_type, self.hparams).test_dataloader()
self.valid_dl = ECGDataloader(self.data_path, data_type, self.hparams).valid_dataloader()
def calc_results_per_run(self):
acc, f1 = _calc_metrics(self.pred_labels, self.true_labels, self.dataset_configs.class_names)
return acc, f1
def train(self):
copy_files(self.exp_log_dir) # save a copy of training files
self.metrics = {'accuracy': [], 'f1_score': []}
# fixing random seed
fix_randomness(int(self.seed_id))
# Logging
self.logger, self.scenario_log_dir = starting_logs(self.dataset, self.exp_log_dir, self.seed_id)
self.logger.debug(self.hparams)
# Load data
self.load_data(self.dataset)
model = MoE_ECGFormer(configs=self.dataset_configs, hparams=self.hparams)
model.to(self.device)
# Average meters
loss_avg_meters = collections.defaultdict(lambda: AverageMeter())
self.optimizer = torch.optim.Adam(
model.parameters(),
lr=self.hparams["learning_rate"],
weight_decay=self.hparams["weight_decay"],
betas=(0.9, 0.99)
)
weights = [float(value) for value in self.cw_dict.values()]
# Now convert the list of floats to a numpy array, then to a PyTorch tensor
weights_array = np.array(weights).astype(np.float32) # Ensuring the correct dtype
weights_tensor = torch.tensor(weights_array).to(self.device)
self.cross_entropy = torch.nn.CrossEntropyLoss(weight=weights_tensor)
best_acc = 0
best_f1 = 0
# training..
ts_acc = 0
ts_f1 = 0
for epoch in range(1, self.hparams["num_epochs"] + 1):
model.train()
for step, batches in enumerate(self.train_dl):
batches = to_device(batches, self.device)
data = batches['samples'].float()
labels = batches['labels'].long()
# ====== Source =====================
self.optimizer.zero_grad()
# Src original features
logits = model(data)
# Cross-Entropy loss
x_ent_loss = self.cross_entropy(logits, labels)
x_ent_loss.backward()
self.optimizer.step()
losses = {'Total_loss': x_ent_loss.item()}
for key, val in losses.items():
loss_avg_meters[key].update(val, self.hparams["batch_size"])
self.evaluate(model, self.valid_dl)
tr_acc, tr_f1 = self.calc_results_per_run()
# logging
self.logger.debug(f'[Epoch : {epoch}/{self.hparams["num_epochs"]}]')
for key, val in loss_avg_meters.items():
self.logger.debug(f'{key}\t: {val.avg:2.4f}')
self.logger.debug(f'TRAIN: Acc:{tr_acc:2.4f} \t F1:{tr_f1:2.4f}')
# VALIDATION part
self.evaluate(model, self.valid_dl)
ts_acc, ts_f1 = self.calc_results_per_run()
if ts_f1 > best_f1: # save best model based on best f1.
best_f1 = ts_f1
best_acc = ts_acc
save_checkpoint(self.exp_log_dir, model, self.dataset, self.dataset_configs, self.hparams, "best")
_save_metrics(self.pred_labels, self.true_labels, self.exp_log_dir, "validation_best")
# logging
self.logger.debug(f'VAL : Acc:{ts_acc:2.4f} \t F1:{ts_f1:2.4f} (best: {best_f1:2.4f})')
self.logger.debug(f'-------------------------------------')
# LAST EPOCH
_save_metrics(self.pred_labels, self.true_labels, self.exp_log_dir, "validation_last")
self.logger.debug("LAST EPOCH PERFORMANCE on validation set...")
self.logger.debug(f'Acc:{ts_acc:2.4f} \t F1:{ts_f1:2.4f}')
self.logger.debug(":::::::::::::")
# BEST EPOCH
self.logger.debug("BEST EPOCH PERFORMANCE on validation set ...")
self.logger.debug(f'Acc:{best_acc:2.4f} \t F1:{best_f1:2.4f}')
save_checkpoint(self.exp_log_dir, model, self.dataset, self.dataset_configs, self.hparams, "last")
# TESTING
print(" === Evaluating on TEST set ===")
self.evaluate(model, self.test_dl)
test_acc, test_f1 = self.calc_results_per_run()
_save_metrics(self.pred_labels, self.true_labels, self.exp_log_dir, "test_last")
self.logger.debug(f'Acc:{test_acc:2.4f} \t F1:{test_f1:2.4f}')
def evaluate(self, model, dataset):
model.to(self.device).eval()
total_loss_ = []
self.pred_labels = np.array([])
self.true_labels = np.array([])
with torch.no_grad():
for batches in dataset:
batches = to_device(batches, self.device)
data = batches['samples'].float()
labels = batches['labels'].long()
# forward pass
predictions = model(data)
# compute loss
loss = F.cross_entropy(predictions, labels)
total_loss_.append(loss.item())
pred = predictions.detach().argmax(dim=1) # get the index of the max log-probability
self.pred_labels = np.append(self.pred_labels, pred.cpu().numpy())
self.true_labels = np.append(self.true_labels, labels.data.cpu().numpy())
self.trg_loss = torch.tensor(total_loss_).mean() # average loss