# https://huggingface.co/spaces/asigalov61/MIDI-Search import os import time as reqtime import datetime from pytz import timezone from sentence_transformers import SentenceTransformer from sentence_transformers import util import numpy as np from datasets import load_dataset import gradio as gr import copy import random import pickle import zlib from midi_to_colab_audio import midi_to_colab_audio import TMIDIX import matplotlib.pyplot as plt #========================================================================================================== def find_midi(title, artist): print('=' * 70) print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) start_time = reqtime.time() print('-' * 70) print('Req title:', title) print('Req artist:', artist) print('-' * 70) input_text = '' if title != '': input_text += title if artist != '': input_text += ' by ' + artist print('Searching...') query_embedding = model.encode([input_text]) # Compute cosine similarity between query and each sentence in the corpus similarities = util.cos_sim(query_embedding, corpus_embeddings) top_ten_matches_idxs = np.argsort(-similarities)[0][:10].tolist() # Find the index of the most similar sentence closest_index = np.argmax(similarities) closest_index_match_ratio = max(similarities[0]).tolist() best_corpus_match = all_MIDI_files_names[closest_index] top_ten_matches = '' for t in top_ten_matches_idxs: top_ten_matches += str(all_MIDI_files_names[t][0]).title() + '\n' print('Done!') print('=' * 70) print('Match corpus index', closest_index) print('Match corpus ratio', closest_index_match_ratio) print('=' * 70) print('Done!') print('=' * 70) song_artist = best_corpus_match[0] song_artist_title = str(song_artist).title() zlib_file_name = best_corpus_match[1] print('Fetching MIDI score...') with open(zlib_file_name, 'rb') as f: compressed_data = f.read() # Decompress the data decompressed_data = zlib.decompress(compressed_data) # Convert the bytes back to a list using pickle scores_data = pickle.loads(decompressed_data) fnames = [f[0] for f in scores_data] fnameidx = fnames.index(song_artist) MIDI_score_data = scores_data[fnameidx][1] print('Rendering results...') print('=' * 70) print('MIDi Title:', song_artist_title) print('Sample INTs', MIDI_score_data[:12]) print('=' * 70) if len(MIDI_score_data) != 0: song = MIDI_score_data song_f = [] time = 0 dur = 0 vel = 90 pitch = 0 channel = 0 patches = [-1] * 16 channels = [0] * 16 channels[9] = 1 for ss in song: if 0 <= ss < 256: time += ss * 16 if 256 <= ss < 512: dur = (ss-256) * 16 if 512 <= ss <= 640: patch = (ss-512) if patch < 128: if patch not in patches: if 0 in channels: cha = channels.index(0) channels[cha] = 1 else: cha = 15 patches[cha] = patch channel = patches.index(patch) else: channel = patches.index(patch) if patch == 128: channel = 9 if 640 < ss < 768: ptc = (ss-640) if 768 < ss < 896: vel = (ss - 768) song_f.append(['note', time, dur, channel, ptc, vel, patch ]) patches = [0 if x==-1 else x for x in patches] print('=' * 70) #=============================================================================== output_score, patches, overflow_patches = TMIDIX.patch_enhanced_score_notes(song_f) detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(output_score, output_signature = 'Los Angeles MIDI Dataset Search', output_file_name = song_artist_title, track_name='Project Los Angeles', list_of_MIDI_patches=patches ) new_fn = song_artist_title + '.mid' audio = midi_to_colab_audio(new_fn, soundfont_path=soundfont, sample_rate=16000, volume_scale=10, output_for_gradio=True ) print('Done!') print('=' * 70) #======================================================== output_midi_title = str(song_artist_title) output_midi_summary = str(top_ten_matches) output_midi = str(new_fn) output_audio = (16000, audio) output_plot = TMIDIX.plot_ms_SONG(output_score, plot_title=output_midi_title, return_plt=True) print('Output MIDI file name:', output_midi) print('Output MIDI title:', output_midi_title) print('Output MIDI summary:', output_midi_summary) print('=' * 70) #======================================================== print('-' * 70) print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) print('-' * 70) print('Req execution time:', (reqtime.time() - start_time), 'sec') return output_midi_title, output_midi_summary, output_midi, output_audio, output_plot #========================================================================================================== if __name__ == "__main__": PDT = timezone('US/Pacific') print('=' * 70) print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) print('=' * 70) soundfont = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2" print('Loading MidiCaps dataset...') mc_dataset = load_dataset("amaai-lab/MidiCaps") print('=' * 70) print('Loading files list...') all_MIDI_files_names = TMIDIX.Tegridy_Any_Pickle_File_Reader('LAKH_all_files_names') print('=' * 70) print('Loading MIDI corpus embeddings...') corpus_embeddings = np.load('MIDI_corpus_embeddings_all-MiniLM-L6-v2.npz')['data'] print('Done!') print('=' * 70) print('Loading Sentence Transformer model...') model = SentenceTransformer('all-MiniLM-L6-v2') print('Done!') print('=' * 70) app = gr.Blocks() with app: gr.Markdown("

LAKH MIDI Dataset Search

") gr.Markdown("

Search and explore LAKH MIDI dataset with sentence transformer

") gr.Markdown("![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.LAKH-MIDI-Dataset-Search&style=flat)\n\n" "This is a demo for MidiCaps dataset\n\n" "Check out [MidiCaps Dataset](https://huggingface.co/datasets/amaai-lab/MidiCaps) on Hugging Face!\n\n" ) gr.Markdown("# Enter any desired title, artist or both\n\n") title = gr.Textbox(label="Song Title", value="Family Guy") artist = gr.Textbox(label="Song Artist", value="TV Themes") submit = gr.Button(value='Search') gr.ClearButton(components=[title, artist]) gr.Markdown("# Search results") output_midi_title = gr.Textbox(label="Output MIDI title") output_midi_summary = gr.Textbox(label="Top ten MIDI matches") output_audio = gr.Audio(label="Output MIDI audio", format="wav", elem_id="midi_audio") output_plot = gr.Plot(label="Output MIDI score plot") output_midi = gr.File(label="Output MIDI file", file_types=[".mid"]) run_event = submit.click(find_midi, [title, artist], [output_midi_title, output_midi_summary, output_midi, output_audio, output_plot ]) app.launch()