Spaces:
Running
Running
#================================================================ | |
# https://huggingface.co/spaces/asigalov61/MIDI-Identification | |
#================================================================ | |
import os | |
import hashlib | |
import time | |
import datetime | |
from pytz import timezone | |
import copy | |
from collections import Counter | |
import random | |
import statistics | |
import gradio as gr | |
from huggingface_hub import InferenceClient | |
import TMIDIX | |
#========================================================================================================== | |
HF_TOKEN = os.getenv('HF_TOKEN') | |
#========================================================================================================== | |
def format_table_data(data_string): | |
# Split the string into rows based on newlines | |
rows = data_string.strip().split("\n") | |
# Initialize a list to store the formatted data | |
formatted_data = [] | |
for row in rows: | |
# Split each row into columns based on the separator '|' and strip extra spaces | |
columns = row.split("|") | |
formatted_row = [cell.strip() for cell in columns] | |
# Remove cells with only "-" symbols | |
formatted_row = [cell for cell in formatted_row if not all(char == '-' for char in cell)] | |
# Handle uneven rows by ensuring each row has the same number of columns | |
max_columns = max(len(columns) for columns in formatted_data) if formatted_data else len(columns) | |
while len(formatted_row) < max_columns: | |
formatted_row.append("") # Add empty strings to fill the row | |
formatted_data.append(formatted_row) | |
# Handle case where new rows have more columns than previous rows | |
max_columns = max(len(row) for row in formatted_data) | |
for row in formatted_data: | |
while len(row) < max_columns: | |
row.append("") # Add empty strings to fill the row | |
return formatted_data | |
#========================================================================================================== | |
def ID_MIDI(input_midi): | |
print('*' * 70) | |
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) | |
start_time = time.time() | |
print('=' * 70) | |
print('Loading MIDI...') | |
fn = os.path.basename(input_midi) | |
fn1 = fn.split('.')[0] | |
fdata = open(input_midi, 'rb').read() | |
input_midi_md5hash = hashlib.md5(fdata).hexdigest() | |
print('=' * 70) | |
print('Requested settings:') | |
print('=' * 70) | |
print('Input MIDI file name:', fn) | |
print('Input MIDI md5 hash', input_midi_md5hash) | |
print('=' * 70) | |
print('Processing MIDI...Please wait...') | |
#======================================================= | |
# START PROCESSING | |
new_midi_data = TMIDIX.score2midi(TMIDIX.midi2score(fdata)) | |
new_midi_md5hash = hashlib.md5(new_midi_data).hexdigest() | |
print('New md5 hash:', new_midi_md5hash) | |
print('Done!') | |
print('=' * 70) | |
print('Processing...Please wait...') | |
output_str = 'None' | |
output_midi_records_count = 0 | |
output_midi_src_dataset= 'Unknown' | |
output_midi_path_str = 'None' | |
if new_midi_md5hash in MIDID_database: | |
client = InferenceClient(api_key=HF_TOKEN) | |
prompt = "Please create a summary table for a MIDI file based on the following keywords strings, best possible description and best possible summary fields. Please respond with the table only. Do not say anything else. Thank you." | |
output_midi_records_count = len(MIDID_database[new_midi_md5hash]) | |
output_entry = random.choice(MIDID_database[new_midi_md5hash]) | |
output_midi_src_dataset = output_entry['midi_dataset'] | |
output_midi_path_str = output_entry['midi_path'] | |
data = 'Source MIDI dataset: ' + output_midi_src_dataset + '\n' + output_midi_path_str | |
messages = [ | |
{ | |
"role": "user", | |
"content": prompt + "\n\n" + data | |
} | |
] | |
completion = client.chat.completions.create( | |
#model="Qwen/Qwen2.5-72B-Instruct", | |
model="mistralai/Mistral-Nemo-Instruct-2407", | |
messages=messages, | |
max_tokens=500 | |
) | |
output_str = completion.choices[0].message['content'] | |
output_table_data = format_table_data(output_str) | |
else: | |
output_table_data = [['No matching MIDI ID records found', 'Unknown MIDI', 'Sorry :(']] | |
print('Done!') | |
print('=' * 70) | |
print('Original MIDI unique records count', output_midi_records_count) | |
print('Original MIDI dataset', output_midi_src_dataset) | |
print('Original MIDI path string', output_midi_path_str) | |
print('=' * 70) | |
print(output_str) | |
print('=') | |
#======================================================== | |
output_midi_md5 = str(input_midi_md5hash) | |
#======================================================== | |
print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) | |
print('-' * 70) | |
print('Req execution time:', (time.time() - start_time), 'sec') | |
print('*' * 70) | |
#======================================================== | |
return output_midi_md5, output_midi_records_count, output_midi_src_dataset, output_midi_path_str, output_table_data | |
#========================================================================================================== | |
if __name__ == "__main__": | |
PDT = timezone('US/Pacific') | |
print('=' * 70) | |
print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) | |
print('=' * 70) | |
print('Loading MIDID database...') | |
MIDID_database = TMIDIX.Tegridy_Any_Pickle_File_Reader('MIDID_Basic_Database_CC_BY_NC_SA.pickle') | |
print('=' * 70) | |
app = gr.Blocks() | |
with app: | |
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>MIDI Identification</h1>") | |
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Identify any MIDI in a comprehensive database of 1.42M+ MIDI records</h1>") | |
gr.Markdown("This is a demo for tegridy-tools and Monster MIDI dataset\n\n" | |
"Please see [tegridy-tools](https://github.com/asigalov61/tegridy-tools) and [Monster MIDI Dataset](https://github.com/asigalov61/Monster-MIDI-Dataset)GitHub repos for more information\n\n" | |
) | |
gr.Markdown("## Upload your MIDI") | |
input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"], type="filepath") | |
submit = gr.Button("Identify MIDI", variant="primary") | |
gr.Markdown("## MIDI identification results") | |
output_midi_md5 = gr.Textbox(label="Monster MIDI dataset md5 hash") | |
output_midi_records_count = gr.Textbox(label="Original MIDI unique records count") | |
output_midi_src_dataset = gr.Textbox(label="Original MIDI dataset pretty name") | |
output_midi_path_str = gr.Textbox(label="Original MIDI raw path string") | |
output_MIDID_results_table = gr.Dataframe(label="MIDID database results table", wrap=True, col_count=(3, 'dynamic')) | |
run_event = submit.click(ID_MIDI, [input_midi, | |
], | |
[output_midi_md5, | |
output_midi_records_count, | |
output_midi_src_dataset, | |
output_midi_path_str, | |
output_MIDID_results_table | |
]) | |
app.queue().launch() |