tspsram commited on
Commit
49ad2ae
·
verified ·
1 Parent(s): c90f843

Upload 2 files

Browse files
Files changed (2) hide show
  1. main.py +129 -0
  2. requirements.txt +22 -0
main.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torchaudio
4
+ import os
5
+ import numpy as np
6
+ import base64
7
+ from audiocraft.models import MusicGen
8
+
9
+ # Before
10
+ batch_size = 64
11
+
12
+ # After
13
+ batch_size = 32
14
+ torch.cuda.empty_cache()
15
+
16
+ genres = ["Pop", "Rock", "Jazz", "Electronic", "Hip-Hop", "Classical", "Lofi", "Chillpop"]
17
+
18
+ @st.cache_resource()
19
+ def load_model():
20
+ model = MusicGen.get_pretrained('facebook/musicgen-small')
21
+ return model
22
+
23
+ def generate_music_tensors(description, duration: int):
24
+ model = load_model()
25
+
26
+ model.set_generation_params(
27
+ use_sampling=True,
28
+ top_k=250,
29
+ duration=duration
30
+ )
31
+
32
+ with st.spinner("Generating Music..."):
33
+ output = model.generate(
34
+ descriptions=description,
35
+ progress=True,
36
+ return_tokens=True
37
+ )
38
+
39
+ st.success("Music Generation Complete!")
40
+ return output
41
+
42
+ def save_audio(samples: torch.Tensor):
43
+ print("Samples (inside function): ", samples)
44
+ sample_rate = 30000
45
+ save_path = "audio_output/"
46
+ sample= samples[0]
47
+ assert sample.dim() == 2 or sample.dim() == 3
48
+
49
+ sample = sample.detach().cpu()
50
+ if sample.dim() == 2:
51
+ sample = sample[None, ...]
52
+
53
+ for idx, audio in enumerate(sample):
54
+ audio_path = os.path.join(save_path, f"audio_{idx}.wav")
55
+ torchaudio.save(audio_path, audio, sample_rate)
56
+
57
+ def get_binary_file_downloader_html(bin_file, file_label='File'):
58
+ with open(bin_file, 'rb') as f:
59
+ data = f.read()
60
+ bin_str = base64.b64encode(data).decode()
61
+ href = f'<a href="data:application/octet-stream;base64,{bin_str}" download="{os.path.basename(bin_file)}">Download {file_label}</a>'
62
+ return href
63
+
64
+ st.set_page_config(
65
+ page_icon= "musical_note",
66
+ page_title= "Music Gen"
67
+ )
68
+
69
+ def main():
70
+ st.title("🎧 AI Composer Medium-Model 🎧")
71
+
72
+ st.subheader("Craft your perfect melody!")
73
+ bpm = st.number_input("Enter Speed in BPM", min_value=2)
74
+
75
+ text_area = st.text_area('Ex : 80s rock song with guitar and drums')
76
+ st.text('')
77
+ # Dropdown for genres
78
+ selected_genre = st.selectbox("Select Genre", genres)
79
+
80
+ st.subheader("2. Select time duration (In Seconds)")
81
+ time_slider = st.slider("Select time duration (In Seconds)", 0, 30, 10)
82
+ mood = st.selectbox("Select Mood (Optional)", ["Happy", "Sad", "Angry", "Relaxed", "Energetic"], None)
83
+ instrument = st.selectbox("Select Instrument (Optional)", ["Piano", "Guitar", "Flute", "Violin", "Drums"], None)
84
+ tempo = st.selectbox("Select Tempo (Optional)", ["Slow", "Moderate", "Fast"], None)
85
+ melody = st.text_input("Enter Melody or Chord Progression (Optional) e.g: C D:min G:7 C, Twinkle Twinkle Little Star", " ")
86
+
87
+ if st.button('Let\'s Generate 🎶'):
88
+ st.text('\n\n')
89
+ st.subheader("Generated Music")
90
+
91
+ # Generate audio
92
+ description = text_area # Initialize description with text_area
93
+ if selected_genre:
94
+ description += f" {selected_genre}"
95
+ st.empty() # Hide the selected_genre selectbox after selecting one option
96
+ if bpm:
97
+ description += f" {bpm} BPM"
98
+ if mood:
99
+ description += f" {mood}"
100
+ st.empty() # Hide the mood selectbox after selecting one option
101
+ if instrument:
102
+ description += f" {instrument}"
103
+ st.empty() # Hide the instrument selectbox after selecting one option
104
+ if tempo:
105
+ description += f" {tempo}"
106
+ st.empty() # Hide the tempo selectbox after selecting one option
107
+ if melody:
108
+ description += f" {melody}"
109
+
110
+ # Clear CUDA memory cache before generating music
111
+ torch.cuda.empty_cache()
112
+
113
+ music_tensors = generate_music_tensors(description, time_slider)
114
+
115
+ # Only play the full audio for index 0
116
+ # idx = 0
117
+ # music_tensor = music_tensors[idx]
118
+ # music_tensor = 1
119
+ save_audio(music_tensors)
120
+ audio_filepath = f'audio_output/audio_0.wav'
121
+ audio_file = open(audio_filepath, 'rb')
122
+ audio_bytes = audio_file.read()
123
+
124
+ # Play the full audio
125
+ st.audio(audio_bytes, format='audio/wav')
126
+ st.markdown(get_binary_file_downloader_html(audio_filepath, f'Audio'), unsafe_allow_html=True)
127
+
128
+ if __name__ == "__main__":
129
+ main()
requirements.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ av
3
+ einops
4
+ flashy>=0.0.1
5
+ hydra-core>=1.1
6
+ hydra_colorlog
7
+ julius
8
+ num2words
9
+ numpy
10
+ sentencepiece
11
+ spacy==3.5.2
12
+ torch>=2.0.0
13
+ torchaudio>=2.0.0
14
+ huggingface_hub
15
+ tqdm
16
+ transformers>=4.31.0 # need Encodec there.
17
+ xformers
18
+ demucs
19
+ librosa
20
+ gradio
21
+ torchmetrics
22
+ encodec