asigalov61's picture
Update app.py
0c2c10c verified
#================================================================
# https://huggingface.co/spaces/asigalov61/MIDI-Identification
#================================================================
import os
import hashlib
import time
import datetime
from pytz import timezone
import copy
from collections import Counter
import random
import statistics
import re
import gradio as gr
from huggingface_hub import InferenceClient
from datasets import load_dataset
import TMIDIX
#==========================================================================================================
HF_TOKEN = os.getenv('HF_TOKEN')
#==========================================================================================================
def format_table_data(data_string):
# Split the string into rows based on newlines
rows = data_string.strip().split("\n")
# Initialize a list to store the formatted data
formatted_data = []
for row in rows:
# Split each row into columns based on the separator '|' and strip extra spaces
columns = row.split("|")
formatted_row = [cell.strip() for cell in columns]
# Remove cells with only "-" symbols
formatted_row = [cell for cell in formatted_row if not all(char == '-' for char in cell)]
# Handle uneven rows by ensuring each row has the same number of columns
max_columns = max(len(columns) for columns in formatted_data) if formatted_data else len(columns)
while len(formatted_row) < max_columns:
formatted_row.append("") # Add empty strings to fill the row
formatted_data.append(formatted_row)
# Handle case where new rows have more columns than previous rows
max_columns = max(len(row) for row in formatted_data)
for row in formatted_data:
while len(row) < max_columns:
row.append("") # Add empty strings to fill the row
return formatted_data
#==========================================================================================================
MODELS = {'Mistral Nemo Instruct 2407': 'mistralai/Mistral-Nemo-Instruct-2407'
}
#==========================================================================================================
def ID_MIDI(input_midi, input_model):
print('*' * 70)
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
start_time = time.time()
print('=' * 70)
print('Loading MIDI...')
fn = os.path.basename(input_midi)
fn1 = fn.split('.')[0]
fdata = open(input_midi, 'rb').read()
input_midi_md5hash = hashlib.md5(fdata).hexdigest()
print('=' * 70)
print('Requested settings:')
print('=' * 70)
print('Input MIDI file name:', fn)
print('Input MIDI md5 hash', input_midi_md5hash)
print('Input model:', input_model)
print('=' * 70)
print('Processing MIDI...Please wait...')
#=======================================================
# START PROCESSING
new_midi_data = TMIDIX.score2midi(TMIDIX.midi2score(fdata))
new_midi_md5hash = hashlib.md5(new_midi_data).hexdigest()
print('New md5 hash:', new_midi_md5hash)
print('Done!')
print('=' * 70)
print('Processing...Please wait...')
output_str = 'None'
output_midi_records_count = 0
output_midi_src_dataset= 'Unknown'
output_midi_path_str = 'None'
raw_score = TMIDIX.midi2single_track_ms_score(fdata)
escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]
escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes, sort_drums_last=True)
output_midi_src_dataset = 'unknown'
output_midi_path_str = 'none'
if new_midi_md5hash in monster_midi_titles['md5_hashes_titles_dict']:
title = random.choice(monster_midi_titles['md5_hashes_titles_dict'][new_midi_md5hash]).split(' --- ')
song = title[0]
artist = title[1]
song_description = TMIDIX.escore_notes_to_text_description(escore_notes, song_name=song, artist_name=artist)
else:
song_description = TMIDIX.escore_notes_to_text_description(escore_notes)
if new_midi_md5hash in midid_md5_hashes:
midid_entry_idx = midid_md5_hashes.index(new_midi_md5hash)
MIDID_record = midid_dataset[midid_entry_idx]['midid']
output_midi_records_count = len(MIDID_record)
output_entry = random.choice(MIDID_record)
output_midi_src_dataset = output_entry[0]
output_midi_path_str = TMIDIX.clean_string(output_entry[1], regex=r'[^a-zA-Z0-9.() \n]')
client = InferenceClient(api_key=HF_TOKEN)
prompt = "Please create a summary table for a MIDI file based on the following keywords strings, best possible description and best possible summary fields. Please respond with the table only. Do not say anything else. Thank you."
data = 'Source MIDI dataset: ' + output_midi_src_dataset + '\n\n'
data += 'MIDI keywords strings:' + '\n'
data += output_midi_path_str + '\n\n'
data += 'Music description:' + '\n'
data += song_description
messages = [
{
"role": "user",
"content": prompt + "\n\n" + data
}
]
completion = client.chat.completions.create(
#model="Qwen/Qwen2.5-72B-Instruct",
model=MODELS[input_model],
messages=messages,
max_tokens=500
)
output_str = completion.choices[0].message['content']
output_table_data = format_table_data(output_str)
print('Done!')
print('=' * 70)
print('Original MIDI unique records count', output_midi_records_count)
print('Original MIDI dataset', output_midi_src_dataset)
print('Original MIDI path string', data)
print('=' * 70)
print(output_str)
print('=')
#========================================================
output_midi_md5 = str(new_midi_md5hash)
#========================================================
print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
print('-' * 70)
print('Req execution time:', (time.time() - start_time), 'sec')
print('*' * 70)
#========================================================
return output_midi_md5, output_midi_records_count, output_midi_src_dataset, data, output_table_data
#==========================================================================================================
if __name__ == "__main__":
PDT = timezone('US/Pacific')
print('=' * 70)
print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
print('=' * 70)
print('Loading MIDID database...')
midid_dataset = load_dataset("asigalov61/MIDID")['train']
midid_md5_hashes = midid_dataset['midi_hash']
print('Done!')
print('=' * 70)
print('Loading Monster MIDI titles database...')
monster_midi_titles = TMIDIX.Tegridy_Any_Pickle_File_Reader('Monster_MIDI_Titles_Database_CC_BY_NC_SA.pickle')
print('Done!')
print('=' * 70)
app = gr.Blocks()
with app:
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>MIDI Identification</h1>")
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Identify any MIDI in a comprehensive database of 2.32M+ MIDI records</h1>")
gr.Markdown("This is a demo for tegridy-tools, MIDID and Monster MIDI dataset\n\n"
"Please see [tegridy-tools](https://github.com/asigalov61/tegridy-tools), [MIDID](https://huggingface.co/datasets/asigalov61/MIDID) and [Monster MIDI Dataset](https://github.com/asigalov61/Monster-MIDI-Dataset) repos for more information\n\n"
)
gr.Markdown("## Upload your MIDI")
input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"], type="filepath")
input_model = gr.Dropdown(['Mistral Nemo Instruct 2407', 'Mistral Nemo Instruct 2407'],
value='Mistral Nemo Instruct 2407',
label='Select model'
)
submit = gr.Button("Identify MIDI", variant="primary")
gr.Markdown("## MIDI identification results")
output_midi_md5 = gr.Textbox(label="Monster MIDI dataset md5 hash")
output_midi_records_count = gr.Textbox(label="Original MIDI unique records count")
output_midi_src_dataset = gr.Textbox(label="Original MIDI dataset pretty name")
output_midi_path_str = gr.Textbox(label="Original MIDI raw path string")
output_MIDID_results_table = gr.Dataframe(label="MIDID database results table", wrap=True, col_count=(3, 'dynamic'))
run_event = submit.click(ID_MIDI, [input_midi,
input_model
],
[output_midi_md5,
output_midi_records_count,
output_midi_src_dataset,
output_midi_path_str,
output_MIDID_results_table
])
app.queue().launch()