NotaGen / app.py
ElectricAlexis's picture
Update app.py
272ffb6 verified
# import spaces
# import zero
import gradio as gr
import sys
import threading
import queue
import time
import random
from io import TextIOBase
import datetime
import subprocess
import os
from inference import postprocess_inst_names
from inference import inference_patch
from convert import abc2xml, xml2, pdf2img
title_html = """
<div class="title-container">
<h1 class="title-text">NotaGen</h1> &nbsp;
<!-- ArXiv -->
<a href="https://arxiv.org/abs/2502.18008">
<img src="https://img.shields.io/badge/NotaGen_Paper-ArXiv-%23B31B1B?logo=arxiv&logoColor=white" alt="Paper">
</a>
&nbsp;
<!-- GitHub -->
<a href="https://github.com/ElectricAlexis/NotaGen">
<img src="https://img.shields.io/badge/NotaGen_Code-GitHub-%23181717?logo=github&logoColor=white" alt="GitHub">
</a>
&nbsp;
<!-- HuggingFace -->
<a href="https://huggingface.co/ElectricAlexis/NotaGen">
<img src="https://img.shields.io/badge/NotaGen_Weights-HuggingFace-%23FFD21F?logo=huggingface&logoColor=white" alt="Weights">
</a>
&nbsp;
<!-- Web Demo -->
<a href="https://electricalexis.github.io/notagen-demo/">
<img src="https://img.shields.io/badge/NotaGen_Demo-Web-%23007ACC?logo=google-chrome&logoColor=white" alt="Demo">
</a>
</div>
<bp>
<p style="font-size: 1.2em;">NotaGen is a model for generating sheet music in ABC notation format. Select a period, composer, and instrumentation to generate classical-style music!</p>
"""
# Read prompt combinations
with open('prompts.txt', 'r') as f:
prompts = f.readlines()
valid_combinations = set()
for prompt in prompts:
prompt = prompt.strip()
parts = prompt.split('_')
valid_combinations.add((parts[0], parts[1], parts[2]))
# Prepare dropdown options
periods = sorted({p for p, _, _ in valid_combinations})
composers = sorted({c for _, c, _ in valid_combinations})
instruments = sorted({i for _, _, i in valid_combinations})
# Dynamically update composer and instrument dropdown options
def update_components(period, composer):
if not period:
return [
gr.update(choices=[], value=None, interactive=False),
gr.update(choices=[], value=None, interactive=False)
]
valid_composers = sorted({c for p, c, _ in valid_combinations if p == period})
valid_instruments = sorted({i for p, c, i in valid_combinations if p == period and c == composer}) if composer else []
return [
gr.update(
choices=valid_composers,
value=composer if composer in valid_composers else None,
interactive=True
),
gr.update(
choices=valid_instruments,
value=None,
interactive=bool(valid_instruments)
)
]
# Custom realtime stream for outputting model inference process to frontend
class RealtimeStream(TextIOBase):
def __init__(self, queue):
self.queue = queue
def write(self, text):
self.queue.put(text)
return len(text)
def convert_files(abc_content, period, composer, instrumentation):
if not all([period, composer, instrumentation]):
raise gr.Error("Please complete a valid generation first before saving")
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
prompt_str = f"{period}_{composer}_{instrumentation}"
filename_base = f"{timestamp}_{prompt_str}"
abc_filename = f"{filename_base}.abc"
with open(abc_filename, "w", encoding="utf-8") as f:
f.write(abc_content)
# instrumentation replacement
postprocessed_inst_abc = postprocess_inst_names(abc_content)
filename_base_postinst = f"{filename_base}_postinst"
with open(filename_base_postinst + ".abc", "w", encoding="utf-8") as f:
f.write(postprocessed_inst_abc)
# Convert files
file_paths = {'abc': abc_filename}
try:
# abc2xml
abc2xml(filename_base)
abc2xml(filename_base_postinst)
# xml2pdf
xml2(filename_base, 'pdf')
# xml2mid
xml2(filename_base, 'mid')
xml2(filename_base_postinst, 'mid')
# xml2mp3
xml2(filename_base, 'mp3')
xml2(filename_base_postinst, 'mp3')
# 将PDF转为图片
images = pdf2img(filename_base)
for i, image in enumerate(images):
image.save(f"{filename_base}_page_{i+1}.png", "PNG")
file_paths.update({
'xml': f"{filename_base_postinst}.xml",
'pdf': f"{filename_base}.pdf",
'mid': f"{filename_base_postinst}.mid",
'mp3': f"{filename_base_postinst}.mp3",
'pages': len(images),
'current_page': 0,
'base': filename_base
})
except Exception as e:
raise gr.Error(f"File processing failed: {str(e)}")
return file_paths
# Page navigation control function
def update_page(direction, data):
"""
data contains three key pieces of information: 'pages', 'current_page', and 'base'
"""
if not data:
return None, gr.update(interactive=False), gr.update(interactive=False), data
if direction == "prev" and data['current_page'] > 0:
data['current_page'] -= 1
elif direction == "next" and data['current_page'] < data['pages'] - 1:
data['current_page'] += 1
current_page_index = data['current_page']
# Update image path
new_image = f"{data['base']}_page_{current_page_index+1}.png"
# When current_page==0, prev_btn is disabled; when current_page==pages-1, next_btn is disabled
prev_btn_state = gr.update(interactive=(current_page_index > 0))
next_btn_state = gr.update(interactive=(current_page_index < data['pages'] - 1))
return new_image, prev_btn_state, next_btn_state, data
# @spaces.GPU(duration=600)
def generate_music(period, composer, instrumentation):
"""
Must ensure each yield returns the same number of values.
We're preparing to return 5 values, corresponding to:
1) process_output (intermediate inference information)
2) final_output (final ABC)
3) pdf_image (path to the PNG of the first page of the PDF)
4) audio_player (mp3 path)
5) pdf_state (state for page navigation)
"""
# Set a different random seed each time based on current timestamp
random_seed = int(time.time()) % 10000
random.seed(random_seed)
# For numpy if you're using it
try:
import numpy as np
np.random.seed(random_seed)
except ImportError:
pass
# For torch if you're using it
try:
import torch
torch.manual_seed(random_seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(random_seed)
except ImportError:
pass
if (period, composer, instrumentation) not in valid_combinations:
# If the combination is invalid, raise an error
raise gr.Error("Invalid prompt combination! Please re-select from the period options")
output_queue = queue.Queue()
original_stdout = sys.stdout
sys.stdout = RealtimeStream(output_queue)
result_container = []
def run_inference():
try:
# Use downloaded model weights path for inference
result = inference_patch(period, composer, instrumentation)
result_container.append(result)
finally:
sys.stdout = original_stdout
thread = threading.Thread(target=run_inference)
thread.start()
process_output = ""
final_output_abc = ""
pdf_image = None
audio_file = None
pdf_state = None
# First continuously read intermediate output
while thread.is_alive():
try:
text = output_queue.get(timeout=0.1)
process_output += text
# No final ABC yet, files not yet converted
yield process_output, final_output_abc, pdf_image, audio_file, pdf_state, gr.update(value=None, visible=False)
except queue.Empty:
continue
# After thread ends, get all remaining items from the queue
while not output_queue.empty():
text = output_queue.get()
process_output += text
# Final inference result
final_result = result_container[0] if result_container else ""
# Display file conversion prompt
final_output_abc = "Converting files..."
yield process_output, final_output_abc, pdf_image, audio_file, pdf_state, gr.update(value=None, visible=False)
# Convert files
try:
file_paths = convert_files(final_result, period, composer, instrumentation)
final_output_abc = final_result
# Get the first image and mp3 file
if file_paths['pages'] > 0:
pdf_image = f"{file_paths['base']}_page_1.png"
audio_file = file_paths['mp3']
pdf_state = file_paths # Directly use the converted information dictionary as state
# Prepare download file list
download_list = []
if 'abc' in file_paths and os.path.exists(file_paths['abc']):
download_list.append(file_paths['abc'])
if 'xml' in file_paths and os.path.exists(file_paths['xml']):
download_list.append(file_paths['xml'])
if 'pdf' in file_paths and os.path.exists(file_paths['pdf']):
download_list.append(file_paths['pdf'])
if 'mid' in file_paths and os.path.exists(file_paths['mid']):
download_list.append(file_paths['mid'])
if 'mp3' in file_paths and os.path.exists(file_paths['mp3']):
download_list.append(file_paths['mp3'])
except Exception as e:
# If conversion fails, return error message to output box
yield process_output, f"Error converting files: {str(e)}", None, None, None, gr.update(value=None, visible=False)
return
# Final yield with all information - modify here to make component visible
yield process_output, final_output_abc, pdf_image, audio_file, pdf_state, gr.update(value=download_list, visible=True)
def get_file(file_type, period, composer, instrumentation):
"""
Returns the local file of specified type for Gradio download
"""
# Here you actually need to return based on specific file paths saved earlier, simplified for demo
# If matching by timestamp, you can store all converted files in a directory and get the latest
# This is just an example:
possible_files = [f for f in os.listdir('.') if f.endswith(f'.{file_type}')]
if not possible_files:
return None
# Simply return the latest
possible_files.sort(key=os.path.getmtime)
return possible_files[-1]
css = """
/* Compact button style */
button[size="sm"] {
padding: 4px 8px !important;
margin: 2px !important;
min-width: 60px;
}
/* PDF preview area */
#pdf-preview {
border-radius: 8px; /* Rounded corners */
box-shadow: 0 2px 8px rgba(0,0,0,0.1); /* Shadow */
}
.page-btn {
padding: 12px !important; /* Increase clickable area */
margin: auto !important; /* Vertical center */
}
/* Button hover effect */
.page-btn:hover {
background: #f0f0f0 !important;
transform: scale(1.05);
}
/* Layout adjustment */
.gr-row {
gap: 10px !important; /* Element spacing */
}
/* Audio player */
.audio-panel {
margin-top: 15px !important;
max-width: 400px;
}
#audio-preview audio {
height: 200px !important;
}
/* Save functionality area */
.save-as-row {
margin-top: 15px;
padding: 10px;
border-top: 1px solid #eee;
}
/* Download files styling */
.download-files {
margin-top: 15px;
border-radius: 8px;
box-shadow: 0 2px 8px rgba(0,0,0,0.1);
}
/* Social icons styling */
.title-container {
display: flex;
align-items: center;
gap: 15px;
margin-bottom: 10px;
}
.title-text {
margin: 0;
font-size: 1.8em;
}
.social-icons {
display: flex;
gap: 10px;
}
.social-icon {
display: inline-flex;
align-items: center;
justify-content: center;
width: 32px;
height: 32px;
border-radius: 50%;
background-color: #f5f5f5;
text-decoration: none;
transition: transform 0.2s, background-color 0.2s;
}
.social-icon:hover {
transform: scale(1.1);
background-color: #e0e0e0;
}
.social-icon img {
width: 20px;
height: 20px;
}
"""
with gr.Blocks(css=css) as demo:
gr.HTML(title_html)
# For storing PDF page count, current page and other information
pdf_state = gr.State()
with gr.Column():
with gr.Row():
# Left sidebar
with gr.Column():
with gr.Row():
period_dd = gr.Dropdown(
choices=periods,
value=None,
label="Period",
interactive=True
)
composer_dd = gr.Dropdown(
choices=[],
value=None,
label="Composer",
interactive=False
)
instrument_dd = gr.Dropdown(
choices=[],
value=None,
label="Instrumentation",
interactive=False
)
generate_btn = gr.Button("Generate!", variant="primary")
process_output = gr.Textbox(
label="Generation process",
interactive=False,
lines=2,
max_lines=2,
placeholder="Generation progress will be shown here..."
)
final_output = gr.Textbox(
label="Post-processed ABC notation scores",
interactive=True,
lines=8,
max_lines=8,
placeholder="Post-processed ABC scores will be shown here..."
)
# Audio playback
audio_player = gr.Audio(
label="Audio Preview",
format="mp3",
interactive=False,
)
# Right sidebar
with gr.Column():
# Image container
pdf_image = gr.Image(
label="Sheet Music Preview",
show_label=False,
height=650,
type="filepath",
elem_id="pdf-preview",
interactive=False,
show_download_button=False
)
# Page navigation buttons
with gr.Row():
prev_btn = gr.Button(
"⬅️ Last Page",
variant="secondary",
size="sm",
elem_classes="page-btn"
)
next_btn = gr.Button(
"Next Page ➡️",
variant="secondary",
size="sm",
elem_classes="page-btn"
)
with gr.Column():
gr.Markdown("**Download Files:**")
download_files = gr.Files(
label="Generated Files",
visible=False,
elem_classes="download-files",
type="filepath" # Make sure this is set to filepath
)
# Dropdown linking
period_dd.change(
update_components,
inputs=[period_dd, composer_dd],
outputs=[composer_dd, instrument_dd]
)
composer_dd.change(
update_components,
inputs=[period_dd, composer_dd],
outputs=[composer_dd, instrument_dd]
)
# Click generate button, note outputs must match each yield in generate_music
generate_btn.click(
generate_music,
inputs=[period_dd, composer_dd, instrument_dd],
outputs=[process_output, final_output, pdf_image, audio_player, pdf_state, download_files]
)
# Page navigation
prev_signal = gr.Textbox(value="prev", visible=False)
next_signal = gr.Textbox(value="next", visible=False)
prev_btn.click(
update_page,
inputs=[prev_signal, pdf_state], # ✅ Use component
outputs=[pdf_image, prev_btn, next_btn, pdf_state]
)
next_btn.click(
update_page,
inputs=[next_signal, pdf_state], # ✅ Use component
outputs=[pdf_image, prev_btn, next_btn, pdf_state]
)
if __name__ == "__main__":
# Configure GPU/CPU handling
demo.launch(
server_name="0.0.0.0",
server_port=7860
)