|
import os |
|
import logging |
|
|
|
import torch |
|
torch.set_float32_matmul_precision("medium") |
|
|
|
|
|
import pytorch_lightning as pl |
|
from pytorch_lightning.loggers import TensorBoardLogger |
|
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint |
|
from data_loader import DataModule |
|
from trainer import MusicClassifier |
|
from omegaconf import DictConfig |
|
import hydra |
|
from hydra.utils import to_absolute_path |
|
from hydra.core.hydra_config import HydraConfig |
|
from pytorch_lightning.callbacks import EarlyStopping |
|
from pytorch_lightning.utilities.combined_loader import CombinedLoader |
|
from pytorch_lightning.strategies import DDPStrategy |
|
|
|
|
|
def get_latest_version(log_dir): |
|
version_dirs = [d for d in os.listdir(log_dir) if d.startswith('version_')] |
|
version_dirs.sort(key=lambda x: int(x.split('_')[-1])) |
|
return version_dirs[-1] if version_dirs else None |
|
|
|
log = logging.getLogger(__name__) |
|
@hydra.main(version_base=None, config_path="config", config_name="train_config") |
|
def main(config: DictConfig): |
|
|
|
log_base_dir = 'tb_logs/train_audio_classification' |
|
|
|
is_mt = False |
|
if "mt" in config.model.classifier: |
|
is_mt = True |
|
|
|
logger = TensorBoardLogger("tb_logs", name="train_audio_classification") |
|
logger.log_hyperparams(config) |
|
train_log_dir = logger.log_dir |
|
print(f"Logging to {train_log_dir}") |
|
log.info("Training starts") |
|
|
|
data_module = DataModule( config ) |
|
data_module.setup() |
|
|
|
|
|
trainloaders = {dataset_name: loader for dataset_name, loader in zip(config.datasets, data_module.train_dataloader())} |
|
vallowers = {dataset_name: loader for dataset_name, loader in zip(config.datasets, data_module.val_dataloader())} |
|
|
|
|
|
combined_train_loader = CombinedLoader(trainloaders, mode="max_size") |
|
combined_val_loader = CombinedLoader(vallowers, mode="max_size") |
|
|
|
latest_version = get_latest_version(log_base_dir) |
|
next_version = int(latest_version.split('_')[-1]) + 1 if latest_version else 0 |
|
next_version = f"version_{next_version}" |
|
|
|
val_epoch_file = os.path.join(log_base_dir, latest_version, 'val_epoch.txt') |
|
|
|
model = MusicClassifier( config, output_file = val_epoch_file) |
|
|
|
if is_mt: |
|
checkpoint_callback_mood = ModelCheckpoint(**config.checkpoint_mood) |
|
checkpoint_callback_va = ModelCheckpoint(**config.checkpoint_va) |
|
early_stop_callback = EarlyStopping(**config.earlystopping) |
|
|
|
if config.model.kd == True: |
|
trainer = pl.Trainer( |
|
**config.trainer, |
|
strategy=DDPStrategy(find_unused_parameters=True), |
|
callbacks=[checkpoint_callback_mood, checkpoint_callback_va, early_stop_callback], |
|
logger=logger, |
|
num_sanity_val_steps=0 |
|
) |
|
else: |
|
trainer = pl.Trainer( |
|
**config.trainer, |
|
strategy=DDPStrategy(find_unused_parameters=False), |
|
callbacks=[checkpoint_callback_mood, checkpoint_callback_va, early_stop_callback], |
|
logger=logger, |
|
num_sanity_val_steps=0 |
|
) |
|
|
|
else: |
|
checkpoint_callback = ModelCheckpoint(**config.checkpoint) |
|
|
|
trainer = pl.Trainer( |
|
**config.trainer, |
|
callbacks=[checkpoint_callback, early_stop_callback], |
|
logger=logger, |
|
num_sanity_val_steps = 0 |
|
) |
|
|
|
trainer.fit(model, combined_train_loader, combined_val_loader) |
|
|
|
if trainer.global_rank == 0: |
|
best_checkpoint_file = os.path.join(train_log_dir, 'best_checkpoint.txt') |
|
with open(best_checkpoint_file, 'w') as f: |
|
if is_mt: |
|
f.write(f"Best checkpoint (mood): {checkpoint_callback_mood.best_model_path}\n") |
|
f.write(f"Best checkpoint (va): {checkpoint_callback_va.best_model_path}\n") |
|
else: |
|
f.write(f"Best checkpoint: {checkpoint_callback.best_model_path}\n") |
|
f.write(f"Version: {logger.version}\n") |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|