Spaces:
Running
Running
import gradio as gr | |
from transformers import Wav2Vec2FeatureExtractor | |
from transformers import AutoModel | |
import torch | |
from torch import nn | |
import torchaudio | |
import torchaudio.transforms as T | |
import logging | |
import json | |
import os | |
import re | |
import pandas as pd | |
import librosa | |
import importlib | |
modeling_MERT = importlib.import_module("MERT-v1-95M.modeling_MERT") | |
from Prediction_Head.MTGGenre_head import MLPProberBase | |
logger = logging.getLogger("MERT-v1-95M-app") | |
logger.setLevel(logging.INFO) | |
ch = logging.StreamHandler() | |
ch.setLevel(logging.INFO) | |
formatter = logging.Formatter( | |
"%(asctime)s;%(levelname)s;%(message)s", "%Y-%m-%d %H:%M:%S") | |
ch.setFormatter(formatter) | |
logger.addHandler(ch) | |
title = "One Model for All Music Understanding Tasks" | |
description = "An example of using the [MERT-v1-95M](https://huggingface.co/m-a-p/MERT-v1-95M) model as backbone to conduct multiple music understanding tasks with the universal representation. \n Due the hardware limitation of the machine hosting this demo (2 CPU and 16GB RAM) only the first 4 seconds of audio are used!" | |
with open('./README.md', 'r') as f: | |
# skip the header | |
header_count = 0 | |
for line in f: | |
if '---' in line: | |
header_count += 1 | |
if header_count >= 2: | |
break | |
# read the rest conent | |
article = f.read() | |
df_init = pd.DataFrame(columns=['Task', 'Top 1', 'Top 2', 'Top 3', 'Top 4', 'Top 5']) | |
transcription_df = gr.DataFrame(value=df_init, label="Output Dataframe", row_count=( | |
0, "dynamic"), wrap=True) | |
outputs = transcription_df | |
# model = AutoModel.from_pretrained("m-a-p/MERT-v0-public", trust_remote_code=True) | |
# processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v0-public",trust_remote_code=True) | |
model = modeling_MERT.MERTModel.from_pretrained("./MERT-v1-95M") | |
processor = Wav2Vec2FeatureExtractor.from_pretrained("./MERT-v1-95M") | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
MERT_BEST_LAYER_IDX = { | |
'EMO': 5, | |
'GS': 8, | |
'GTZAN': 7, | |
'MTGGenre': 7, | |
'MTGInstrument': 'all', | |
'MTGMood': 6, | |
'MTGTop50': 6, | |
'MTT': 'all', | |
'NSynthI': 6, | |
'NSynthP': 1, | |
'VocalSetS': 2, | |
'VocalSetT': 9, | |
} | |
MERT_BEST_LAYER_IDX = { | |
'EMO': 5, | |
'GS': 8, | |
'GTZAN': 7, | |
'MTGGenre': 7, | |
'MTGInstrument': 'all', | |
'MTGMood': 6, | |
'MTGTop50': 6, | |
'MTT': 'all', | |
'NSynthI': 6, | |
'NSynthP': 1, | |
'VocalSetS': 2, | |
'VocalSetT': 9, | |
} | |
CLASSIFIERS = { | |
} | |
ID2CLASS = { | |
} | |
TASKS = ['GS', 'MTGInstrument', 'MTGGenre', 'MTGTop50', 'MTGMood', 'NSynthI', 'NSynthP', 'VocalSetS', 'VocalSetT','EMO',] | |
Regression_TASKS = ['EMO'] | |
head_dir = './Prediction_Head/best-layer-MERT-v1-95M' | |
for task in TASKS: | |
print('loading', task) | |
with open(os.path.join(head_dir,f'{task}.id2class.json'), 'r') as f: | |
ID2CLASS[task]=json.load(f) | |
num_class = len(ID2CLASS[task].keys()) | |
CLASSIFIERS[task] = MLPProberBase(d=768, layer=MERT_BEST_LAYER_IDX[task], num_outputs=num_class) | |
CLASSIFIERS[task].load_state_dict(torch.load(f'{head_dir}/{task}.ckpt')['state_dict']) | |
CLASSIFIERS[task].to(device) | |
model.to(device) | |
def model_inference(inputs): | |
waveform, sample_rate = torchaudio.load(inputs) | |
resample_rate = processor.sampling_rate | |
# make sure the sample_rate aligned | |
if resample_rate != sample_rate: | |
# print(f'setting rate from {sample_rate} to {resample_rate}') | |
resampler = T.Resample(sample_rate, resample_rate) | |
waveform = resampler(waveform) | |
#waveform = waveform.view(-1,) # make it (n_sample, ) | |
waveform = waveform[0][0:4*resample_rate] # cut to 4s samples | |
model_inputs = processor(waveform, sampling_rate=resample_rate, return_tensors="pt") | |
model_inputs.to(device) | |
with torch.no_grad(): | |
model_outputs = model(**model_inputs, output_hidden_states=True) | |
# take a look at the output shape, there are 13 layers of representation | |
# each layer performs differently in different downstream tasks, you should choose empirically | |
all_layer_hidden_states = torch.stack(model_outputs.hidden_states).squeeze()[1:,:,:].unsqueeze(0) | |
print(all_layer_hidden_states.shape) # [13 layer, Time steps, 768 feature_dim] | |
all_layer_hidden_states = all_layer_hidden_states.mean(dim=2) | |
task_output_texts = "" | |
df = pd.DataFrame(columns=['Task', 'Top 1', 'Top 2', 'Top 3', 'Top 4', 'Top 5']) | |
df_objects = [] | |
for task in TASKS: | |
num_class = len(ID2CLASS[task].keys()) | |
if MERT_BEST_LAYER_IDX[task] == 'all': | |
logits = CLASSIFIERS[task](all_layer_hidden_states) # [1, 87] | |
else: | |
logits = CLASSIFIERS[task](all_layer_hidden_states[:, MERT_BEST_LAYER_IDX[task]]) | |
# print(f'task {task} logits:', logits.shape, 'num class:', num_class) | |
sorted_idx = torch.argsort(logits, dim = -1, descending=True)[0] # batch =1 | |
sorted_prob,_ = torch.sort(nn.functional.softmax(logits[0], dim=-1), dim=-1, descending=True) | |
# print(sorted_prob) | |
# print(sorted_prob.shape) | |
top_n_show = 5 if num_class >= 5 else num_class | |
# task_output_texts = task_output_texts + f"TASK {task} output:\n" + "\n".join([str(ID2CLASS[task][str(sorted_idx[idx].item())])+f', probability: {sorted_prob[idx].item():.2%}' for idx in range(top_n_show)]) + '\n' | |
# task_output_texts = task_output_texts + '----------------------\n' | |
row_elements = [task] | |
for idx in range(top_n_show): | |
print(ID2CLASS[task]) | |
# print('id', str(sorted_idx[idx].item())) | |
output_class_name = str(ID2CLASS[task][str(sorted_idx[idx].item())]) | |
output_class_name = re.sub(r'^\w+---', '', output_class_name) | |
output_class_name = re.sub(r'^\w+\/\w+---', '', output_class_name) | |
# print('output name', output_class_name) | |
output_prob = f' {sorted_prob[idx].item():.2%}' | |
row_elements.append(output_class_name+output_prob) | |
# fill empty elment | |
for _ in range(5+1 - len(row_elements)): | |
row_elements.append(' ') | |
df_objects.append(row_elements) | |
df = pd.DataFrame(df_objects, columns=['Task', 'Top 1', 'Top 2', 'Top 3', 'Top 4', 'Top 5']) | |
return df | |
def convert_audio(inputs): | |
#audio_data, sample_rate = librosa.load(inputs, sr=None) | |
return model_inference(inputs) | |
def build_audio_flow(title, description, article): | |
audio_file = gr.File(label="Select Audio File (*.wav)") | |
demo = gr.Interface( | |
fn=convert_audio, | |
inputs=audio_file, | |
outputs=outputs, | |
allow_flagging="never", | |
title=title, | |
description=description, | |
article=article, | |
) | |
return demo | |
demo = build_audio_flow(title, description, article) | |
# demo.queue(concurrency_count=1, max_size=5) | |
demo.launch() | |