Spaces:
Runtime error
Runtime error
import streamlit as st | |
import os | |
from transformer_wrapper import TransformerWrapper | |
from omegaconf import OmegaConf | |
def get_file_content_as_string(path): | |
return open(path, "r", encoding="utf-8").read() | |
def model_load(): | |
config = OmegaConf.load("config.yaml") | |
wrapper = TransformerWrapper(config) | |
wrapper = wrapper.load_from_checkpoint( | |
"https://huggingface.co/sweetcocoa/pop2piano/resolve/main/model-1999-val_0.67311615.ckpt", | |
config=config, | |
map_location="cpu", | |
) | |
model_id = "dpipqxiy" | |
wrapper.eval() | |
return wrapper, model_id, config | |
def main(): | |
wrapper, model_id, config = model_load() | |
composers = list(config.composer_to_feature_token.keys()) | |
dest_dir = "ytsamples" | |
os.makedirs(dest_dir, exist_ok=True) | |
composer = st.selectbox(label="Arranger", options=composers) | |
file_up = st.file_uploader("Upload an audio", type=["mp3", "wav"]) | |
if st.button("convert"): | |
if file_up is not None: | |
bytes_data = file_up.getvalue() | |
target_file = f"{dest_dir}/{file_up.name}" | |
with open(target_file, "wb") as f: | |
f.write(bytes_data) | |
with st.spinner("Wait for it..."): | |
midi, arranger, mix_path, midi_path = wrapper.generate( | |
audio_path=target_file, | |
composer=composer, | |
model=model_id, | |
ignore_duplicate=True, | |
show_plot=False, | |
save_midi=True, | |
save_mix=True, | |
) | |
with open(midi_path, "rb") as midi_f: | |
file_down = st.download_button( | |
"Download midi", | |
data=midi_f, | |
file_name=os.path.basename(midi_path), | |
) | |
with open(mix_path, "rb") as audio_f: | |
st.audio(audio_f.read(), format="audio/wav") | |
if __name__ == "__main__": | |
main() | |