from __future__ import annotations import os from collections import defaultdict from collections.abc import Callable import numpy as np import pandas as pd import torch import wandb from torch.cuda.amp import GradScaler, autocast from tqdm import tqdm from utmosv2.utils import calc_metrics, print_metrics def _train_1epoch( cfg, model: torch.nn.Module, train_dataloader: torch.utils.data.DataLoader, criterion: torch.nn.Module, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, device: torch.device, ) -> dict[str, float]: model.train() train_loss = defaultdict(float) scaler = GradScaler() print(f" (lr: {scheduler.get_last_lr()[0]:.6f})") pbar = tqdm(train_dataloader, total=len(train_dataloader)) for i, t in enumerate(pbar): x, y = t[:-1], t[-1] x = [t.to(device, non_blocking=True) for t in x] y = y.to(device, non_blocking=True) if cfg.run.mixup: lmd = np.random.beta(cfg.run.mixup_alpha, cfg.run.mixup_alpha) perm = torch.randperm(x[0].shape[0]).to(device) x2 = [t[perm, :] for t in x] y2 = y[perm] optimizer.zero_grad() with autocast(): if cfg.run.mixup: output = model( *[lmd * t + (1 - lmd) * t2 for t, t2 in zip(x, x2)] ).squeeze(1) if isinstance(cfg.loss, list): loss = [ (w1, lmd * l1 + (1 - lmd) * l2) for (w1, l1), (_, l2) in zip( criterion(output, y), criterion(output, y2) ) ] else: loss = lmd * criterion(output, y) + (1 - lmd) * criterion( output, y2 ) else: output = model(*x).squeeze(1) loss = criterion(output, y) if isinstance(loss, list): loss_total = sum(w * ls for w, ls in loss) else: loss_total = loss scaler.scale(loss_total).backward() scaler.step(optimizer) scaler.update() scheduler.step() train_loss["loss"] += loss_total.detach().float().cpu().item() if isinstance(loss, list): for (cl, _), (_, ls) in zip(cfg.loss, loss): train_loss[cl.name] += ls.detach().float().cpu().item() pbar.set_description( f' loss: {train_loss["loss"] / (i + 1):.4f}' + ( f' ({", ".join([f"{cl.name}: {train_loss[cl.name] / (i + 1):.4f}" for cl, _ in cfg.loss])})' if isinstance(loss, list) else "" ) ) return {name: v / len(train_dataloader) for name, v in train_loss.items()} def _validate_1epoch( cfg, model: torch.nn.Module, valid_dataloader: torch.utils.data.DataLoader, criterion: torch.nn.Module, metrics: dict[str, Callable[[np.ndarray, np.ndarray], float]], device: torch.device, ) -> tuple[dict[str, float], dict[str, float], np.ndarray]: model.eval() valid_loss = defaultdict(float) valid_metrics = {name: 0.0 for name in metrics} valid_preds = [] pbar = tqdm(valid_dataloader, total=len(valid_dataloader)) with torch.no_grad(): for i, t in enumerate(pbar): x, y = t[:-1], t[-1] x = [t.to(device, non_blocking=True) for t in x] y_cpu = y y = y.to(device, non_blocking=True) with autocast(): output = model(*x).squeeze(1) loss = criterion(output, y) if isinstance(loss, list): loss_total = sum(w * ls for w, ls in loss) else: loss_total = loss valid_loss["loss"] += loss_total.detach().float().cpu().item() if isinstance(loss, list): for (cl, _), (_, ls) in zip(cfg.loss, loss): valid_loss[cl.name] += ls.detach().float().cpu().item() output = output.cpu().numpy() for name, metric in metrics.items(): valid_metrics[name] += metric(output, y_cpu.numpy()) pbar.set_description( f' val_loss: {valid_loss["loss"] / (i + 1):.4f} ' + ( f'({", ".join([f"{cl.name}: {valid_loss[cl.name] / (i + 1):.4f}" for cl, _ in cfg.loss])}) ' if isinstance(loss, list) else "" ) + " - ".join( [ f"val_{name}: {v / (i + 1):.4f}" for name, v in valid_metrics.items() ] ) ) valid_preds.append(output) valid_loss = {name: v / len(valid_dataloader) for name, v in valid_loss.items()} valid_metrics = { name: v / len(valid_dataloader) for name, v in valid_metrics.items() } valid_preds = np.concatenate(valid_preds) return valid_loss, valid_metrics, valid_preds def run_train( cfg, model: torch.nn.Module, train_dataloader: torch.utils.data.DataLoader, valid_dataloader: torch.utils.data.DataLoader, valid_data: pd.DataFrame, oof_preds: np.ndarray, now_fold: int, criterion: torch.nn.Module, metrics: dict[str, Callable[[np.ndarray, np.ndarray], float]], optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, device: torch.device, ) -> None: best_metric = 0.0 os.makedirs(cfg.save_path, exist_ok=True) for epoch in range(cfg.run.num_epochs): print(f"[Epoch {epoch + 1}/{cfg.run.num_epochs}]") train_loss = _train_1epoch( cfg, model, train_dataloader, criterion, optimizer, scheduler, device ) valid_loss, _, valid_preds = _validate_1epoch( cfg, model, valid_dataloader, criterion, metrics, device ) print(f"Validation dataset: {cfg.validation_dataset}") if cfg.validation_dataset == "each": dataset = valid_data["dataset"].unique() val_metrics = [ calc_metrics( valid_data[valid_data["dataset"] == ds], valid_preds[valid_data["dataset"] == ds], ) for ds in dataset ] val_metrics = { name: sum([m[name] for m in val_metrics]) / len(val_metrics) for name in val_metrics[0].keys() } if cfg.validation_dataset == "all": print("Validation dataset: ALL") val_metrics = calc_metrics(valid_data, valid_preds) else: val_metrics = calc_metrics( valid_data[valid_data["dataset"] == cfg.validation_dataset], valid_preds[valid_data["dataset"] == cfg.validation_dataset], ) print_metrics(val_metrics) if val_metrics[cfg.main_metric] > best_metric: new_metric = val_metrics[cfg.main_metric] print(f"(Found best metric: {best_metric:.4f} -> {new_metric:.4f})") best_metric = new_metric save_path = ( cfg.save_path / f"fold{now_fold}_s{cfg.split.seed}_best_model.pth" ) torch.save(model.state_dict(), save_path) print(f"Save best model: {save_path}") oof_preds[valid_data.index] = valid_preds save_path = cfg.save_path / f"fold{now_fold}_s{cfg.split.seed}_last_model.pth" torch.save(model.state_dict(), save_path) print() val_metrics["train_loss"] = train_loss["loss"] val_metrics["val_loss"] = valid_loss["loss"] for cl, _ in cfg.loss: val_metrics[f"train_loss_{cl.name}"] = train_loss[cl.name] val_metrics[f"val_loss_{cl.name}"] = valid_loss[cl.name] if cfg.wandb: wandb.log(val_metrics)