English
music
emotion
kjysmu commited on
Commit
2dfd92b
·
verified ·
1 Parent(s): 1adb3ce

Upload 5 files

Browse files
Files changed (6) hide show
  1. .gitattributes +1 -0
  2. data_loader.py +78 -0
  3. framework.png +3 -0
  4. test.py +98 -0
  5. train.py +107 -0
  6. trainer.py +478 -0
.gitattributes CHANGED
@@ -35,3 +35,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  inference/input/test.mp3 filter=lfs diff=lfs merge=lfs -text
37
  m2e.png filter=lfs diff=lfs merge=lfs -text
 
 
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  inference/input/test.mp3 filter=lfs diff=lfs merge=lfs -text
37
  m2e.png filter=lfs diff=lfs merge=lfs -text
38
+ framework.png filter=lfs diff=lfs merge=lfs -text
data_loader.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import pickle
4
+ from torch.utils import data
5
+ import torchaudio.transforms as T
6
+ import torchaudio
7
+ import torch
8
+ import csv
9
+ import pytorch_lightning as pl
10
+ from music2latent import EncoderDecoder
11
+ import json
12
+ import math
13
+ from sklearn.preprocessing import StandardScaler
14
+
15
+ from dataset_loaders.jamendo import JamendoDataset
16
+ from dataset_loaders.pmemo import PMEmoDataset
17
+ from dataset_loaders.deam import DEAMDataset
18
+ from dataset_loaders.emomusic import EmoMusicDataset
19
+
20
+ from omegaconf import DictConfig
21
+
22
+ DATASET_REGISTRY = {
23
+ "jamendo": JamendoDataset,
24
+ "pmemo": PMEmoDataset,
25
+ "deam": DEAMDataset,
26
+ "emomusic": EmoMusicDataset
27
+ }
28
+
29
+ class DataModule(pl.LightningDataModule):
30
+ def __init__(self, cfg: DictConfig):
31
+ super().__init__()
32
+ self.cfg = cfg
33
+
34
+ self.train_datasets = []
35
+ self.val_datasets = []
36
+ self.test_datasets = []
37
+
38
+ def setup(self, stage=None):
39
+ # Clear previous dataset lists
40
+ self.train_datasets = []
41
+ self.val_datasets = []
42
+ self.test_datasets = []
43
+
44
+ # Register the datasets and load them
45
+ for dataset_name in self.cfg.datasets:
46
+ dataset_cfg = self.cfg.dataset[dataset_name]
47
+
48
+ if dataset_name in DATASET_REGISTRY:
49
+ train_dataset = DATASET_REGISTRY[dataset_name](**dataset_cfg, cfg=self.cfg, tr_val='train')
50
+ val_dataset = DATASET_REGISTRY[dataset_name](**dataset_cfg, cfg=self.cfg, tr_val='validation')
51
+ test_dataset = DATASET_REGISTRY[dataset_name](**dataset_cfg, cfg=self.cfg, tr_val='test')
52
+
53
+ self.train_datasets.append(train_dataset)
54
+ self.val_datasets.append(val_dataset)
55
+ self.test_datasets.append(test_dataset)
56
+ else:
57
+ raise ValueError(f"Dataset {dataset_name} not found in registry")
58
+
59
+ def train_dataloader(self):
60
+ return [data.DataLoader(ds, batch_size=self.cfg.dataset[ds_name].batch_size,
61
+ shuffle=True, num_workers=self.cfg.dataset[ds_name].num_workers,
62
+ persistent_workers=True)
63
+ for ds, ds_name in zip(self.train_datasets, self.cfg.datasets)]
64
+
65
+ def val_dataloader(self):
66
+ return [data.DataLoader(ds, batch_size=self.cfg.dataset[ds_name].batch_size,
67
+ shuffle=False, num_workers=self.cfg.dataset[ds_name].num_workers,
68
+ persistent_workers=True)
69
+ for ds, ds_name in zip(self.val_datasets, self.cfg.datasets)]
70
+
71
+ def test_dataloader(self):
72
+ return [data.DataLoader(ds, batch_size=self.cfg.dataset[ds_name].batch_size,
73
+ shuffle=False, num_workers=self.cfg.dataset[ds_name].num_workers,
74
+ persistent_workers=True)
75
+ for ds, ds_name in zip(self.test_datasets, self.cfg.datasets)]
76
+
77
+
78
+
framework.png ADDED

Git LFS Details

  • SHA256: 836c90e0d9954d853c8f9fcc45e76ae4b37b7c9a71e964a5ec904a5a6f5da39c
  • Pointer size: 131 Bytes
  • Size of remote file: 291 kB
test.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+
4
+ import torch
5
+ torch.set_float32_matmul_precision("medium")
6
+
7
+ import pytorch_lightning as pl
8
+ from pytorch_lightning.loggers import TensorBoardLogger
9
+ from pytorch_lightning.callbacks import ModelCheckpoint
10
+ from data_loader import DataModule
11
+ from trainer import MusicClassifier
12
+ import yaml
13
+ from omegaconf import DictConfig
14
+ import hydra
15
+ from hydra.utils import to_absolute_path
16
+ from hydra.core.hydra_config import HydraConfig
17
+ from pytorch_lightning.utilities.combined_loader import CombinedLoader
18
+
19
+
20
+ log = logging.getLogger(__name__)
21
+
22
+ def get_latest_version(log_dir):
23
+ version_dirs = [d for d in os.listdir(log_dir) if d.startswith('version_')]
24
+ version_dirs.sort(key=lambda x: int(x.split('_')[-1])) # Sort by version number
25
+ return version_dirs[-1] if version_dirs else None
26
+
27
+ def save_metrics_and_checkpoint(metrics, checkpoint, output_file):
28
+ data = {
29
+ 'checkpoint': checkpoint,
30
+ 'metrics': metrics
31
+ }
32
+ with open(output_file, 'w') as f:
33
+ yaml.dump(data, f)
34
+
35
+ def read_best_checkpoint_info(file_path, dataset_type=None):
36
+ """Read the best checkpoint file."""
37
+ if not os.path.exists(file_path):
38
+ raise FileNotFoundError(f"Checkpoint info file not found: {file_path}")
39
+
40
+ with open(file_path, 'r') as f:
41
+ lines = f.readlines()
42
+
43
+ if dataset_type == "mood":
44
+ checkpoint_line = next((line for line in lines if line.startswith("Best checkpoint (mood):")), None)
45
+ elif dataset_type == "va":
46
+ checkpoint_line = next((line for line in lines if line.startswith("Best checkpoint (va):")), None)
47
+ else:
48
+ checkpoint_line = next((line for line in lines if line.startswith("Best checkpoint:")), None)
49
+
50
+ if not checkpoint_line:
51
+ raise ValueError(f"No checkpoint found for dataset type '{dataset_type}' in the file.")
52
+
53
+ return checkpoint_line.split(": ")[-1].strip()
54
+
55
+ @hydra.main(version_base=None, config_path="config", config_name="test_config")
56
+ def main(config: DictConfig):
57
+ log.info("Testing starts")
58
+ log_base_dir = 'tb_logs/train_audio_classification'
59
+ # log_base_dir = to_absolute_path('tb_logs/train_audio_classification')
60
+
61
+ latest_version = get_latest_version(log_base_dir)
62
+ if not latest_version:
63
+ raise FileNotFoundError("No version directories found in log base directory.")
64
+ version_log_dir = os.path.join(log_base_dir, latest_version)
65
+ output_file = os.path.join(version_log_dir, 'test_metrics.txt')
66
+
67
+ if config.checkpoint_latest:
68
+ if config.multitask:
69
+ dataset_type = config.dataset_type # Expecting 'mood' or 'va'
70
+ best_checkpoint_file = os.path.join(version_log_dir, 'best_checkpoint.txt')
71
+ ckpt = read_best_checkpoint_info(best_checkpoint_file, dataset_type)
72
+ else:
73
+ best_checkpoint_file = os.path.join(version_log_dir, 'best_checkpoint.txt')
74
+ ckpt = read_best_checkpoint_info(best_checkpoint_file)
75
+ else:
76
+ ckpt = config.checkpoint
77
+ if not os.path.exists(ckpt):
78
+ raise FileNotFoundError(f"Checkpoint file not found: {ckpt}")
79
+
80
+ log.info(f"Using checkpoint: {ckpt}")
81
+ data_module = DataModule( config )
82
+ data_module.setup()
83
+
84
+ testloaders = {dataset_name: loader for dataset_name, loader in zip(config.datasets, data_module.test_dataloader())}
85
+ combined_test_loader = CombinedLoader(testloaders, mode="max_size")
86
+
87
+ model = MusicClassifier.load_from_checkpoint(ckpt, cfg=config, output_file=output_file)
88
+ logger = TensorBoardLogger(save_dir=log_base_dir,
89
+ name="",
90
+ version=latest_version)
91
+ trainer = pl.Trainer(**config.trainer,
92
+ logger=logger)
93
+
94
+
95
+ trainer.test(model, combined_test_loader)
96
+
97
+ if __name__ == '__main__':
98
+ main()
train.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+
4
+ import torch
5
+ torch.set_float32_matmul_precision("medium")
6
+
7
+
8
+ import pytorch_lightning as pl
9
+ from pytorch_lightning.loggers import TensorBoardLogger
10
+ from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
11
+ from data_loader import DataModule
12
+ from trainer import MusicClassifier
13
+ from omegaconf import DictConfig
14
+ import hydra
15
+ from hydra.utils import to_absolute_path
16
+ from hydra.core.hydra_config import HydraConfig
17
+ from pytorch_lightning.callbacks import EarlyStopping
18
+ from pytorch_lightning.utilities.combined_loader import CombinedLoader
19
+ from pytorch_lightning.strategies import DDPStrategy
20
+
21
+ #from utilities.custom_early_stopping import MultiMetricEarlyStopping
22
+ def get_latest_version(log_dir):
23
+ version_dirs = [d for d in os.listdir(log_dir) if d.startswith('version_')]
24
+ version_dirs.sort(key=lambda x: int(x.split('_')[-1])) # Sort by version number
25
+ return version_dirs[-1] if version_dirs else None
26
+
27
+ log = logging.getLogger(__name__)
28
+ @hydra.main(version_base=None, config_path="config", config_name="train_config")
29
+ def main(config: DictConfig):
30
+
31
+ log_base_dir = 'tb_logs/train_audio_classification'
32
+ # log_base_dir = to_absolute_path('tb_logs/train_audio_classification')
33
+ is_mt = False
34
+ if "mt" in config.model.classifier:
35
+ is_mt = True
36
+
37
+ logger = TensorBoardLogger("tb_logs", name="train_audio_classification")
38
+ logger.log_hyperparams(config)
39
+ train_log_dir = logger.log_dir
40
+ print(f"Logging to {train_log_dir}")
41
+ log.info("Training starts")
42
+
43
+ data_module = DataModule( config )
44
+ data_module.setup()
45
+
46
+ # Get the list of dataloaders for both train and validation, with dataset names
47
+ trainloaders = {dataset_name: loader for dataset_name, loader in zip(config.datasets, data_module.train_dataloader())}
48
+ vallowers = {dataset_name: loader for dataset_name, loader in zip(config.datasets, data_module.val_dataloader())}
49
+
50
+ # Combine multiple loaders using CombinedLoader, now with dataset names
51
+ combined_train_loader = CombinedLoader(trainloaders, mode="max_size")
52
+ combined_val_loader = CombinedLoader(vallowers, mode="max_size")
53
+
54
+ latest_version = get_latest_version(log_base_dir)
55
+ next_version = int(latest_version.split('_')[-1]) + 1 if latest_version else 0
56
+ next_version = f"version_{next_version}"
57
+
58
+ val_epoch_file = os.path.join(log_base_dir, latest_version, 'val_epoch.txt')
59
+
60
+ model = MusicClassifier( config, output_file = val_epoch_file)
61
+
62
+ if is_mt:
63
+ checkpoint_callback_mood = ModelCheckpoint(**config.checkpoint_mood)
64
+ checkpoint_callback_va = ModelCheckpoint(**config.checkpoint_va)
65
+ early_stop_callback = EarlyStopping(**config.earlystopping)
66
+
67
+ if config.model.kd == True:
68
+ trainer = pl.Trainer(
69
+ **config.trainer,
70
+ strategy=DDPStrategy(find_unused_parameters=True),
71
+ callbacks=[checkpoint_callback_mood, checkpoint_callback_va, early_stop_callback],
72
+ logger=logger,
73
+ num_sanity_val_steps=0
74
+ )
75
+ else:
76
+ trainer = pl.Trainer(
77
+ **config.trainer,
78
+ strategy=DDPStrategy(find_unused_parameters=False),
79
+ callbacks=[checkpoint_callback_mood, checkpoint_callback_va, early_stop_callback],
80
+ logger=logger,
81
+ num_sanity_val_steps=0
82
+ )
83
+
84
+ else:
85
+ checkpoint_callback = ModelCheckpoint(**config.checkpoint)
86
+ # early_stop_callback = EarlyStopping(**config.earlystopping)
87
+ trainer = pl.Trainer(
88
+ **config.trainer,
89
+ callbacks=[checkpoint_callback, early_stop_callback],
90
+ logger=logger,
91
+ num_sanity_val_steps = 0
92
+ )
93
+
94
+ trainer.fit(model, combined_train_loader, combined_val_loader)
95
+
96
+ if trainer.global_rank == 0:
97
+ best_checkpoint_file = os.path.join(train_log_dir, 'best_checkpoint.txt')
98
+ with open(best_checkpoint_file, 'w') as f:
99
+ if is_mt:
100
+ f.write(f"Best checkpoint (mood): {checkpoint_callback_mood.best_model_path}\n")
101
+ f.write(f"Best checkpoint (va): {checkpoint_callback_va.best_model_path}\n")
102
+ else:
103
+ f.write(f"Best checkpoint: {checkpoint_callback.best_model_path}\n")
104
+ f.write(f"Version: {logger.version}\n")
105
+
106
+ if __name__ == '__main__':
107
+ main()
trainer.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import pytorch_lightning as pl
6
+ from sklearn import metrics
7
+ from transformers import AutoModelForAudioClassification
8
+ import numpy as np
9
+ from collections import OrderedDict
10
+ from torchmetrics import MeanMetric, MaxMetric, Accuracy
11
+ import torchmetrics.functional as tmf
12
+
13
+ from model.linear import FeedforwardModel
14
+ from model.linear_small import FeedforwardModelSmall
15
+ from model.linear_attn_ck import FeedforwardModelAttnCK
16
+ from model.linear_mt import FeedforwardModelMT
17
+ from model.linear_mt_attn_ck import FeedforwardModelMTAttnCK
18
+
19
+ import logging
20
+ import yaml
21
+ from omegaconf import DictConfig
22
+
23
+ import torch
24
+ from torch.distributed import all_gather, get_world_size
25
+ # from lion_pytorch import Lion
26
+ from torch_optimizer import RAdam
27
+
28
+ def gather_all_results(tensor):
29
+ """
30
+ Gather tensors from all GPUs in distributed training.
31
+ """
32
+ gathered_tensors = [torch.zeros_like(tensor) for _ in range(get_world_size())]
33
+ all_gather(gathered_tensors, tensor)
34
+ return torch.cat(gathered_tensors, dim=0)
35
+
36
+ # torch.set_float32_matmul_precision('medium')
37
+
38
+ log = logging.getLogger(__name__)
39
+ class MusicClassifier(pl.LightningModule):
40
+ def __init__(self, cfg: DictConfig, output_file = None):
41
+ super(MusicClassifier, self).__init__()
42
+ self.cfg = cfg
43
+ self.encoder = cfg.model.encoder
44
+ self.classifier = cfg.model.classifier
45
+ self.lr = cfg.model.lr
46
+ self.output_file = output_file
47
+ self.kd = cfg.model.kd
48
+ self.kd_weight = cfg.model.kd_weight
49
+ self.kd_temperature = self.cfg.model.kd_temperature
50
+
51
+ layer_size = len(self.cfg.model.layers)
52
+ mert_dim = 768 * layer_size
53
+
54
+ self.feature_dim_dict = {
55
+ "MERT": mert_dim
56
+ }
57
+
58
+ encoders = self.encoder.split("-")
59
+ self.input_size = sum(self.feature_dim_dict[encoder] for encoder in encoders)
60
+ self.num_datasets = len(self.cfg.datasets)
61
+
62
+ if "mt" in self.classifier:
63
+ if self.num_datasets < 2:
64
+ raise Exception("Error: Dataset size >= 2 needed for MT classifier")
65
+ classifiers = {
66
+ "linear-mt-attn-ck": FeedforwardModelMTAttnCK,
67
+ }
68
+ if self.classifier in classifiers:
69
+ self.model = classifiers[self.classifier](
70
+ input_size=self.input_size,
71
+ output_size_classification=56,
72
+ output_size_regression=2
73
+ )
74
+ else:
75
+ raise Exception(f"Unknown classifier: {self.classifier}")
76
+ else:
77
+ if self.num_datasets >= 2:
78
+ raise Exception(f"Error: Dataset size == 1 needed for classifier")
79
+ dataset_name = self.cfg.datasets[0]
80
+ self.output_size = self.cfg.dataset[dataset_name].output_size
81
+ classifiers = {
82
+ "linear": FeedforwardModel,
83
+ "linear-attn-ck": FeedforwardModelAttnCK
84
+ }
85
+
86
+ if self.classifier in classifiers:
87
+ self.model = classifiers[self.classifier](input_size=self.input_size, output_size=self.output_size)
88
+ else:
89
+ raise Exception(f"Unknown classifier: {self.classifier}")
90
+
91
+
92
+ if self.kd:
93
+ self.teacher_models = {}
94
+
95
+ for dataset in self.cfg.datasets:
96
+ self.output_size = self.cfg.dataset[dataset].output_size
97
+ teacher_model_path = getattr(self.cfg, f"checkpoint_{dataset}", None)
98
+
99
+ if teacher_model_path:
100
+ # Create a new teacher model instance
101
+ teacher_model = FeedforwardModelAttnCK(
102
+ input_size=self.input_size,
103
+ output_size=self.output_size,
104
+ )
105
+
106
+ # Load the checkpoint
107
+ checkpoint = torch.load(teacher_model_path, map_location=self.device, weights_only=False)
108
+ state_dict = checkpoint["state_dict"]
109
+
110
+ # Adjust the keys in the state_dict
111
+ state_dict = {key.replace("model.", ""): value for key, value in state_dict.items()}
112
+
113
+ # Filter state_dict to match model's keys
114
+ model_keys = set(teacher_model.state_dict().keys())
115
+ filtered_state_dict = {key: value for key, value in state_dict.items() if key in model_keys}
116
+
117
+ # Load the filtered state_dict and set the model to evaluation mode
118
+ teacher_model.load_state_dict(filtered_state_dict)
119
+ teacher_model.to(self.device)
120
+
121
+ teacher_model.eval()
122
+
123
+ # Store the teacher model in the dictionary with the dataset name as the key
124
+ self.teacher_models[dataset] = teacher_model
125
+
126
+ probas = torch.from_numpy(np.load("dataset/jamendo/meta/probas_train.npy"))
127
+ pos_weight = torch.tensor(1.) / probas
128
+ weight = torch.tensor(2.) / (torch.tensor(1.) + pos_weight)
129
+
130
+ self.loss_fn_classification = nn.BCEWithLogitsLoss(
131
+ pos_weight=pos_weight,reduction="mean",weight=weight
132
+ )
133
+ self.loss_fn_classification_eval = nn.BCEWithLogitsLoss(
134
+ pos_weight=pos_weight,reduction="none",weight=weight
135
+ )
136
+
137
+ self.loss_fn_regression = nn.MSELoss()
138
+
139
+ self.loss_kd = nn.KLDivLoss(reduction="batchmean")
140
+
141
+ self.prd_array = []
142
+ self.gt_array = []
143
+ self.song_array = []
144
+
145
+ self.prd_array_va = []
146
+ self.gt_array_va = []
147
+ self.song_array_va = []
148
+
149
+ self.validation_predictions = []
150
+ self.validation_targets = []
151
+ self.validation_results = {'preds': [], 'gt': []}
152
+
153
+ self.trn_loss = MeanMetric()
154
+ self.val_loss = MeanMetric()
155
+
156
+ def forward(self, model_input_dic, output_idx = 0):
157
+ if "mt" in self.classifier:
158
+ classification_output, regression_output = self.model(model_input_dic)
159
+ if output_idx == 0:
160
+ return classification_output
161
+ elif output_idx == 1:
162
+ return regression_output
163
+ elif output_idx == 2:
164
+ return classification_output, regression_output
165
+ else:
166
+ output = self.model(model_input_dic)
167
+ return output
168
+
169
+ def compute_classification_loss(self, model_input_dic, y_mood):
170
+ classification_logits = self(model_input_dic, 0)
171
+ loss= self.loss_fn_classification(classification_logits, y_mood)
172
+ return loss
173
+
174
+ def compute_regression_loss(self, model_input_dic, y_va):
175
+ regression_output = self(model_input_dic, 1)
176
+ loss = self.loss_fn_regression(regression_output, y_va)
177
+ return loss
178
+
179
+ def compute_mt_loss(self, model_input_dic, y_mood, y_va):
180
+ classification_logits, regression_output = self(model_input_dic, 2)
181
+ loss_classification = self.loss_fn_classification(classification_logits, y_mood)
182
+ loss_regression = self.loss_fn_regression(regression_output, y_va)
183
+ return loss_classification, loss_regression
184
+
185
+
186
+ def compute_kd_loss(self, model_input_dic, y_mood, y_va, dataset_name):
187
+ """
188
+ Compute knowledge distillation loss for a given dataset.
189
+ """
190
+ # Forward pass through student model
191
+ s_logits_mood, s_logits_va = self(model_input_dic, 2)
192
+
193
+ # Compute student losses
194
+ s_loss_mood = self.loss_fn_classification(s_logits_mood, y_mood)
195
+ s_loss_va = self.loss_fn_regression(s_logits_va, y_va)
196
+
197
+ # Get the corresponding teacher model for the dataset
198
+ teacher_model = self.teacher_models.get(dataset_name)
199
+ teacher_model.to(self.device)
200
+
201
+ # Ensure teacher model exists
202
+ if teacher_model is None:
203
+ raise ValueError(f"No teacher model found for dataset: {dataset_name}")
204
+
205
+ with torch.no_grad():
206
+ # Forward pass through teacher model
207
+ t_logits = teacher_model(model_input_dic)
208
+
209
+ # Compute knowledge distillation losses
210
+ t_probs = torch.softmax(t_logits / self.kd_temperature, dim=-1)
211
+ if dataset_name == "jamendo":
212
+ s_probs_mood = torch.log_softmax(s_logits_mood / self.kd_temperature, dim=-1)
213
+ kd_loss = self.loss_kd(s_probs_mood, t_probs)
214
+ else:
215
+ s_probs_va = torch.log_softmax(s_logits_va / self.kd_temperature, dim=-1)
216
+ kd_loss = self.loss_kd(s_probs_va, t_probs)
217
+
218
+ return kd_loss, s_loss_mood, s_loss_va
219
+
220
+ def handle_dataset(self, dataset_name, batch, losses, total_loss, stage):
221
+ dataset_batch = batch[dataset_name]
222
+
223
+ model_input_dic = {}
224
+ model_input_dic["x_mert"] = dataset_batch["x_mert"]
225
+ model_input_dic["x_chord"] = dataset_batch["x_chord"]
226
+ model_input_dic["x_chord_root"] = dataset_batch["x_chord_root"]
227
+ model_input_dic["x_chord_attr"] = dataset_batch["x_chord_attr"]
228
+ model_input_dic["x_key"] = dataset_batch["x_key"]
229
+
230
+ if "mt" in self.classifier:
231
+ if dataset_name == "jamendo":
232
+ y_mood = dataset_batch["y_mood"]
233
+ y_va = dataset_batch["y_va"]
234
+ if self.kd:
235
+ kd_loss, s_loss_mood, s_loss_va = self.compute_kd_loss(model_input_dic, y_mood, y_va, dataset_name)
236
+ if stage == "train":
237
+ losses['loss_mood'] = s_loss_mood
238
+
239
+ total_loss += self.kd_weight * kd_loss + (1 - self.kd_weight) * s_loss_mood
240
+ else:
241
+ losses['loss_mood'] = s_loss_mood
242
+ total_loss += s_loss_mood
243
+ else:
244
+ s_loss_mood, s_loss_va = self.compute_mt_loss(model_input_dic, y_mood, y_va)
245
+ if stage == "train":
246
+ losses['loss_mood'] = s_loss_mood
247
+ total_loss += s_loss_mood
248
+ else:
249
+ losses['loss_mood'] = s_loss_mood
250
+ total_loss += s_loss_mood
251
+ else:
252
+ y_mood = dataset_batch["y_mood"]
253
+ y_va = dataset_batch["y_va"]
254
+
255
+ if self.kd:
256
+ kd_loss, s_loss_mood, s_loss_va = self.compute_kd_loss(model_input_dic, y_mood, y_va, dataset_name)
257
+ if stage == "train":
258
+ losses['loss_va'] = s_loss_va
259
+ total_loss += self.kd_weight * kd_loss + (1 - self.kd_weight) * s_loss_va
260
+ else:
261
+ losses['loss_va'] = s_loss_va
262
+ total_loss += s_loss_va
263
+ else:
264
+ s_loss_mood, s_loss_va = self.compute_mt_loss(model_input_dic, y_mood, y_va)
265
+ if stage == "train":
266
+ losses['loss_va'] = s_loss_va
267
+ total_loss += s_loss_va
268
+ else:
269
+ losses['loss_va'] = s_loss_va
270
+ total_loss += s_loss_va
271
+ else:
272
+ if dataset_name == "jamendo":
273
+ y_mood = dataset_batch["y_mood"]
274
+ loss_classification = self.compute_classification_loss(model_input_dic, y_mood)
275
+ losses['loss_mood'] = loss_classification
276
+ total_loss += loss_classification
277
+ else:
278
+ y_va = dataset_batch["y_va"]
279
+ loss_regression = self.compute_regression_loss(model_input_dic, y_va)
280
+ losses['loss_va'] = loss_regression
281
+ total_loss += loss_regression
282
+
283
+ return total_loss
284
+
285
+ def training_step(self, batch, batch_idx):
286
+ total_loss = 0
287
+ losses = {}
288
+ datasets = ["jamendo", "deam", "emomusic", "pmemo"]
289
+
290
+ for dataset in datasets:
291
+ if dataset in batch and batch[dataset] is not None:
292
+ total_loss = self.handle_dataset(dataset, batch, losses, total_loss, "train")
293
+
294
+ batch_size = batch[next(iter(batch))]["x_mert"].size(0)
295
+
296
+ self.log('train_loss_mood', losses.get('loss_mood', 0), on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size)
297
+ self.log('train_loss_va', losses.get('loss_va', 0), on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size)
298
+ self.log('train_loss', total_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size)
299
+
300
+ return total_loss
301
+
302
+ def validation_step(self, batch, batch_idx):
303
+
304
+ total_loss = 0
305
+ losses = {}
306
+ datasets = ["jamendo", "deam", "emomusic", "pmemo"]
307
+
308
+ for dataset in datasets:
309
+ if dataset in batch and batch[dataset] is not None:
310
+ total_loss = self.handle_dataset(dataset, batch, losses, total_loss, "val")
311
+
312
+ batch_size = batch[next(iter(batch))]["x_mert"].size(0)
313
+
314
+ self.log('val_loss_mood', losses.get('loss_mood', 0), on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size)
315
+ self.log('val_loss_va', losses.get('loss_va', 0), on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size)
316
+ self.log('val_loss', total_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size)
317
+ return total_loss
318
+
319
+ def test_step(self, batch, batch_idx):
320
+ total_loss = 0
321
+ losses = {}
322
+ datasets = ["jamendo", "deam", "emomusic", "pmemo"]
323
+
324
+ for dataset in datasets:
325
+ if dataset in batch and batch[dataset] is not None:
326
+ dataset_batch = batch[dataset]
327
+
328
+ model_input_dic = {}
329
+ model_input_dic["x_mert"] = dataset_batch["x_mert"]
330
+ model_input_dic["x_chord"] = dataset_batch["x_chord"]
331
+ model_input_dic["x_chord_root"] = dataset_batch["x_chord_root"]
332
+ model_input_dic["x_chord_attr"] = dataset_batch["x_chord_attr"]
333
+
334
+ model_input_dic["x_key"] = dataset_batch["x_key"]
335
+
336
+ if dataset == "jamendo":
337
+ y_mood = dataset_batch["y_mood"]
338
+ classification_logits = self(model_input_dic, 0)
339
+
340
+ loss_classification = self.loss_fn_classification(classification_logits, y_mood)
341
+ total_loss += loss_classification
342
+
343
+ probs = torch.sigmoid(classification_logits)
344
+ if not hasattr(self, 'jamendo_results'):
345
+ self.jamendo_results = {'preds': [], 'gt': [], 'paths': []}
346
+
347
+ self.jamendo_results['preds'].extend(probs.detach().cpu().numpy())
348
+ self.jamendo_results['gt'].extend(y_mood.detach().cpu().numpy())
349
+ self.jamendo_results['paths'].extend(dataset_batch["path"])
350
+
351
+ losses['test_loss_mood'] = loss_classification
352
+
353
+ else: # Handle regression for all other datasets
354
+ if batch[dataset] is not None:
355
+ y_va = dataset_batch["y_va"]
356
+ regression_output = self(model_input_dic, 1)
357
+
358
+ loss_regression = self.loss_fn_regression(regression_output, y_va)
359
+ total_loss += loss_regression
360
+
361
+ # Track results separately for each dataset
362
+ if not hasattr(self, f'{dataset}_results'):
363
+ setattr(self, f'{dataset}_results', {'preds': [], 'gt': [], 'paths': []})
364
+
365
+ dataset_results = getattr(self, f'{dataset}_results')
366
+ dataset_results['preds'].extend(regression_output.detach().cpu().numpy())
367
+ dataset_results['gt'].extend(y_va.detach().cpu().numpy())
368
+ dataset_results['paths'].extend(dataset_batch["path"])
369
+
370
+ losses['test_loss_va'] = loss_regression
371
+
372
+ batch_size = batch[next(iter(batch))]["x_mert"].size(0)
373
+
374
+ # Log the classification and regression losses
375
+ self.log('test_loss_mood', losses.get('test_loss_mood', 0), on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size)
376
+ self.log('test_loss_va', losses.get('test_loss_va', 0), on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size)
377
+ self.log('test_loss', total_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size)
378
+
379
+ return total_loss
380
+
381
+ def on_test_end(self):
382
+ output_dic = {}
383
+
384
+ # Jamendo classification metrics (AUC and PR AUC)
385
+ if hasattr(self, 'jamendo_results') and self.jamendo_results['preds']:
386
+ roc_auc, pr_auc = self.get_auc(self.jamendo_results['preds'], self.jamendo_results['gt'])
387
+
388
+ roc_auc = roc_auc.item()
389
+ pr_auc = pr_auc.item()
390
+
391
+ log.info('*** Display ROC_AUC_MACRO scores (Jamendo) ***')
392
+ log.info(f"ROC_AUC_MACRO: {round(roc_auc, 4)}")
393
+ log.info(f"PR_AUC_MACRO: {round(pr_auc, 4)}")
394
+
395
+ if self.output_file is not None:
396
+ with open(self.output_file, 'a') as f:
397
+ f.write(f"ROC_AUC_MACRO (Jamendo): {round(roc_auc, 4)}\n")
398
+ f.write(f"PR_AUC_MACRO (Jamendo): {round(pr_auc, 4)}\n")
399
+
400
+ output_dic["test_roc_auc_jamendo"] = round(roc_auc, 4)
401
+ output_dic["test_pr_auc_jamendo"] = round(pr_auc, 4)
402
+
403
+ # Metrics for each regression dataset (DMDD, DEAM, EmoMusic, PMEmo)
404
+ for dataset in ["deam", "emomusic", "pmemo"]:
405
+ dataset_results = getattr(self, f'{dataset}_results', None)
406
+
407
+ if dataset_results and dataset_results['preds']:
408
+ preds = torch.tensor(np.array(dataset_results['preds']))
409
+ gts = torch.tensor(np.array(dataset_results['gt']))
410
+
411
+ # Assuming valence is the first column and arousal is the second
412
+ preds_valence = preds[:, 0]
413
+ preds_arousal = preds[:, 1]
414
+ gts_valence = gts[:, 0]
415
+ gts_arousal = gts[:, 1]
416
+
417
+ rmse = torch.sqrt(tmf.mean_squared_error(preds, gts))
418
+ r2 = tmf.r2_score(preds, gts)
419
+
420
+ # Calculate metrics for valence
421
+ rmse_valence = torch.sqrt(tmf.mean_squared_error(preds_valence, gts_valence))
422
+ r2_valence = tmf.r2_score(preds_valence, gts_valence)
423
+
424
+ # Calculate metrics for arousal
425
+ rmse_arousal = torch.sqrt(tmf.mean_squared_error(preds_arousal, gts_arousal))
426
+ r2_arousal = tmf.r2_score(preds_arousal, gts_arousal)
427
+
428
+ log.info(f'*** Display RMSE and R² scores ({dataset.upper()}) ***')
429
+ log.info(f"RMSE: {round(rmse.item(), 4)}")
430
+ log.info(f"R²: {round(r2.item(), 4)}")
431
+ log.info(f"Valence - RMSE: {round(rmse_valence.item(), 4)}, R²: {round(r2_valence.item(), 4)}")
432
+ log.info(f"Arousal - RMSE: {round(rmse_arousal.item(), 4)}, R²: {round(r2_arousal.item(), 4)}")
433
+
434
+ if self.output_file is not None:
435
+ with open(self.output_file, 'a') as f:
436
+ f.write(f"RMSE ({dataset.upper()}): {round(rmse.item(), 4)}\n")
437
+ f.write(f"R² ({dataset.upper()}): {round(r2.item(), 4)}\n")
438
+ f.write(f"Valence - RMSE ({dataset.upper()}): {round(rmse_valence.item(), 4)}\n")
439
+ f.write(f"Valence - R² ({dataset.upper()}): {round(r2_valence.item(), 4)}\n")
440
+ f.write(f"Arousal - RMSE ({dataset.upper()}): {round(rmse_arousal.item(), 4)}\n")
441
+ f.write(f"Arousal - R² ({dataset.upper()}): {round(r2_arousal.item(), 4)}\n")
442
+
443
+ output_dic[f"test_rmse_{dataset}"] = round(rmse.item(), 4)
444
+ output_dic[f"test_r2_{dataset}"] = round(r2.item(), 4)
445
+ output_dic[f"test_rmse_valence_{dataset}"] = round(rmse_valence.item(), 4)
446
+ output_dic[f"test_r2_valence_{dataset}"] = round(r2_valence.item(), 4)
447
+ output_dic[f"test_rmse_arousal_{dataset}"] = round(rmse_arousal.item(), 4)
448
+ output_dic[f"test_r2_arousal_{dataset}"] = round(r2_arousal.item(), 4)
449
+
450
+ # Clear results for each dataset
451
+ for dataset in ["jamendo", "deam", "emomusic", "pmemo"]:
452
+ if hasattr(self, f'{dataset}_results'):
453
+ getattr(self, f'{dataset}_results')['preds'].clear()
454
+ getattr(self, f'{dataset}_results')['gt'].clear()
455
+ getattr(self, f'{dataset}_results')['paths'].clear()
456
+
457
+ return output_dic
458
+
459
+ def configure_optimizers(self):
460
+ return torch.optim.Adam(self.parameters(), lr=self.lr)
461
+
462
+ def get_auc(self, prd_array, gt_array):
463
+ prd_array = np.array(prd_array)
464
+ gt_array = np.array(gt_array)
465
+
466
+ prd_tensor = torch.tensor(prd_array)
467
+ gt_tensor = torch.tensor(gt_array)
468
+
469
+ try:
470
+ roc_auc = tmf.auroc(prd_tensor, gt_tensor.int(), task='multilabel', num_labels = 56 , average='macro', num_classes=gt_tensor.size(1))
471
+ pr_auc = tmf.average_precision(prd_tensor, gt_tensor.int(), task='multilabel', num_labels = 56, average='macro', num_classes=gt_tensor.size(1))
472
+ except ValueError as e:
473
+ print(f"Error computing metrics: {e}")
474
+ roc_auc = None
475
+ pr_auc = None
476
+ return roc_auc, pr_auc
477
+
478
+