import argparse |
import json |
from concurrent.futures import ThreadPoolExecutor |
from datetime import datetime |
from pathlib import Path |
import numpy as np |
import torch |
import torch.nn as nn |
import torch.optim as optim |
import transformers |
from sklearn.metrics import ( |
accuracy_score, |
classification_report, |
f1_score, |
precision_score, |
recall_score, |
) |
from torch.optim.lr_scheduler import ( |
CosineAnnealingLR, |
CosineAnnealingWarmRestarts, |
ExponentialLR, |
) |
from torch.utils.data import DataLoader, Dataset |
from torch.utils.tensorboard import SummaryWriter |
from tqdm import tqdm |
from models import AudioClassifier, extract_features |
from losses import AsymmetricLoss, ASLSingleLabel |
torch.manual_seed(42) |
label2id = { |
"usual": 0, |
"aegi": 1, |
"chupa": 2, |
} |
id2label = {v: k for k, v in label2id.items()} |
parser = argparse.ArgumentParser() |
parser.add_argument("--exp_dir", type=str, default="data") |
parser.add_argument("--ckpt_dir", type=str, required=True) |
parser.add_argument("--device", type=str, default="cuda") |
parser.add_argument("--epochs", type=int, default=1000) |
parser.add_argument("--save_every", type=int, default=100) |
args = parser.parse_args() |
device = args.device |
if not torch.cuda.is_available(): |
print("No GPU detected. Using CPU.") |
device = "cpu" |
print(f"Using {device} for training.") |
class AudioDataset(Dataset): |
def __init__(self, file_paths, labels, features): |
self.file_paths = file_paths |
self.labels = labels |
self.features = features |
def __len__(self): |
return len(self.file_paths) |
def __getitem__(self, idx): |
return self.features[idx], self.labels[idx] |
def prepare_dataset(directory): |
file_paths = list(Path(directory).rglob("*.npy")) |
if len(file_paths) == 0: |
return [], [], [] |
def process(file_path: Path): |
npy_feature = np.load(file_path) |
id = int(label2id[file_path.parent.name]) |
return ( |
file_path, |
torch.tensor(id, dtype=torch.long).to(device), |
torch.tensor(npy_feature, dtype=torch.float32).to(device), |
) |
with ThreadPoolExecutor(max_workers=10) as executor: |
results = list(tqdm(executor.map(process, file_paths), total=len(file_paths))) |
file_paths, labels, features = zip(*results) |
return file_paths, labels, features |
print("Preparing dataset...") |
exp_dir = Path(args.exp_dir) |
train_file_paths, train_labels, train_feats = prepare_dataset(exp_dir / "train") |
val_file_paths, val_labels, val_feats = prepare_dataset(exp_dir / "val") |
print(f"Train: {len(train_file_paths)}, Val: {len(val_file_paths)}") |
train_dataset = AudioDataset(train_file_paths, train_labels, train_feats) |
print("Train dataset prepared.") |
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True) |
print("Train loader prepared.") |
if len(val_file_paths) == 0: |
val_dataset = None |
val_loader = None |
print("No validation dataset found.") |
else: |
val_dataset = AudioDataset(val_file_paths, val_labels, val_feats) |
print("Val dataset prepared.") |
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False) |
print("Val loader prepared.") |
config = { |
"model": { |
"label2id": label2id, |
"num_hidden_layers": 2, |
"hidden_dim": 128, |
}, |
"lr": 1e-3, |
"lr_decay": 0.996, |
} |
model = AudioClassifier(device="cuda", **config["model"]).to(device) |
model.to(device) |
criterion = ASLSingleLabel(gamma_pos=1, gamma_neg=4) |
optimizer = optim.AdamW(model.parameters(), lr=config["lr"], weight_decay=1e-2) |
scheduler = ExponentialLR(optimizer, gamma=config["lr_decay"]) |
num_epochs = args.epochs |
print("Start training...") |
current_time = datetime.now().strftime("%b%d_%H-%M-%S") |
ckpt_dir = Path(args.ckpt_dir) / current_time |
ckpt_dir.mkdir(parents=True, exist_ok=True) |
with open(ckpt_dir / "config.json", "w", encoding="utf-8") as f: |
json.dump(config, f, indent=4) |
save_every = args.save_every |
val_interval = 1 |
eval_interval = 1 |
writer = SummaryWriter(ckpt_dir / "logs") |
for epoch in tqdm(range(1, num_epochs + 1)): |
train_loss = 0.0 |
model.train() |
train_labels = [] |
train_preds = [] |
for inputs, labels in train_loader: |
inputs, labels = inputs.to(device), labels.to(device) |
optimizer.zero_grad() |
outputs = model(inputs) |
loss = criterion(outputs.squeeze(), labels) |
loss.backward() |
optimizer.step() |
train_loss += loss.item() |
if epoch % eval_interval == 0: |
with torch.no_grad(): |
_, predictions = torch.max(outputs, 1) |
train_labels.extend(labels.cpu().numpy()) |
train_preds.extend(predictions.cpu().numpy()) |
scheduler.step() |
if epoch % eval_interval == 0: |
accuracy = accuracy_score(train_labels, train_preds) |
precision = precision_score(train_labels, train_preds, average="macro") |
recall = recall_score(train_labels, train_preds, average="macro") |
f1 = f1_score(train_labels, train_preds, average="macro") |
report = classification_report( |
train_labels, train_preds, target_names=list(label2id.keys()) |
) |
writer.add_scalar("train/Accuracy", accuracy, epoch) |
writer.add_scalar("train/Precision", precision, epoch) |
writer.add_scalar("train/Recall", recall, epoch) |
writer.add_scalar("train/F1", f1, epoch) |
writer.add_scalar("Loss/train", train_loss / len(train_loader), epoch) |
writer.add_scalar("Learning Rate", optimizer.param_groups[0]["lr"], epoch) |
if epoch % save_every == 0: |
torch.save(model.state_dict(), ckpt_dir / f"model_{epoch}.pth") |
if epoch % val_interval != 0 or val_loader is None: |
tqdm.write(f"loss: {train_loss / len(train_loader):4f}\n{report}") |
continue |
model.eval() |
val_labels = [] |
val_preds = [] |
val_loss = 0.0 |
with torch.no_grad(): |
for inputs, labels in val_loader: |
inputs, labels = inputs.to(device), labels.to(device) |
outputs = model(inputs) |
_, predictions = torch.max(outputs, 1) |
val_labels.extend(labels.cpu().numpy()) |
val_preds.extend(predictions.cpu().numpy()) |
loss = criterion(outputs.squeeze(), labels) |
val_loss += loss.item() |
accuracy = accuracy_score(val_labels, val_preds) |
precision = precision_score(val_labels, val_preds, average="macro") |
recall = recall_score(val_labels, val_preds, average="macro") |
f1 = f1_score(val_labels, val_preds, average="macro") |
report = classification_report( |
val_labels, val_preds, target_names=list(label2id.keys()) |
) |
writer.add_scalar("Loss/val", val_loss / len(val_loader), epoch) |
writer.add_scalar("val/Accuracy", accuracy, epoch) |
writer.add_scalar("val/Precision", precision, epoch) |
writer.add_scalar("val/Recall", recall, epoch) |
writer.add_scalar("val/F1", f1, epoch) |
tqdm.write( |
f"loss: {train_loss / len(train_loader):4f}, val loss: {val_loss / len(val_loader):4f}, " |
f"acc: {accuracy:4f}, f1: {f1:4f}, prec: {precision:4f}, recall: {recall:4f}\n{report}" |
) |
torch.save(model.state_dict(), ckpt_dir / "model_final.pth") |