English
music
emotion
music2emo / music2emo.py
kjysmu's picture
Upload 3 files
b036e59 verified
import os
import mir_eval
import pretty_midi as pm
from utils import logger
from utils.btc_model import BTC_model
# from preprocess.BTC.btc_model import *
from utils.transformer_modules import *
from utils.transformer_modules import _gen_timing_signal, _gen_bias_mask
from utils.hparams import HParams
from utils.mir_eval_modules import audio_file_to_features, idx2chord, idx2voca_chord, get_audio_paths, get_lab_paths
import argparse
import warnings
from music21 import converter
import os
from tqdm import tqdm
import json
import torch
import torchaudio
import torchaudio.transforms as T
import numpy as np
from omegaconf import DictConfig
import hydra
from hydra.utils import to_absolute_path
from transformers import Wav2Vec2FeatureExtractor, AutoModel
from utils.mert import FeatureExtractorMERT
from model.linear_mt_attn_ck import FeedforwardModelMTAttnCK
from pathlib import Path
import gradio as gr
import shutil
import warnings
import logging
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
# from gradio import Markdown
PITCH_CLASS = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
pitch_num_dic = {
'C': 0, 'C#': 1, 'D': 2, 'D#': 3, 'E': 4, 'F': 5,
'F#': 6, 'G': 7, 'G#': 8, 'A': 9, 'A#': 10, 'B': 11
}
minor_major_dic = {
'D-':'C#', 'E-':'D#', 'G-':'F#', 'A-':'G#', 'B-':'A#'
}
minor_major_dic2 = {
'Db':'C#', 'Eb':'D#', 'Gb':'F#', 'Ab':'G#', 'Bb':'A#'
}
shift_major_dic = {
'C': 0, 'C#': 1, 'D': 2, 'D#': 3, 'E': 4, 'F': 5,
'F#': 6, 'G': 7, 'G#': 8, 'A': 9, 'A#': 10, 'B': 11
}
shift_minor_dic = {
'A': 0, 'A#': 1, 'B': 2, 'C': 3, 'C#': 4, 'D': 5,
'D#': 6, 'E': 7, 'F': 8, 'F#': 9, 'G': 10, 'G#': 11,
}
flat_to_sharp_mapping = {
"Cb": "B",
"Db": "C#",
"Eb": "D#",
"Fb": "E",
"Gb": "F#",
"Ab": "G#",
"Bb": "A#"
}
segment_duration = 30
resample_rate = 24000
is_split = True
def normalize_chord(file_path, key, key_type='major'):
with open(file_path, 'r') as f:
lines = f.readlines()
if key == "None":
new_key = "C major"
shift = 0
else:
#print ("asdas",key)
if len(key) == 1:
key = key[0].upper()
else:
key = key[0].upper() + key[1:]
if key in minor_major_dic2:
key = minor_major_dic2[key]
shift = 0
if key_type == "major":
new_key = "C major"
shift = shift_major_dic[key]
else:
new_key = "A minor"
shift = shift_minor_dic[key]
converted_lines = []
for line in lines:
if line.strip(): # Skip empty lines
parts = line.split()
start_time = parts[0]
end_time = parts[1]
chord = parts[2] # The chord is in the 3rd column
if chord == "N":
newchordnorm = "N"
elif chord == "X":
newchordnorm = "X"
elif ":" in chord:
pitch = chord.split(":")[0]
attr = chord.split(":")[1]
pnum = pitch_num_dic [pitch]
new_idx = (pnum - shift)%12
newchord = PITCH_CLASS[new_idx]
newchordnorm = newchord + ":" + attr
else:
pitch = chord
pnum = pitch_num_dic [pitch]
new_idx = (pnum - shift)%12
newchord = PITCH_CLASS[new_idx]
newchordnorm = newchord
converted_lines.append(f"{start_time} {end_time} {newchordnorm}\n")
return converted_lines
def sanitize_key_signature(key):
return key.replace('-', 'b')
def resample_waveform(waveform, original_sample_rate, target_sample_rate):
if original_sample_rate != target_sample_rate:
resampler = T.Resample(original_sample_rate, target_sample_rate)
return resampler(waveform), target_sample_rate
return waveform, original_sample_rate
def split_audio(waveform, sample_rate):
segment_samples = segment_duration * sample_rate
total_samples = waveform.size(0)
segments = []
for start in range(0, total_samples, segment_samples):
end = start + segment_samples
if end <= total_samples:
segment = waveform[start:end]
segments.append(segment)
# In case audio length is shorter than segment length.
if len(segments) == 0:
segment = waveform
segments.append(segment)
return segments
class Music2emo:
def __init__(
self,
model_weights = "saved_models/J_all.ckpt"
):
use_cuda = torch.cuda.is_available()
self.device = torch.device("cuda" if use_cuda else "cpu")
self.feature_extractor = FeatureExtractorMERT(model_name='m-a-p/MERT-v1-95M', device=self.device, sr=resample_rate)
self.model_weights = model_weights
self.music2emo_model = FeedforwardModelMTAttnCK(
input_size= 768 * 2,
output_size_classification=56,
output_size_regression=2
)
checkpoint = torch.load(self.model_weights, map_location=self.device, weights_only=False)
state_dict = checkpoint["state_dict"]
# Adjust the keys in the state_dict
state_dict = {key.replace("model.", ""): value for key, value in state_dict.items()}
# Filter state_dict to match model's keys
model_keys = set(self.music2emo_model.state_dict().keys())
filtered_state_dict = {key: value for key, value in state_dict.items() if key in model_keys}
# Load the filtered state_dict and set the model to evaluation mode
self.music2emo_model.load_state_dict(filtered_state_dict)
self.music2emo_model.to(self.device)
self.music2emo_model.eval()
def predict(self, audio, threshold = 0.5):
feature_dir = Path("./temp_out")
output_dir = Path("./output")
current_dir = Path("./")
if feature_dir.exists():
shutil.rmtree(str(feature_dir))
if output_dir.exists():
shutil.rmtree(str(output_dir))
feature_dir.mkdir(parents=True)
output_dir.mkdir(parents=True)
warnings.filterwarnings('ignore')
logger.logging_verbosity(1)
# use_cuda = torch.cuda.is_available()
# device = torch.device("cuda" if use_cuda else "cpu")
mert_dir = feature_dir / "mert"
mert_dir.mkdir(parents=True)
# args = parser.parse_args()
# --- MERT feature extract ---
waveform, sample_rate = torchaudio.load(audio)
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0).unsqueeze(0)
waveform = waveform.squeeze()
waveform, sample_rate = resample_waveform(waveform, sample_rate, resample_rate)
if is_split:
segments = split_audio(waveform, sample_rate)
for i, segment in enumerate(segments):
segment_save_path = os.path.join(mert_dir, f"segment_{i}.npy")
self.feature_extractor.extract_features_from_segment(segment, sample_rate, segment_save_path)
else:
segment_save_path = os.path.join(mert_dir, f"segment_0.npy")
self.feature_extractor.extract_features_from_segment(waveform, sample_rate, segment_save_path)
embeddings = []
layers_to_extract = [5,6]
segment_embeddings = []
for filename in sorted(os.listdir(mert_dir)): # Sort files to ensure sequential order
file_path = os.path.join(mert_dir, filename)
if os.path.isfile(file_path) and filename.endswith('.npy'):
segment = np.load(file_path)
concatenated_features = np.concatenate(
[segment[:, layer_idx, :] for layer_idx in layers_to_extract], axis=1
)
concatenated_features = np.squeeze(concatenated_features) # Shape: 768 * 2 = 1536
segment_embeddings.append(concatenated_features)
segment_embeddings = np.array(segment_embeddings)
if len(segment_embeddings) > 0:
final_embedding_mert = np.mean(segment_embeddings, axis=0)
else:
final_embedding_mert = np.zeros((1536,))
final_embedding_mert = torch.from_numpy(final_embedding_mert)
final_embedding_mert.to(self.device)
# --- Chord feature extract ---
config = HParams.load("./inference/data/run_config.yaml")
config.feature['large_voca'] = True
config.model['num_chords'] = 170
model_file = './inference/data/btc_model_large_voca.pt'
idx_to_chord = idx2voca_chord()
model = BTC_model(config=config.model).to(self.device)
if os.path.isfile(model_file):
checkpoint = torch.load(model_file)
mean = checkpoint['mean']
std = checkpoint['std']
model.load_state_dict(checkpoint['model'])
audio_path = audio
audio_id = audio_path.split("/")[-1][:-4]
try:
feature, feature_per_second, song_length_second = audio_file_to_features(audio_path, config)
except:
logger.info("audio file failed to load : %s" % audio_path)
assert(False)
logger.info("audio file loaded and feature computation success : %s" % audio_path)
feature = feature.T
feature = (feature - mean) / std
time_unit = feature_per_second
n_timestep = config.model['timestep']
num_pad = n_timestep - (feature.shape[0] % n_timestep)
feature = np.pad(feature, ((0, num_pad), (0, 0)), mode="constant", constant_values=0)
num_instance = feature.shape[0] // n_timestep
start_time = 0.0
lines = []
with torch.no_grad():
model.eval()
feature = torch.tensor(feature, dtype=torch.float32).unsqueeze(0).to(self.device)
for t in range(num_instance):
self_attn_output, _ = model.self_attn_layers(feature[:, n_timestep * t:n_timestep * (t + 1), :])
prediction, _ = model.output_layer(self_attn_output)
prediction = prediction.squeeze()
for i in range(n_timestep):
if t == 0 and i == 0:
prev_chord = prediction[i].item()
continue
if prediction[i].item() != prev_chord:
lines.append(
'%.3f %.3f %s\n' % (start_time, time_unit * (n_timestep * t + i), idx_to_chord[prev_chord]))
start_time = time_unit * (n_timestep * t + i)
prev_chord = prediction[i].item()
if t == num_instance - 1 and i + num_pad == n_timestep:
if start_time != time_unit * (n_timestep * t + i):
lines.append('%.3f %.3f %s\n' % (start_time, time_unit * (n_timestep * t + i), idx_to_chord[prev_chord]))
break
save_path = os.path.join(feature_dir, os.path.split(audio_path)[-1].replace('.mp3', '').replace('.wav', '') + '.lab')
with open(save_path, 'w') as f:
for line in lines:
f.write(line)
# logger.info("label file saved : %s" % save_path)
# lab file to midi file
starts, ends, pitchs = list(), list(), list()
intervals, chords = mir_eval.io.load_labeled_intervals(save_path)
for p in range(12):
for i, (interval, chord) in enumerate(zip(intervals, chords)):
root_num, relative_bitmap, _ = mir_eval.chord.encode(chord)
tmp_label = mir_eval.chord.rotate_bitmap_to_root(relative_bitmap, root_num)[p]
if i == 0:
start_time = interval[0]
label = tmp_label
continue
if tmp_label != label:
if label == 1.0:
starts.append(start_time), ends.append(interval[0]), pitchs.append(p + 48)
start_time = interval[0]
label = tmp_label
if i == (len(intervals) - 1):
if label == 1.0:
starts.append(start_time), ends.append(interval[1]), pitchs.append(p + 48)
midi = pm.PrettyMIDI()
instrument = pm.Instrument(program=0)
for start, end, pitch in zip(starts, ends, pitchs):
pm_note = pm.Note(velocity=120, pitch=pitch, start=start, end=end)
instrument.notes.append(pm_note)
midi.instruments.append(instrument)
midi.write(save_path.replace('.lab', '.midi'))
tonic_signatures = ["A", "A#", "B", "C", "C#", "D", "D#", "E", "F", "F#", "G", "G#"]
mode_signatures = ["major", "minor"] # Major and minor modes
tonic_to_idx = {tonic: idx for idx, tonic in enumerate(tonic_signatures)}
mode_to_idx = {mode: idx for idx, mode in enumerate(mode_signatures)}
idx_to_tonic = {idx: tonic for tonic, idx in tonic_to_idx.items()}
idx_to_mode = {idx: mode for mode, idx in mode_to_idx.items()}
with open('inference/data/chord.json', 'r') as f:
chord_to_idx = json.load(f)
with open('inference/data/chord_inv.json', 'r') as f:
idx_to_chord = json.load(f)
idx_to_chord = {int(k): v for k, v in idx_to_chord.items()} # Ensure keys are ints
with open('inference/data/chord_root.json') as json_file:
chordRootDic = json.load(json_file)
with open('inference/data/chord_attr.json') as json_file:
chordAttrDic = json.load(json_file)
try:
midi_file = converter.parse(save_path.replace('.lab', '.midi'))
key_signature = str(midi_file.analyze('key'))
except Exception as e:
key_signature = "None"
key_parts = key_signature.split()
key_signature = sanitize_key_signature(key_parts[0]) # Sanitize key signature
key_type = key_parts[1] if len(key_parts) > 1 else 'major'
# --- Key feature (Tonic and Mode separation) ---
if key_signature == "None":
mode = "major"
else:
mode = key_signature.split()[-1]
encoded_mode = mode_to_idx.get(mode, 0)
mode_tensor = torch.tensor([encoded_mode], dtype=torch.long).to(self.device)
converted_lines = normalize_chord(save_path, key_signature, key_type)
lab_norm_path = save_path[:-4] + "_norm.lab"
# Write the converted lines to the new file
with open(lab_norm_path, 'w') as f:
f.writelines(converted_lines)
chords = []
if not os.path.exists(lab_norm_path):
chords.append((float(0), float(0), "N"))
else:
with open(lab_norm_path, 'r') as file:
for line in file:
start, end, chord = line.strip().split()
chords.append((float(start), float(end), chord))
encoded = []
encoded_root= []
encoded_attr=[]
durations = []
for start, end, chord in chords:
chord_arr = chord.split(":")
if len(chord_arr) == 1:
chordRootID = chordRootDic[chord_arr[0]]
if chord_arr[0] == "N" or chord_arr[0] == "X":
chordAttrID = 0
else:
chordAttrID = 1
elif len(chord_arr) == 2:
chordRootID = chordRootDic[chord_arr[0]]
chordAttrID = chordAttrDic[chord_arr[1]]
encoded_root.append(chordRootID)
encoded_attr.append(chordAttrID)
if chord in chord_to_idx:
encoded.append(chord_to_idx[chord])
else:
print(f"Warning: Chord {chord} not found in chord.json. Skipping.")
durations.append(end - start) # Compute duration
encoded_chords = np.array(encoded)
encoded_chords_root = np.array(encoded_root)
encoded_chords_attr = np.array(encoded_attr)
# Maximum sequence length for chords
max_sequence_length = 100 # Define this globally or as a parameter
# Truncate or pad chord sequences
if len(encoded_chords) > max_sequence_length:
# Truncate to max length
encoded_chords = encoded_chords[:max_sequence_length]
encoded_chords_root = encoded_chords_root[:max_sequence_length]
encoded_chords_attr = encoded_chords_attr[:max_sequence_length]
else:
# Pad with zeros (padding value for chords)
padding = [0] * (max_sequence_length - len(encoded_chords))
encoded_chords = np.concatenate([encoded_chords, padding])
encoded_chords_root = np.concatenate([encoded_chords_root, padding])
encoded_chords_attr = np.concatenate([encoded_chords_attr, padding])
# Convert to tensor
chords_tensor = torch.tensor(encoded_chords, dtype=torch.long).to(self.device)
chords_root_tensor = torch.tensor(encoded_chords_root, dtype=torch.long).to(self.device)
chords_attr_tensor = torch.tensor(encoded_chords_attr, dtype=torch.long).to(self.device)
model_input_dic = {
"x_mert": final_embedding_mert.unsqueeze(0),
"x_chord": chords_tensor.unsqueeze(0),
"x_chord_root": chords_root_tensor.unsqueeze(0),
"x_chord_attr": chords_attr_tensor.unsqueeze(0),
"x_key": mode_tensor.unsqueeze(0)
}
model_input_dic = {k: v.to(self.device) for k, v in model_input_dic.items()}
classification_output, regression_output = self.music2emo_model(model_input_dic)
probs = torch.sigmoid(classification_output)
tag_list = np.load ( "./inference/data/tag_list.npy")
tag_list = tag_list[127:]
mood_list = [t.replace("mood/theme---", "") for t in tag_list]
threshold = threshold
predicted_moods = [mood_list[i] for i, p in enumerate(probs.squeeze().tolist()) if p > threshold]
# Print the results
# print("Predicted Mood Tags:", predicted_moods)
valence, arousal = regression_output.squeeze().tolist()
# Print results
# print("\n🎵 **Music Emotion Recognition Results** 🎵")
# print("-" * 50)
# print(f"🎭 **Predicted Mood Tags:** {', '.join(predicted_moods) if predicted_moods else 'None'}")
# print(f"💖 **Valence:** {valence:.2f} (Scale: 1-9)")
# print(f"⚡ **Arousal:** {arousal:.2f} (Scale: 1-9)")
# print("-" * 50)
# self.model.eval()
# self.modelReg.eval()
# with torch.set_grad_enabled(False):
# f_path_midi = output_dir / "output.mid"
# f_path_flac = output_dir / "output.flac"
# f_path_video_out = output_dir / "output.mp4"
model_output_dic = {
"valence": valence,
"arousal": arousal,
"predicted_moods": predicted_moods
}
return model_output_dic