unpairedelectron07 commited on
Commit
288248e
1 Parent(s): 68d5a3a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -0
app.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from audiocraft.models import MusicGen
2
+ import streamlit as st
3
+ import os
4
+ import torch
5
+ import torchaudio
6
+ import numpy as np
7
+ import base64
8
+ from dotenv import load_dotenv
9
+ import google.generativeai as genai
10
+ load_dotenv()
11
+
12
+ genai.configure(api_key=os.getenv("API_KEY"))
13
+ llm = genai.GenerativeModel("gemini-pro")
14
+
15
+ @st.cache_resource
16
+ def load_model():
17
+ model = MusicGen.get_pretrained("facebook/musicgen-small")
18
+ return model
19
+
20
+ def generate_music_tensors(description, duration:int):
21
+ print(f"Description: {description}")
22
+ print(f"Duration: {duration}")
23
+ model = load_model()
24
+
25
+ model.set_generation_params(
26
+ use_sampling=True,
27
+ top_k=250,
28
+ duration=duration
29
+ )
30
+
31
+ output = model.generate(
32
+ descriptions=[description],
33
+ progress=True,
34
+ return_tokens=True
35
+ )
36
+
37
+ return output[0]
38
+
39
+ def save_audio(samples: torch.Tensor):
40
+ sample_rate = 32000
41
+ save_path = "saved_audio/"
42
+
43
+ assert samples.dim() == 2 or samples.dim() == 3
44
+ samples = samples.detach().cpu()
45
+
46
+ if samples.dim() == 2:
47
+ samples = samples[None, ...]
48
+
49
+ for idx, audio in enumerate(samples):
50
+ audio_path = os.path.join(save_path, f"audio_{idx}.wav")
51
+ torchaudio.save(audio_path, audio, sample_rate)
52
+
53
+ def download_music(bin_file, file_label="File"):
54
+ with open(bin_file, 'rb') as f:
55
+ data = f.read()
56
+
57
+ bin_str = base64.b64encode(data).decode()
58
+ href = f'<a href="data:application/octet-stream;base64,{bin_str}" download="{os.path.basename(bin_file)}">Download {file_label}</a>'
59
+ return href
60
+
61
+ st.set_page_config(
62
+ page_icon=":musical_note:",
63
+ page_title="MusicGen"
64
+ )
65
+
66
+ def main():
67
+ st.title("Text to Music Generation")
68
+
69
+ with st.expander("View Details..."):
70
+
71
+ st.write("This was built by https://github.com/ishan-kshirsagar0-7 using Meta's Audiocraft library. Enter the description of the music you want to generate, and set the duration with the slider given below. The longer the duration slider, the longer it will take to generate the music.")
72
+
73
+ text_area = st.text_area("Enter your description...")
74
+ time_slider = st.slider("Select time duration (in seconds)", 2, 20, 5)
75
+
76
+ context = f"""Given the basic description of a prompt for a text-to-music generator below, enhance that prompt by using specific, direct, accurate and relevant vocabulary. This enhanced prompt must clearly assert and describe the kind of music user wants to generate, with the help of appropriate musical terminology or taxonomy. Craft a creative prompt that clearly explains the text-to-music model what music the user desires. DO NOT respond with anything other than the output prompt. You can be as creative as you like with the descriptions, but DO NOT make up details that the original prompt did not ask for. Also, make sure the description is not too lengthy, keep it concise. Your prompt must explain the flow of the music from start through the middle towards the finish, explicitly mentioning the way instruments are played and what they should sound like.
77
+
78
+ ORIGINAL PROMPT : {text_area}
79
+ YOUR OUTPUT PROMPT :
80
+ """
81
+ llm_result = llm.generate_content(context)
82
+ prompt = llm_result.text
83
+
84
+ if text_area and time_slider:
85
+ st.json(
86
+ {
87
+ "Description": prompt,
88
+ "Duration": time_slider
89
+ }
90
+ )
91
+
92
+ st.subheader("Generated Music")
93
+
94
+ music_tensors = generate_music_tensors(prompt, time_slider)
95
+ print(f"Music Tensors: {music_tensors}")
96
+
97
+ save_music_file = save_audio(music_tensors)
98
+
99
+ audio_filepath = "saved_audio/audio_0.wav"
100
+ audio_file = open(audio_filepath, 'rb')
101
+ audio_bytes = audio_file.read()
102
+
103
+ st.audio(audio_bytes)
104
+ st.markdown(download_music(audio_filepath, 'Audio'), unsafe_allow_html=True)
105
+
106
+ if __name__ == '__main__':
107
+ main()