YourMT3-cpu / app.py
mimbres's picture
Update app.py
bd69fba verified
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'amt/src')))
import subprocess
from typing import Tuple, Dict, Literal
from ctypes import ArgumentError
from html_helper import *
from model_helper import *
import torch
import torchaudio
import glob
import gradio as gr
from gradio_log import Log
from pathlib import Path
# gradio_log
log_file = 'amt/log.txt'
Path(log_file).touch()
# @title Load Checkpoint
model_name = 'YPTF.MoE+Multi (noPS)' # @param ["YMT3+", "YPTF+Single (noPS)", "YPTF+Multi (PS)", "YPTF.MoE+Multi (noPS)", "YPTF.MoE+Multi (PS)"]
precision = '16' if torch.cuda.is_available() else '32'# @param ["32", "bf16-mixed", "16"]
project = '2024'
if model_name == "YMT3+":
checkpoint = "notask_all_cross_v6_xk2_amp0811_gm_ext_plus_nops_b72@model.ckpt"
args = [checkpoint, '-p', project, '-pr', precision]
elif model_name == "YPTF+Single (noPS)":
checkpoint = "ptf_all_cross_rebal5_mirst_xk2_edr005_attend_c_full_plus_b100@model.ckpt"
args = [checkpoint, '-p', project, '-enc', 'perceiver-tf', '-ac', 'spec',
'-hop', '300', '-atc', '1', '-pr', precision]
elif model_name == "YPTF+Multi (PS)":
checkpoint = "mc13_256_all_cross_v6_xk5_amp0811_edr005_attend_c_full_plus_2psn_nl26_sb_b26r_800k@model.ckpt"
args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256',
'-dec', 'multi-t5', '-nl', '26', '-enc', 'perceiver-tf',
'-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
elif model_name == "YPTF.MoE+Multi (noPS)":
checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt"
args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
'-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
'-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
'-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
elif model_name == "YPTF.MoE+Multi (PS)":
checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2@model.ckpt"
args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
'-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
'-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
'-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
else:
raise ValueError(model_name)
model = load_model_checkpoint(args=args)
# @title GradIO helper
def prepare_media(source_path_or_url: os.PathLike,
source_type: Literal['audio_filepath', 'youtube_url'],
delete_video: bool = True,
simulate = False) -> Dict:
"""prepare media from source path or youtube, and return audio info"""
# Get audio_file
if source_type == 'audio_filepath':
audio_file = source_path_or_url
elif source_type == 'youtube_url':
if os.path.exists('/download/yt_audio.mp3'):
os.remove('/download/yt_audio.mp3')
# # Download from youtube
with open(log_file, 'w') as lf:
audio_file = './downloaded/yt_audio'
command = ['yt-dlp', '-x', source_path_or_url, '-f', 'bestaudio',
'-o', audio_file, '--audio-format', 'mp3', '--restrict-filenames',
'--extractor-retries', '10',
'--force-overwrites', '--username', 'oauth2', '--password', '', '-v']
if simulate:
command = command + ['-s']
process = subprocess.Popen(command,
stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
for line in iter(process.stdout.readline, ''):
print(line)
# Filter out unnecessary messages
if "www.google.com/device" in line:
hl_text = line.replace("https://www.google.com/device", "\033[93mhttps://www.google.com/device\x1b[0m").split()
hl_text[-1] = "\x1b[31;1m" + hl_text[-1] + "\x1b[0m"
lf.write(' '.join(hl_text)); lf.flush()
process.stdout.close()
process.wait()
audio_file += '.mp3'
else:
raise ValueError(source_type)
# Create info
info = torchaudio.info(audio_file)
return {
"filepath": audio_file,
"track_name": os.path.basename(audio_file).split('.')[0],
"sample_rate": int(info.sample_rate),
"bits_per_sample": int(info.bits_per_sample),
"num_channels": int(info.num_channels),
"num_frames": int(info.num_frames),
"duration": int(info.num_frames / info.sample_rate),
"encoding": str.lower(info.encoding),
}
def process_audio(audio_filepath):
if audio_filepath is None:
return None
audio_info = prepare_media(audio_filepath, source_type='audio_filepath')
midifile = transcribe(model, audio_info)
midifile = to_data_url(midifile)
return create_html_from_midi(midifile) # html midiplayer
def process_video(youtube_url):
# if 'youtu' not in youtube_url:
# return None
audio_info = prepare_media(youtube_url, source_type='youtube_url')
midifile = transcribe(model, audio_info)
midifile = to_data_url(midifile)
return create_html_from_midi(midifile) # html midiplayer
def play_video(youtube_url):
if 'youtu' not in youtube_url:
return None
return create_html_youtube_player(youtube_url)
# def oauth_google():
# return create_html_oauth()
AUDIO_EXAMPLES = glob.glob('examples/*.*', recursive=True)
YOUTUBE_EXAMPLES = ["https://youtu.be/5vJBhdjvVcE?si=s3NFG_SlVju0Iklg",
"https://www.youtube.com/watch?v=vMboypSkj3c",
"https://youtu.be/cQRtUeqmO58?si=DZKZ0t-ISKAaoHQ8",
"https://youtu.be/EOJ0wH6h3rE?si=a99k6BnSajvNmXcn",
"https://youtu.be/7mjQooXt28o?si=qqmMxCxwqBlLPDI2",
"https://youtu.be/bnS-HK_lTHA?si=PQLVAab3QHMbv0S3https://youtu.be/zJB0nnOc7bM?si=EA1DN8nHWJcpQWp_",
"https://youtu.be/mIWYTg55h10?si=WkbtKfL6NlNquvT8"]
theme = gr.Theme.from_hub("gradio/dracula_revamped")
theme.text_md = '10px'
theme.text_lg = '12px'
theme.body_background_fill_dark = '#060a1c' #'#372037'# '#a17ba5' #'#73d3ac'
theme.border_color_primary_dark = '#45507328'
theme.block_background_fill_dark = '#3845685c'
theme.body_text_color_dark = 'white'
theme.block_title_text_color_dark = 'black'
theme.body_text_color_subdued_dark = '#e4e9e9'
css = """
.gradio-container {
background: linear-gradient(-45deg, #ee7752, #e73c7e, #23a6d5, #23d5ab);
background-size: 400% 400%;
animation: gradient 15s ease infinite;
height: 100vh;
}
@keyframes gradient {
0% {background-position: 0% 50%;}
50% {background-position: 100% 50%;}
100% {background-position: 0% 50%;}
}
#mylog {font-size: 12pt; line-height: 1.2; min-height: 2em; max-height: 4em;}
"""
with gr.Blocks(theme=theme, css=css) as demo:
with gr.Row():
with gr.Column(scale=10):
gr.Markdown(
f"""
## 🎶YourMT3+: Multi-instrument Music Transcription with Enhanced Transformer Architectures and Cross-dataset Stem Augmentation
## Model card:
- Model name: `{model_name}`
<details>
<summary>▶model details◀</summary>
| **Component** | **Details** |
|--------------------------|--------------------------------------------------|
| Encoder backbone | Perceiver-TF + Mixture of Experts (2/8) |
| Decoder backbone | Multi-channel T5-small |
| Tokenizer | MT3 tokens with Singing extension |
| Dataset | YourMT3 dataset |
| Augmentation strategy | Intra-/Cross dataset stem augment, No Pitch-shifting |
| FP Precision | BF16-mixed for training, FP16 for inference |
</details>
## Caution:
- Currently running on CPU, and it takes longer than 3 minutes for a 30-second input. Please try [GPU-HuggingFace-demo](mimbres/YourMT3) for fast inference.
- For acadmic reproduction purpose, we strongly recommend to use [Colab Demo](https://colab.research.google.com/drive/1AgOVEBfZknDkjmSRA7leoa81a2vrnhBG?usp=sharing) with multiple checkpoints.
## YouTube transcription (working🚀):
- Press the `Transcribe` button, copy the 12-digit code below, and paste it into `google.com/device`. (Only needed once.)
<div style="display: inline-block;">
<a href="https://arxiv.org/abs/2407.04822">
<img src="https://img.shields.io/badge/arXiv:2407.04822-B31B1B?logo=arxiv&logoColor=fff&style=plastic" alt="arXiv Badge"/>
</a>
</div>
<div style="display: inline-block;">
<a href="https://github.com/mimbres/YourMT3">
<img src="https://img.shields.io/badge/GitHub-181717?logo=github&logoColor=fff&style=plastic" alt="GitHub Badge"/>
</a>
</div>
<div style="display: inline-block;">
<a href="https://colab.research.google.com/drive/1AgOVEBfZknDkjmSRA7leoa81a2vrnhBG?usp=sharing">
<img src="https://img.shields.io/badge/Google%20Colab-F9AB00?logo=googlecolab&logoColor=fff&style=plastic"/>
</a>
</div>
""")
with gr.Group():
with gr.Tab("Upload audio"):
# Input
audio_input = gr.Audio(label="Record Audio", type="filepath",
show_share_button=True, show_download_button=True)
# Display examples
gr.Examples(examples=AUDIO_EXAMPLES, inputs=audio_input)
# Submit button
transcribe_audio_button = gr.Button("Transcribe", variant="primary")
# Transcribe
output_tab1 = gr.HTML()
transcribe_audio_button.click(process_audio, inputs=audio_input, outputs=output_tab1)
with gr.Tab("From YouTube"):
with gr.Column(scale=4):
# Input URL
youtube_url = gr.Textbox(label="YouTube Link URL",
placeholder="https://youtu.be/...")
# Display examples
gr.Examples(examples=YOUTUBE_EXAMPLES, inputs=youtube_url)
# Play button
play_video_button = gr.Button("Get Audio from YouTube", variant="primary")
# Play youtube
youtube_player = gr.HTML(render=True)
with gr.Column(scale=4):
with gr.Row():
# Submit button
transcribe_video_button = gr.Button("Transcribe", variant="primary")
# Oauth button
oauth_button = gr.Button("google.com/device", variant="primary", link="https://www.google.com/device")
with gr.Column(scale=1):
# Transcribe
output_tab2 = gr.HTML(render=True)
# video_output = gr.Text(label="Video Info")
transcribe_video_button.click(process_video, inputs=youtube_url, outputs=output_tab2)
# Play
play_video_button.click(play_video, inputs=youtube_url, outputs=youtube_player)
with gr.Column(scale=1):
logger = Log(log_file, dark=True, xterm_font_size=12, every=None, elem_id='mylog')
demo.launch(debug=True)