Upload 5 files
Browse files- .gitattributes +1 -0
- data_loader.py +78 -0
- framework.png +3 -0
- test.py +98 -0
- train.py +107 -0
- trainer.py +478 -0
.gitattributes
CHANGED
@@ -35,3 +35,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
inference/input/test.mp3 filter=lfs diff=lfs merge=lfs -text
|
37 |
m2e.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
inference/input/test.mp3 filter=lfs diff=lfs merge=lfs -text
|
37 |
m2e.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
framework.png filter=lfs diff=lfs merge=lfs -text
|
data_loader.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import pickle
|
4 |
+
from torch.utils import data
|
5 |
+
import torchaudio.transforms as T
|
6 |
+
import torchaudio
|
7 |
+
import torch
|
8 |
+
import csv
|
9 |
+
import pytorch_lightning as pl
|
10 |
+
from music2latent import EncoderDecoder
|
11 |
+
import json
|
12 |
+
import math
|
13 |
+
from sklearn.preprocessing import StandardScaler
|
14 |
+
|
15 |
+
from dataset_loaders.jamendo import JamendoDataset
|
16 |
+
from dataset_loaders.pmemo import PMEmoDataset
|
17 |
+
from dataset_loaders.deam import DEAMDataset
|
18 |
+
from dataset_loaders.emomusic import EmoMusicDataset
|
19 |
+
|
20 |
+
from omegaconf import DictConfig
|
21 |
+
|
22 |
+
DATASET_REGISTRY = {
|
23 |
+
"jamendo": JamendoDataset,
|
24 |
+
"pmemo": PMEmoDataset,
|
25 |
+
"deam": DEAMDataset,
|
26 |
+
"emomusic": EmoMusicDataset
|
27 |
+
}
|
28 |
+
|
29 |
+
class DataModule(pl.LightningDataModule):
|
30 |
+
def __init__(self, cfg: DictConfig):
|
31 |
+
super().__init__()
|
32 |
+
self.cfg = cfg
|
33 |
+
|
34 |
+
self.train_datasets = []
|
35 |
+
self.val_datasets = []
|
36 |
+
self.test_datasets = []
|
37 |
+
|
38 |
+
def setup(self, stage=None):
|
39 |
+
# Clear previous dataset lists
|
40 |
+
self.train_datasets = []
|
41 |
+
self.val_datasets = []
|
42 |
+
self.test_datasets = []
|
43 |
+
|
44 |
+
# Register the datasets and load them
|
45 |
+
for dataset_name in self.cfg.datasets:
|
46 |
+
dataset_cfg = self.cfg.dataset[dataset_name]
|
47 |
+
|
48 |
+
if dataset_name in DATASET_REGISTRY:
|
49 |
+
train_dataset = DATASET_REGISTRY[dataset_name](**dataset_cfg, cfg=self.cfg, tr_val='train')
|
50 |
+
val_dataset = DATASET_REGISTRY[dataset_name](**dataset_cfg, cfg=self.cfg, tr_val='validation')
|
51 |
+
test_dataset = DATASET_REGISTRY[dataset_name](**dataset_cfg, cfg=self.cfg, tr_val='test')
|
52 |
+
|
53 |
+
self.train_datasets.append(train_dataset)
|
54 |
+
self.val_datasets.append(val_dataset)
|
55 |
+
self.test_datasets.append(test_dataset)
|
56 |
+
else:
|
57 |
+
raise ValueError(f"Dataset {dataset_name} not found in registry")
|
58 |
+
|
59 |
+
def train_dataloader(self):
|
60 |
+
return [data.DataLoader(ds, batch_size=self.cfg.dataset[ds_name].batch_size,
|
61 |
+
shuffle=True, num_workers=self.cfg.dataset[ds_name].num_workers,
|
62 |
+
persistent_workers=True)
|
63 |
+
for ds, ds_name in zip(self.train_datasets, self.cfg.datasets)]
|
64 |
+
|
65 |
+
def val_dataloader(self):
|
66 |
+
return [data.DataLoader(ds, batch_size=self.cfg.dataset[ds_name].batch_size,
|
67 |
+
shuffle=False, num_workers=self.cfg.dataset[ds_name].num_workers,
|
68 |
+
persistent_workers=True)
|
69 |
+
for ds, ds_name in zip(self.val_datasets, self.cfg.datasets)]
|
70 |
+
|
71 |
+
def test_dataloader(self):
|
72 |
+
return [data.DataLoader(ds, batch_size=self.cfg.dataset[ds_name].batch_size,
|
73 |
+
shuffle=False, num_workers=self.cfg.dataset[ds_name].num_workers,
|
74 |
+
persistent_workers=True)
|
75 |
+
for ds, ds_name in zip(self.test_datasets, self.cfg.datasets)]
|
76 |
+
|
77 |
+
|
78 |
+
|
framework.png
ADDED
![]() |
Git LFS Details
|
test.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
|
4 |
+
import torch
|
5 |
+
torch.set_float32_matmul_precision("medium")
|
6 |
+
|
7 |
+
import pytorch_lightning as pl
|
8 |
+
from pytorch_lightning.loggers import TensorBoardLogger
|
9 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
10 |
+
from data_loader import DataModule
|
11 |
+
from trainer import MusicClassifier
|
12 |
+
import yaml
|
13 |
+
from omegaconf import DictConfig
|
14 |
+
import hydra
|
15 |
+
from hydra.utils import to_absolute_path
|
16 |
+
from hydra.core.hydra_config import HydraConfig
|
17 |
+
from pytorch_lightning.utilities.combined_loader import CombinedLoader
|
18 |
+
|
19 |
+
|
20 |
+
log = logging.getLogger(__name__)
|
21 |
+
|
22 |
+
def get_latest_version(log_dir):
|
23 |
+
version_dirs = [d for d in os.listdir(log_dir) if d.startswith('version_')]
|
24 |
+
version_dirs.sort(key=lambda x: int(x.split('_')[-1])) # Sort by version number
|
25 |
+
return version_dirs[-1] if version_dirs else None
|
26 |
+
|
27 |
+
def save_metrics_and_checkpoint(metrics, checkpoint, output_file):
|
28 |
+
data = {
|
29 |
+
'checkpoint': checkpoint,
|
30 |
+
'metrics': metrics
|
31 |
+
}
|
32 |
+
with open(output_file, 'w') as f:
|
33 |
+
yaml.dump(data, f)
|
34 |
+
|
35 |
+
def read_best_checkpoint_info(file_path, dataset_type=None):
|
36 |
+
"""Read the best checkpoint file."""
|
37 |
+
if not os.path.exists(file_path):
|
38 |
+
raise FileNotFoundError(f"Checkpoint info file not found: {file_path}")
|
39 |
+
|
40 |
+
with open(file_path, 'r') as f:
|
41 |
+
lines = f.readlines()
|
42 |
+
|
43 |
+
if dataset_type == "mood":
|
44 |
+
checkpoint_line = next((line for line in lines if line.startswith("Best checkpoint (mood):")), None)
|
45 |
+
elif dataset_type == "va":
|
46 |
+
checkpoint_line = next((line for line in lines if line.startswith("Best checkpoint (va):")), None)
|
47 |
+
else:
|
48 |
+
checkpoint_line = next((line for line in lines if line.startswith("Best checkpoint:")), None)
|
49 |
+
|
50 |
+
if not checkpoint_line:
|
51 |
+
raise ValueError(f"No checkpoint found for dataset type '{dataset_type}' in the file.")
|
52 |
+
|
53 |
+
return checkpoint_line.split(": ")[-1].strip()
|
54 |
+
|
55 |
+
@hydra.main(version_base=None, config_path="config", config_name="test_config")
|
56 |
+
def main(config: DictConfig):
|
57 |
+
log.info("Testing starts")
|
58 |
+
log_base_dir = 'tb_logs/train_audio_classification'
|
59 |
+
# log_base_dir = to_absolute_path('tb_logs/train_audio_classification')
|
60 |
+
|
61 |
+
latest_version = get_latest_version(log_base_dir)
|
62 |
+
if not latest_version:
|
63 |
+
raise FileNotFoundError("No version directories found in log base directory.")
|
64 |
+
version_log_dir = os.path.join(log_base_dir, latest_version)
|
65 |
+
output_file = os.path.join(version_log_dir, 'test_metrics.txt')
|
66 |
+
|
67 |
+
if config.checkpoint_latest:
|
68 |
+
if config.multitask:
|
69 |
+
dataset_type = config.dataset_type # Expecting 'mood' or 'va'
|
70 |
+
best_checkpoint_file = os.path.join(version_log_dir, 'best_checkpoint.txt')
|
71 |
+
ckpt = read_best_checkpoint_info(best_checkpoint_file, dataset_type)
|
72 |
+
else:
|
73 |
+
best_checkpoint_file = os.path.join(version_log_dir, 'best_checkpoint.txt')
|
74 |
+
ckpt = read_best_checkpoint_info(best_checkpoint_file)
|
75 |
+
else:
|
76 |
+
ckpt = config.checkpoint
|
77 |
+
if not os.path.exists(ckpt):
|
78 |
+
raise FileNotFoundError(f"Checkpoint file not found: {ckpt}")
|
79 |
+
|
80 |
+
log.info(f"Using checkpoint: {ckpt}")
|
81 |
+
data_module = DataModule( config )
|
82 |
+
data_module.setup()
|
83 |
+
|
84 |
+
testloaders = {dataset_name: loader for dataset_name, loader in zip(config.datasets, data_module.test_dataloader())}
|
85 |
+
combined_test_loader = CombinedLoader(testloaders, mode="max_size")
|
86 |
+
|
87 |
+
model = MusicClassifier.load_from_checkpoint(ckpt, cfg=config, output_file=output_file)
|
88 |
+
logger = TensorBoardLogger(save_dir=log_base_dir,
|
89 |
+
name="",
|
90 |
+
version=latest_version)
|
91 |
+
trainer = pl.Trainer(**config.trainer,
|
92 |
+
logger=logger)
|
93 |
+
|
94 |
+
|
95 |
+
trainer.test(model, combined_test_loader)
|
96 |
+
|
97 |
+
if __name__ == '__main__':
|
98 |
+
main()
|
train.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
|
4 |
+
import torch
|
5 |
+
torch.set_float32_matmul_precision("medium")
|
6 |
+
|
7 |
+
|
8 |
+
import pytorch_lightning as pl
|
9 |
+
from pytorch_lightning.loggers import TensorBoardLogger
|
10 |
+
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
|
11 |
+
from data_loader import DataModule
|
12 |
+
from trainer import MusicClassifier
|
13 |
+
from omegaconf import DictConfig
|
14 |
+
import hydra
|
15 |
+
from hydra.utils import to_absolute_path
|
16 |
+
from hydra.core.hydra_config import HydraConfig
|
17 |
+
from pytorch_lightning.callbacks import EarlyStopping
|
18 |
+
from pytorch_lightning.utilities.combined_loader import CombinedLoader
|
19 |
+
from pytorch_lightning.strategies import DDPStrategy
|
20 |
+
|
21 |
+
#from utilities.custom_early_stopping import MultiMetricEarlyStopping
|
22 |
+
def get_latest_version(log_dir):
|
23 |
+
version_dirs = [d for d in os.listdir(log_dir) if d.startswith('version_')]
|
24 |
+
version_dirs.sort(key=lambda x: int(x.split('_')[-1])) # Sort by version number
|
25 |
+
return version_dirs[-1] if version_dirs else None
|
26 |
+
|
27 |
+
log = logging.getLogger(__name__)
|
28 |
+
@hydra.main(version_base=None, config_path="config", config_name="train_config")
|
29 |
+
def main(config: DictConfig):
|
30 |
+
|
31 |
+
log_base_dir = 'tb_logs/train_audio_classification'
|
32 |
+
# log_base_dir = to_absolute_path('tb_logs/train_audio_classification')
|
33 |
+
is_mt = False
|
34 |
+
if "mt" in config.model.classifier:
|
35 |
+
is_mt = True
|
36 |
+
|
37 |
+
logger = TensorBoardLogger("tb_logs", name="train_audio_classification")
|
38 |
+
logger.log_hyperparams(config)
|
39 |
+
train_log_dir = logger.log_dir
|
40 |
+
print(f"Logging to {train_log_dir}")
|
41 |
+
log.info("Training starts")
|
42 |
+
|
43 |
+
data_module = DataModule( config )
|
44 |
+
data_module.setup()
|
45 |
+
|
46 |
+
# Get the list of dataloaders for both train and validation, with dataset names
|
47 |
+
trainloaders = {dataset_name: loader for dataset_name, loader in zip(config.datasets, data_module.train_dataloader())}
|
48 |
+
vallowers = {dataset_name: loader for dataset_name, loader in zip(config.datasets, data_module.val_dataloader())}
|
49 |
+
|
50 |
+
# Combine multiple loaders using CombinedLoader, now with dataset names
|
51 |
+
combined_train_loader = CombinedLoader(trainloaders, mode="max_size")
|
52 |
+
combined_val_loader = CombinedLoader(vallowers, mode="max_size")
|
53 |
+
|
54 |
+
latest_version = get_latest_version(log_base_dir)
|
55 |
+
next_version = int(latest_version.split('_')[-1]) + 1 if latest_version else 0
|
56 |
+
next_version = f"version_{next_version}"
|
57 |
+
|
58 |
+
val_epoch_file = os.path.join(log_base_dir, latest_version, 'val_epoch.txt')
|
59 |
+
|
60 |
+
model = MusicClassifier( config, output_file = val_epoch_file)
|
61 |
+
|
62 |
+
if is_mt:
|
63 |
+
checkpoint_callback_mood = ModelCheckpoint(**config.checkpoint_mood)
|
64 |
+
checkpoint_callback_va = ModelCheckpoint(**config.checkpoint_va)
|
65 |
+
early_stop_callback = EarlyStopping(**config.earlystopping)
|
66 |
+
|
67 |
+
if config.model.kd == True:
|
68 |
+
trainer = pl.Trainer(
|
69 |
+
**config.trainer,
|
70 |
+
strategy=DDPStrategy(find_unused_parameters=True),
|
71 |
+
callbacks=[checkpoint_callback_mood, checkpoint_callback_va, early_stop_callback],
|
72 |
+
logger=logger,
|
73 |
+
num_sanity_val_steps=0
|
74 |
+
)
|
75 |
+
else:
|
76 |
+
trainer = pl.Trainer(
|
77 |
+
**config.trainer,
|
78 |
+
strategy=DDPStrategy(find_unused_parameters=False),
|
79 |
+
callbacks=[checkpoint_callback_mood, checkpoint_callback_va, early_stop_callback],
|
80 |
+
logger=logger,
|
81 |
+
num_sanity_val_steps=0
|
82 |
+
)
|
83 |
+
|
84 |
+
else:
|
85 |
+
checkpoint_callback = ModelCheckpoint(**config.checkpoint)
|
86 |
+
# early_stop_callback = EarlyStopping(**config.earlystopping)
|
87 |
+
trainer = pl.Trainer(
|
88 |
+
**config.trainer,
|
89 |
+
callbacks=[checkpoint_callback, early_stop_callback],
|
90 |
+
logger=logger,
|
91 |
+
num_sanity_val_steps = 0
|
92 |
+
)
|
93 |
+
|
94 |
+
trainer.fit(model, combined_train_loader, combined_val_loader)
|
95 |
+
|
96 |
+
if trainer.global_rank == 0:
|
97 |
+
best_checkpoint_file = os.path.join(train_log_dir, 'best_checkpoint.txt')
|
98 |
+
with open(best_checkpoint_file, 'w') as f:
|
99 |
+
if is_mt:
|
100 |
+
f.write(f"Best checkpoint (mood): {checkpoint_callback_mood.best_model_path}\n")
|
101 |
+
f.write(f"Best checkpoint (va): {checkpoint_callback_va.best_model_path}\n")
|
102 |
+
else:
|
103 |
+
f.write(f"Best checkpoint: {checkpoint_callback.best_model_path}\n")
|
104 |
+
f.write(f"Version: {logger.version}\n")
|
105 |
+
|
106 |
+
if __name__ == '__main__':
|
107 |
+
main()
|
trainer.py
ADDED
@@ -0,0 +1,478 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import pytorch_lightning as pl
|
6 |
+
from sklearn import metrics
|
7 |
+
from transformers import AutoModelForAudioClassification
|
8 |
+
import numpy as np
|
9 |
+
from collections import OrderedDict
|
10 |
+
from torchmetrics import MeanMetric, MaxMetric, Accuracy
|
11 |
+
import torchmetrics.functional as tmf
|
12 |
+
|
13 |
+
from model.linear import FeedforwardModel
|
14 |
+
from model.linear_small import FeedforwardModelSmall
|
15 |
+
from model.linear_attn_ck import FeedforwardModelAttnCK
|
16 |
+
from model.linear_mt import FeedforwardModelMT
|
17 |
+
from model.linear_mt_attn_ck import FeedforwardModelMTAttnCK
|
18 |
+
|
19 |
+
import logging
|
20 |
+
import yaml
|
21 |
+
from omegaconf import DictConfig
|
22 |
+
|
23 |
+
import torch
|
24 |
+
from torch.distributed import all_gather, get_world_size
|
25 |
+
# from lion_pytorch import Lion
|
26 |
+
from torch_optimizer import RAdam
|
27 |
+
|
28 |
+
def gather_all_results(tensor):
|
29 |
+
"""
|
30 |
+
Gather tensors from all GPUs in distributed training.
|
31 |
+
"""
|
32 |
+
gathered_tensors = [torch.zeros_like(tensor) for _ in range(get_world_size())]
|
33 |
+
all_gather(gathered_tensors, tensor)
|
34 |
+
return torch.cat(gathered_tensors, dim=0)
|
35 |
+
|
36 |
+
# torch.set_float32_matmul_precision('medium')
|
37 |
+
|
38 |
+
log = logging.getLogger(__name__)
|
39 |
+
class MusicClassifier(pl.LightningModule):
|
40 |
+
def __init__(self, cfg: DictConfig, output_file = None):
|
41 |
+
super(MusicClassifier, self).__init__()
|
42 |
+
self.cfg = cfg
|
43 |
+
self.encoder = cfg.model.encoder
|
44 |
+
self.classifier = cfg.model.classifier
|
45 |
+
self.lr = cfg.model.lr
|
46 |
+
self.output_file = output_file
|
47 |
+
self.kd = cfg.model.kd
|
48 |
+
self.kd_weight = cfg.model.kd_weight
|
49 |
+
self.kd_temperature = self.cfg.model.kd_temperature
|
50 |
+
|
51 |
+
layer_size = len(self.cfg.model.layers)
|
52 |
+
mert_dim = 768 * layer_size
|
53 |
+
|
54 |
+
self.feature_dim_dict = {
|
55 |
+
"MERT": mert_dim
|
56 |
+
}
|
57 |
+
|
58 |
+
encoders = self.encoder.split("-")
|
59 |
+
self.input_size = sum(self.feature_dim_dict[encoder] for encoder in encoders)
|
60 |
+
self.num_datasets = len(self.cfg.datasets)
|
61 |
+
|
62 |
+
if "mt" in self.classifier:
|
63 |
+
if self.num_datasets < 2:
|
64 |
+
raise Exception("Error: Dataset size >= 2 needed for MT classifier")
|
65 |
+
classifiers = {
|
66 |
+
"linear-mt-attn-ck": FeedforwardModelMTAttnCK,
|
67 |
+
}
|
68 |
+
if self.classifier in classifiers:
|
69 |
+
self.model = classifiers[self.classifier](
|
70 |
+
input_size=self.input_size,
|
71 |
+
output_size_classification=56,
|
72 |
+
output_size_regression=2
|
73 |
+
)
|
74 |
+
else:
|
75 |
+
raise Exception(f"Unknown classifier: {self.classifier}")
|
76 |
+
else:
|
77 |
+
if self.num_datasets >= 2:
|
78 |
+
raise Exception(f"Error: Dataset size == 1 needed for classifier")
|
79 |
+
dataset_name = self.cfg.datasets[0]
|
80 |
+
self.output_size = self.cfg.dataset[dataset_name].output_size
|
81 |
+
classifiers = {
|
82 |
+
"linear": FeedforwardModel,
|
83 |
+
"linear-attn-ck": FeedforwardModelAttnCK
|
84 |
+
}
|
85 |
+
|
86 |
+
if self.classifier in classifiers:
|
87 |
+
self.model = classifiers[self.classifier](input_size=self.input_size, output_size=self.output_size)
|
88 |
+
else:
|
89 |
+
raise Exception(f"Unknown classifier: {self.classifier}")
|
90 |
+
|
91 |
+
|
92 |
+
if self.kd:
|
93 |
+
self.teacher_models = {}
|
94 |
+
|
95 |
+
for dataset in self.cfg.datasets:
|
96 |
+
self.output_size = self.cfg.dataset[dataset].output_size
|
97 |
+
teacher_model_path = getattr(self.cfg, f"checkpoint_{dataset}", None)
|
98 |
+
|
99 |
+
if teacher_model_path:
|
100 |
+
# Create a new teacher model instance
|
101 |
+
teacher_model = FeedforwardModelAttnCK(
|
102 |
+
input_size=self.input_size,
|
103 |
+
output_size=self.output_size,
|
104 |
+
)
|
105 |
+
|
106 |
+
# Load the checkpoint
|
107 |
+
checkpoint = torch.load(teacher_model_path, map_location=self.device, weights_only=False)
|
108 |
+
state_dict = checkpoint["state_dict"]
|
109 |
+
|
110 |
+
# Adjust the keys in the state_dict
|
111 |
+
state_dict = {key.replace("model.", ""): value for key, value in state_dict.items()}
|
112 |
+
|
113 |
+
# Filter state_dict to match model's keys
|
114 |
+
model_keys = set(teacher_model.state_dict().keys())
|
115 |
+
filtered_state_dict = {key: value for key, value in state_dict.items() if key in model_keys}
|
116 |
+
|
117 |
+
# Load the filtered state_dict and set the model to evaluation mode
|
118 |
+
teacher_model.load_state_dict(filtered_state_dict)
|
119 |
+
teacher_model.to(self.device)
|
120 |
+
|
121 |
+
teacher_model.eval()
|
122 |
+
|
123 |
+
# Store the teacher model in the dictionary with the dataset name as the key
|
124 |
+
self.teacher_models[dataset] = teacher_model
|
125 |
+
|
126 |
+
probas = torch.from_numpy(np.load("dataset/jamendo/meta/probas_train.npy"))
|
127 |
+
pos_weight = torch.tensor(1.) / probas
|
128 |
+
weight = torch.tensor(2.) / (torch.tensor(1.) + pos_weight)
|
129 |
+
|
130 |
+
self.loss_fn_classification = nn.BCEWithLogitsLoss(
|
131 |
+
pos_weight=pos_weight,reduction="mean",weight=weight
|
132 |
+
)
|
133 |
+
self.loss_fn_classification_eval = nn.BCEWithLogitsLoss(
|
134 |
+
pos_weight=pos_weight,reduction="none",weight=weight
|
135 |
+
)
|
136 |
+
|
137 |
+
self.loss_fn_regression = nn.MSELoss()
|
138 |
+
|
139 |
+
self.loss_kd = nn.KLDivLoss(reduction="batchmean")
|
140 |
+
|
141 |
+
self.prd_array = []
|
142 |
+
self.gt_array = []
|
143 |
+
self.song_array = []
|
144 |
+
|
145 |
+
self.prd_array_va = []
|
146 |
+
self.gt_array_va = []
|
147 |
+
self.song_array_va = []
|
148 |
+
|
149 |
+
self.validation_predictions = []
|
150 |
+
self.validation_targets = []
|
151 |
+
self.validation_results = {'preds': [], 'gt': []}
|
152 |
+
|
153 |
+
self.trn_loss = MeanMetric()
|
154 |
+
self.val_loss = MeanMetric()
|
155 |
+
|
156 |
+
def forward(self, model_input_dic, output_idx = 0):
|
157 |
+
if "mt" in self.classifier:
|
158 |
+
classification_output, regression_output = self.model(model_input_dic)
|
159 |
+
if output_idx == 0:
|
160 |
+
return classification_output
|
161 |
+
elif output_idx == 1:
|
162 |
+
return regression_output
|
163 |
+
elif output_idx == 2:
|
164 |
+
return classification_output, regression_output
|
165 |
+
else:
|
166 |
+
output = self.model(model_input_dic)
|
167 |
+
return output
|
168 |
+
|
169 |
+
def compute_classification_loss(self, model_input_dic, y_mood):
|
170 |
+
classification_logits = self(model_input_dic, 0)
|
171 |
+
loss= self.loss_fn_classification(classification_logits, y_mood)
|
172 |
+
return loss
|
173 |
+
|
174 |
+
def compute_regression_loss(self, model_input_dic, y_va):
|
175 |
+
regression_output = self(model_input_dic, 1)
|
176 |
+
loss = self.loss_fn_regression(regression_output, y_va)
|
177 |
+
return loss
|
178 |
+
|
179 |
+
def compute_mt_loss(self, model_input_dic, y_mood, y_va):
|
180 |
+
classification_logits, regression_output = self(model_input_dic, 2)
|
181 |
+
loss_classification = self.loss_fn_classification(classification_logits, y_mood)
|
182 |
+
loss_regression = self.loss_fn_regression(regression_output, y_va)
|
183 |
+
return loss_classification, loss_regression
|
184 |
+
|
185 |
+
|
186 |
+
def compute_kd_loss(self, model_input_dic, y_mood, y_va, dataset_name):
|
187 |
+
"""
|
188 |
+
Compute knowledge distillation loss for a given dataset.
|
189 |
+
"""
|
190 |
+
# Forward pass through student model
|
191 |
+
s_logits_mood, s_logits_va = self(model_input_dic, 2)
|
192 |
+
|
193 |
+
# Compute student losses
|
194 |
+
s_loss_mood = self.loss_fn_classification(s_logits_mood, y_mood)
|
195 |
+
s_loss_va = self.loss_fn_regression(s_logits_va, y_va)
|
196 |
+
|
197 |
+
# Get the corresponding teacher model for the dataset
|
198 |
+
teacher_model = self.teacher_models.get(dataset_name)
|
199 |
+
teacher_model.to(self.device)
|
200 |
+
|
201 |
+
# Ensure teacher model exists
|
202 |
+
if teacher_model is None:
|
203 |
+
raise ValueError(f"No teacher model found for dataset: {dataset_name}")
|
204 |
+
|
205 |
+
with torch.no_grad():
|
206 |
+
# Forward pass through teacher model
|
207 |
+
t_logits = teacher_model(model_input_dic)
|
208 |
+
|
209 |
+
# Compute knowledge distillation losses
|
210 |
+
t_probs = torch.softmax(t_logits / self.kd_temperature, dim=-1)
|
211 |
+
if dataset_name == "jamendo":
|
212 |
+
s_probs_mood = torch.log_softmax(s_logits_mood / self.kd_temperature, dim=-1)
|
213 |
+
kd_loss = self.loss_kd(s_probs_mood, t_probs)
|
214 |
+
else:
|
215 |
+
s_probs_va = torch.log_softmax(s_logits_va / self.kd_temperature, dim=-1)
|
216 |
+
kd_loss = self.loss_kd(s_probs_va, t_probs)
|
217 |
+
|
218 |
+
return kd_loss, s_loss_mood, s_loss_va
|
219 |
+
|
220 |
+
def handle_dataset(self, dataset_name, batch, losses, total_loss, stage):
|
221 |
+
dataset_batch = batch[dataset_name]
|
222 |
+
|
223 |
+
model_input_dic = {}
|
224 |
+
model_input_dic["x_mert"] = dataset_batch["x_mert"]
|
225 |
+
model_input_dic["x_chord"] = dataset_batch["x_chord"]
|
226 |
+
model_input_dic["x_chord_root"] = dataset_batch["x_chord_root"]
|
227 |
+
model_input_dic["x_chord_attr"] = dataset_batch["x_chord_attr"]
|
228 |
+
model_input_dic["x_key"] = dataset_batch["x_key"]
|
229 |
+
|
230 |
+
if "mt" in self.classifier:
|
231 |
+
if dataset_name == "jamendo":
|
232 |
+
y_mood = dataset_batch["y_mood"]
|
233 |
+
y_va = dataset_batch["y_va"]
|
234 |
+
if self.kd:
|
235 |
+
kd_loss, s_loss_mood, s_loss_va = self.compute_kd_loss(model_input_dic, y_mood, y_va, dataset_name)
|
236 |
+
if stage == "train":
|
237 |
+
losses['loss_mood'] = s_loss_mood
|
238 |
+
|
239 |
+
total_loss += self.kd_weight * kd_loss + (1 - self.kd_weight) * s_loss_mood
|
240 |
+
else:
|
241 |
+
losses['loss_mood'] = s_loss_mood
|
242 |
+
total_loss += s_loss_mood
|
243 |
+
else:
|
244 |
+
s_loss_mood, s_loss_va = self.compute_mt_loss(model_input_dic, y_mood, y_va)
|
245 |
+
if stage == "train":
|
246 |
+
losses['loss_mood'] = s_loss_mood
|
247 |
+
total_loss += s_loss_mood
|
248 |
+
else:
|
249 |
+
losses['loss_mood'] = s_loss_mood
|
250 |
+
total_loss += s_loss_mood
|
251 |
+
else:
|
252 |
+
y_mood = dataset_batch["y_mood"]
|
253 |
+
y_va = dataset_batch["y_va"]
|
254 |
+
|
255 |
+
if self.kd:
|
256 |
+
kd_loss, s_loss_mood, s_loss_va = self.compute_kd_loss(model_input_dic, y_mood, y_va, dataset_name)
|
257 |
+
if stage == "train":
|
258 |
+
losses['loss_va'] = s_loss_va
|
259 |
+
total_loss += self.kd_weight * kd_loss + (1 - self.kd_weight) * s_loss_va
|
260 |
+
else:
|
261 |
+
losses['loss_va'] = s_loss_va
|
262 |
+
total_loss += s_loss_va
|
263 |
+
else:
|
264 |
+
s_loss_mood, s_loss_va = self.compute_mt_loss(model_input_dic, y_mood, y_va)
|
265 |
+
if stage == "train":
|
266 |
+
losses['loss_va'] = s_loss_va
|
267 |
+
total_loss += s_loss_va
|
268 |
+
else:
|
269 |
+
losses['loss_va'] = s_loss_va
|
270 |
+
total_loss += s_loss_va
|
271 |
+
else:
|
272 |
+
if dataset_name == "jamendo":
|
273 |
+
y_mood = dataset_batch["y_mood"]
|
274 |
+
loss_classification = self.compute_classification_loss(model_input_dic, y_mood)
|
275 |
+
losses['loss_mood'] = loss_classification
|
276 |
+
total_loss += loss_classification
|
277 |
+
else:
|
278 |
+
y_va = dataset_batch["y_va"]
|
279 |
+
loss_regression = self.compute_regression_loss(model_input_dic, y_va)
|
280 |
+
losses['loss_va'] = loss_regression
|
281 |
+
total_loss += loss_regression
|
282 |
+
|
283 |
+
return total_loss
|
284 |
+
|
285 |
+
def training_step(self, batch, batch_idx):
|
286 |
+
total_loss = 0
|
287 |
+
losses = {}
|
288 |
+
datasets = ["jamendo", "deam", "emomusic", "pmemo"]
|
289 |
+
|
290 |
+
for dataset in datasets:
|
291 |
+
if dataset in batch and batch[dataset] is not None:
|
292 |
+
total_loss = self.handle_dataset(dataset, batch, losses, total_loss, "train")
|
293 |
+
|
294 |
+
batch_size = batch[next(iter(batch))]["x_mert"].size(0)
|
295 |
+
|
296 |
+
self.log('train_loss_mood', losses.get('loss_mood', 0), on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size)
|
297 |
+
self.log('train_loss_va', losses.get('loss_va', 0), on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size)
|
298 |
+
self.log('train_loss', total_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size)
|
299 |
+
|
300 |
+
return total_loss
|
301 |
+
|
302 |
+
def validation_step(self, batch, batch_idx):
|
303 |
+
|
304 |
+
total_loss = 0
|
305 |
+
losses = {}
|
306 |
+
datasets = ["jamendo", "deam", "emomusic", "pmemo"]
|
307 |
+
|
308 |
+
for dataset in datasets:
|
309 |
+
if dataset in batch and batch[dataset] is not None:
|
310 |
+
total_loss = self.handle_dataset(dataset, batch, losses, total_loss, "val")
|
311 |
+
|
312 |
+
batch_size = batch[next(iter(batch))]["x_mert"].size(0)
|
313 |
+
|
314 |
+
self.log('val_loss_mood', losses.get('loss_mood', 0), on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size)
|
315 |
+
self.log('val_loss_va', losses.get('loss_va', 0), on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size)
|
316 |
+
self.log('val_loss', total_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size)
|
317 |
+
return total_loss
|
318 |
+
|
319 |
+
def test_step(self, batch, batch_idx):
|
320 |
+
total_loss = 0
|
321 |
+
losses = {}
|
322 |
+
datasets = ["jamendo", "deam", "emomusic", "pmemo"]
|
323 |
+
|
324 |
+
for dataset in datasets:
|
325 |
+
if dataset in batch and batch[dataset] is not None:
|
326 |
+
dataset_batch = batch[dataset]
|
327 |
+
|
328 |
+
model_input_dic = {}
|
329 |
+
model_input_dic["x_mert"] = dataset_batch["x_mert"]
|
330 |
+
model_input_dic["x_chord"] = dataset_batch["x_chord"]
|
331 |
+
model_input_dic["x_chord_root"] = dataset_batch["x_chord_root"]
|
332 |
+
model_input_dic["x_chord_attr"] = dataset_batch["x_chord_attr"]
|
333 |
+
|
334 |
+
model_input_dic["x_key"] = dataset_batch["x_key"]
|
335 |
+
|
336 |
+
if dataset == "jamendo":
|
337 |
+
y_mood = dataset_batch["y_mood"]
|
338 |
+
classification_logits = self(model_input_dic, 0)
|
339 |
+
|
340 |
+
loss_classification = self.loss_fn_classification(classification_logits, y_mood)
|
341 |
+
total_loss += loss_classification
|
342 |
+
|
343 |
+
probs = torch.sigmoid(classification_logits)
|
344 |
+
if not hasattr(self, 'jamendo_results'):
|
345 |
+
self.jamendo_results = {'preds': [], 'gt': [], 'paths': []}
|
346 |
+
|
347 |
+
self.jamendo_results['preds'].extend(probs.detach().cpu().numpy())
|
348 |
+
self.jamendo_results['gt'].extend(y_mood.detach().cpu().numpy())
|
349 |
+
self.jamendo_results['paths'].extend(dataset_batch["path"])
|
350 |
+
|
351 |
+
losses['test_loss_mood'] = loss_classification
|
352 |
+
|
353 |
+
else: # Handle regression for all other datasets
|
354 |
+
if batch[dataset] is not None:
|
355 |
+
y_va = dataset_batch["y_va"]
|
356 |
+
regression_output = self(model_input_dic, 1)
|
357 |
+
|
358 |
+
loss_regression = self.loss_fn_regression(regression_output, y_va)
|
359 |
+
total_loss += loss_regression
|
360 |
+
|
361 |
+
# Track results separately for each dataset
|
362 |
+
if not hasattr(self, f'{dataset}_results'):
|
363 |
+
setattr(self, f'{dataset}_results', {'preds': [], 'gt': [], 'paths': []})
|
364 |
+
|
365 |
+
dataset_results = getattr(self, f'{dataset}_results')
|
366 |
+
dataset_results['preds'].extend(regression_output.detach().cpu().numpy())
|
367 |
+
dataset_results['gt'].extend(y_va.detach().cpu().numpy())
|
368 |
+
dataset_results['paths'].extend(dataset_batch["path"])
|
369 |
+
|
370 |
+
losses['test_loss_va'] = loss_regression
|
371 |
+
|
372 |
+
batch_size = batch[next(iter(batch))]["x_mert"].size(0)
|
373 |
+
|
374 |
+
# Log the classification and regression losses
|
375 |
+
self.log('test_loss_mood', losses.get('test_loss_mood', 0), on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size)
|
376 |
+
self.log('test_loss_va', losses.get('test_loss_va', 0), on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size)
|
377 |
+
self.log('test_loss', total_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size)
|
378 |
+
|
379 |
+
return total_loss
|
380 |
+
|
381 |
+
def on_test_end(self):
|
382 |
+
output_dic = {}
|
383 |
+
|
384 |
+
# Jamendo classification metrics (AUC and PR AUC)
|
385 |
+
if hasattr(self, 'jamendo_results') and self.jamendo_results['preds']:
|
386 |
+
roc_auc, pr_auc = self.get_auc(self.jamendo_results['preds'], self.jamendo_results['gt'])
|
387 |
+
|
388 |
+
roc_auc = roc_auc.item()
|
389 |
+
pr_auc = pr_auc.item()
|
390 |
+
|
391 |
+
log.info('*** Display ROC_AUC_MACRO scores (Jamendo) ***')
|
392 |
+
log.info(f"ROC_AUC_MACRO: {round(roc_auc, 4)}")
|
393 |
+
log.info(f"PR_AUC_MACRO: {round(pr_auc, 4)}")
|
394 |
+
|
395 |
+
if self.output_file is not None:
|
396 |
+
with open(self.output_file, 'a') as f:
|
397 |
+
f.write(f"ROC_AUC_MACRO (Jamendo): {round(roc_auc, 4)}\n")
|
398 |
+
f.write(f"PR_AUC_MACRO (Jamendo): {round(pr_auc, 4)}\n")
|
399 |
+
|
400 |
+
output_dic["test_roc_auc_jamendo"] = round(roc_auc, 4)
|
401 |
+
output_dic["test_pr_auc_jamendo"] = round(pr_auc, 4)
|
402 |
+
|
403 |
+
# Metrics for each regression dataset (DMDD, DEAM, EmoMusic, PMEmo)
|
404 |
+
for dataset in ["deam", "emomusic", "pmemo"]:
|
405 |
+
dataset_results = getattr(self, f'{dataset}_results', None)
|
406 |
+
|
407 |
+
if dataset_results and dataset_results['preds']:
|
408 |
+
preds = torch.tensor(np.array(dataset_results['preds']))
|
409 |
+
gts = torch.tensor(np.array(dataset_results['gt']))
|
410 |
+
|
411 |
+
# Assuming valence is the first column and arousal is the second
|
412 |
+
preds_valence = preds[:, 0]
|
413 |
+
preds_arousal = preds[:, 1]
|
414 |
+
gts_valence = gts[:, 0]
|
415 |
+
gts_arousal = gts[:, 1]
|
416 |
+
|
417 |
+
rmse = torch.sqrt(tmf.mean_squared_error(preds, gts))
|
418 |
+
r2 = tmf.r2_score(preds, gts)
|
419 |
+
|
420 |
+
# Calculate metrics for valence
|
421 |
+
rmse_valence = torch.sqrt(tmf.mean_squared_error(preds_valence, gts_valence))
|
422 |
+
r2_valence = tmf.r2_score(preds_valence, gts_valence)
|
423 |
+
|
424 |
+
# Calculate metrics for arousal
|
425 |
+
rmse_arousal = torch.sqrt(tmf.mean_squared_error(preds_arousal, gts_arousal))
|
426 |
+
r2_arousal = tmf.r2_score(preds_arousal, gts_arousal)
|
427 |
+
|
428 |
+
log.info(f'*** Display RMSE and R² scores ({dataset.upper()}) ***')
|
429 |
+
log.info(f"RMSE: {round(rmse.item(), 4)}")
|
430 |
+
log.info(f"R²: {round(r2.item(), 4)}")
|
431 |
+
log.info(f"Valence - RMSE: {round(rmse_valence.item(), 4)}, R²: {round(r2_valence.item(), 4)}")
|
432 |
+
log.info(f"Arousal - RMSE: {round(rmse_arousal.item(), 4)}, R²: {round(r2_arousal.item(), 4)}")
|
433 |
+
|
434 |
+
if self.output_file is not None:
|
435 |
+
with open(self.output_file, 'a') as f:
|
436 |
+
f.write(f"RMSE ({dataset.upper()}): {round(rmse.item(), 4)}\n")
|
437 |
+
f.write(f"R² ({dataset.upper()}): {round(r2.item(), 4)}\n")
|
438 |
+
f.write(f"Valence - RMSE ({dataset.upper()}): {round(rmse_valence.item(), 4)}\n")
|
439 |
+
f.write(f"Valence - R² ({dataset.upper()}): {round(r2_valence.item(), 4)}\n")
|
440 |
+
f.write(f"Arousal - RMSE ({dataset.upper()}): {round(rmse_arousal.item(), 4)}\n")
|
441 |
+
f.write(f"Arousal - R² ({dataset.upper()}): {round(r2_arousal.item(), 4)}\n")
|
442 |
+
|
443 |
+
output_dic[f"test_rmse_{dataset}"] = round(rmse.item(), 4)
|
444 |
+
output_dic[f"test_r2_{dataset}"] = round(r2.item(), 4)
|
445 |
+
output_dic[f"test_rmse_valence_{dataset}"] = round(rmse_valence.item(), 4)
|
446 |
+
output_dic[f"test_r2_valence_{dataset}"] = round(r2_valence.item(), 4)
|
447 |
+
output_dic[f"test_rmse_arousal_{dataset}"] = round(rmse_arousal.item(), 4)
|
448 |
+
output_dic[f"test_r2_arousal_{dataset}"] = round(r2_arousal.item(), 4)
|
449 |
+
|
450 |
+
# Clear results for each dataset
|
451 |
+
for dataset in ["jamendo", "deam", "emomusic", "pmemo"]:
|
452 |
+
if hasattr(self, f'{dataset}_results'):
|
453 |
+
getattr(self, f'{dataset}_results')['preds'].clear()
|
454 |
+
getattr(self, f'{dataset}_results')['gt'].clear()
|
455 |
+
getattr(self, f'{dataset}_results')['paths'].clear()
|
456 |
+
|
457 |
+
return output_dic
|
458 |
+
|
459 |
+
def configure_optimizers(self):
|
460 |
+
return torch.optim.Adam(self.parameters(), lr=self.lr)
|
461 |
+
|
462 |
+
def get_auc(self, prd_array, gt_array):
|
463 |
+
prd_array = np.array(prd_array)
|
464 |
+
gt_array = np.array(gt_array)
|
465 |
+
|
466 |
+
prd_tensor = torch.tensor(prd_array)
|
467 |
+
gt_tensor = torch.tensor(gt_array)
|
468 |
+
|
469 |
+
try:
|
470 |
+
roc_auc = tmf.auroc(prd_tensor, gt_tensor.int(), task='multilabel', num_labels = 56 , average='macro', num_classes=gt_tensor.size(1))
|
471 |
+
pr_auc = tmf.average_precision(prd_tensor, gt_tensor.int(), task='multilabel', num_labels = 56, average='macro', num_classes=gt_tensor.size(1))
|
472 |
+
except ValueError as e:
|
473 |
+
print(f"Error computing metrics: {e}")
|
474 |
+
roc_auc = None
|
475 |
+
pr_auc = None
|
476 |
+
return roc_auc, pr_auc
|
477 |
+
|
478 |
+
|