Spaces:
Runtime error
Runtime error
File size: 3,982 Bytes
e0c1514 b4432a0 e0c1514 766f359 58fc0c2 766f359 e0c1514 b52703a bce68e6 766f359 e0c1514 6f1d0eb e0c1514 766f359 e0c1514 3b9debd e0c1514 3b9debd 766f359 e0c1514 8f641db f0c1424 e0c1514 766f359 e0c1514 766f359 2cae516 7159bbe 3b9debd 7159bbe e0c1514 7c39bf5 e0c1514 3b9debd e0c1514 766f359 e0c1514 7c39bf5 bce68e6 766f359 0c3e3fb ecc5750 e0c1514 766f359 e0c1514 7c39bf5 a888933 a22f91a a888933 766f359 57cdeb5 766f359 a888933 766f359 a888933 5662bb5 a888933 5f5c021 4363c0a 3b9debd e0c1514 4a79b1a 766f359 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import streamlit as st
import torch
import torchaudio
from audiocraft.models import MusicGen
import os
import numpy as np
import base64
genres = ["Pop", "Rock", "Jazz", "Electronic", "Hip-Hop", "Classical",
"Lofi", "Chillpop","Country","R&G", "Folk","Heavy Metal",
"EDM", "Soil", "Funk","Reggae", "Disco", "Punk Rock", "House",
"Techno","Indie Rock", "Grunge", "Ambient","Gospel", "Latin Music","Grime" ,"Trap", "Psychedelic Rock" ]
@st.cache_resource()
def load_model():
model = MusicGen.get_pretrained('facebook/musicgen-small')
return model
def generate_music_tensors(descriptions, duration: int):
model = load_model()
# model = load_model().to('cpu')
model.set_generation_params(
use_sampling=True,
top_k=250,
duration=duration
)
with st.spinner("Generating Music..."):
output = model.generate(
descriptions=descriptions,
progress=True,
return_tokens=True
)
st.success("Music Generation Complete!")
return output
def save_audio(samples: torch.Tensor):
sample_rate = 30000
save_path = "audio_output"
assert samples.dim() == 2 or samples.dim() == 3
samples = samples.detach().cpu()
if samples.dim() == 2:
samples = samples[None, ...]
for idx, audio in enumerate(samples):
audio_path = os.path.join(save_path, f"audio_{idx}.wav")
torchaudio.save(audio_path, audio, sample_rate)
def get_binary_file_downloader_html(bin_file, file_label='File'):
with open(bin_file, 'rb') as f:
data = f.read()
bin_str = base64.b64encode(data).decode()
href = f'<a href="data:application/octet-stream;base64,{bin_str}" download="{os.path.basename(bin_file)}">Download {file_label}</a>'
return href
st.set_page_config(
page_icon= "musical_note",
page_title= "Music Gen"
)
def main():
with st.sidebar:
st.header("""⚙️Generate Music ⚙️""",divider="rainbow")
st.text("")
st.subheader("1. Enter your music description.......")
bpm = st.number_input("Enter Speed in BPM", min_value=60)
text_area = st.text_area('Ex : 80s rock song with guitar and drums')
st.text('')
# Dropdown for genres
selected_genre = st.selectbox("Select Genre", genres)
st.subheader("2. Select time duration (In Seconds)")
time_slider = st.slider("Select time duration (In Seconds)", 0, 60, 10)
# time_slider = st.slider("Select time duration (In Minutes)", 0,300,10, step=1)
st.title("""🎵 Song Lab AI 🎵""")
st.text('')
left_co,right_co = st.columns(2)
left_co.write("""Music Generation through a prompt""")
left_co.write(("""PS : First generation may take some time ......."""))
if st.sidebar.button('Generate !'):
with left_co:
st.text('')
st.text('')
st.text('')
st.text('')
st.text('')
st.text('')
st.text('\n\n')
st.subheader("Generated Music")
# Generate audio
# descriptions = [f"{text_area} {selected_genre} {bpm} BPM" for _ in range(5)]
descriptions = [f"{text_area} {selected_genre} {bpm} BPM" for _ in range(1)] # Change the batch size to 1
music_tensors = generate_music_tensors(descriptions, time_slider)
# Only play the full audio for index 0
idx = 0
music_tensor = music_tensors[idx]
save_music_file = save_audio(music_tensor)
audio_filepath = f'audio_output/audio_{idx}.wav'
audio_file = open(audio_filepath, 'rb')
audio_bytes = audio_file.read()
# Play the full audio
st.audio(audio_bytes, format='audio/wav')
st.markdown(get_binary_file_downloader_html(audio_filepath, f'Audio_{idx}'), unsafe_allow_html=True)
if __name__ == "__main__":
main()
|