English
music
emotion
music2emo / train.py
kjysmu's picture
Upload 5 files
2dfd92b verified
import os
import logging
import torch
torch.set_float32_matmul_precision("medium")
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from data_loader import DataModule
from trainer import MusicClassifier
from omegaconf import DictConfig
import hydra
from hydra.utils import to_absolute_path
from hydra.core.hydra_config import HydraConfig
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.combined_loader import CombinedLoader
from pytorch_lightning.strategies import DDPStrategy
#from utilities.custom_early_stopping import MultiMetricEarlyStopping
def get_latest_version(log_dir):
version_dirs = [d for d in os.listdir(log_dir) if d.startswith('version_')]
version_dirs.sort(key=lambda x: int(x.split('_')[-1])) # Sort by version number
return version_dirs[-1] if version_dirs else None
log = logging.getLogger(__name__)
@hydra.main(version_base=None, config_path="config", config_name="train_config")
def main(config: DictConfig):
log_base_dir = 'tb_logs/train_audio_classification'
# log_base_dir = to_absolute_path('tb_logs/train_audio_classification')
is_mt = False
if "mt" in config.model.classifier:
is_mt = True
logger = TensorBoardLogger("tb_logs", name="train_audio_classification")
logger.log_hyperparams(config)
train_log_dir = logger.log_dir
print(f"Logging to {train_log_dir}")
log.info("Training starts")
data_module = DataModule( config )
data_module.setup()
# Get the list of dataloaders for both train and validation, with dataset names
trainloaders = {dataset_name: loader for dataset_name, loader in zip(config.datasets, data_module.train_dataloader())}
vallowers = {dataset_name: loader for dataset_name, loader in zip(config.datasets, data_module.val_dataloader())}
# Combine multiple loaders using CombinedLoader, now with dataset names
combined_train_loader = CombinedLoader(trainloaders, mode="max_size")
combined_val_loader = CombinedLoader(vallowers, mode="max_size")
latest_version = get_latest_version(log_base_dir)
next_version = int(latest_version.split('_')[-1]) + 1 if latest_version else 0
next_version = f"version_{next_version}"
val_epoch_file = os.path.join(log_base_dir, latest_version, 'val_epoch.txt')
model = MusicClassifier( config, output_file = val_epoch_file)
if is_mt:
checkpoint_callback_mood = ModelCheckpoint(**config.checkpoint_mood)
checkpoint_callback_va = ModelCheckpoint(**config.checkpoint_va)
early_stop_callback = EarlyStopping(**config.earlystopping)
if config.model.kd == True:
trainer = pl.Trainer(
**config.trainer,
strategy=DDPStrategy(find_unused_parameters=True),
callbacks=[checkpoint_callback_mood, checkpoint_callback_va, early_stop_callback],
logger=logger,
num_sanity_val_steps=0
)
else:
trainer = pl.Trainer(
**config.trainer,
strategy=DDPStrategy(find_unused_parameters=False),
callbacks=[checkpoint_callback_mood, checkpoint_callback_va, early_stop_callback],
logger=logger,
num_sanity_val_steps=0
)
else:
checkpoint_callback = ModelCheckpoint(**config.checkpoint)
# early_stop_callback = EarlyStopping(**config.earlystopping)
trainer = pl.Trainer(
**config.trainer,
callbacks=[checkpoint_callback, early_stop_callback],
logger=logger,
num_sanity_val_steps = 0
)
trainer.fit(model, combined_train_loader, combined_val_loader)
if trainer.global_rank == 0:
best_checkpoint_file = os.path.join(train_log_dir, 'best_checkpoint.txt')
with open(best_checkpoint_file, 'w') as f:
if is_mt:
f.write(f"Best checkpoint (mood): {checkpoint_callback_mood.best_model_path}\n")
f.write(f"Best checkpoint (va): {checkpoint_callback_va.best_model_path}\n")
else:
f.write(f"Best checkpoint: {checkpoint_callback.best_model_path}\n")
f.write(f"Version: {logger.version}\n")
if __name__ == '__main__':
main()