Gpagejr12 commited on
Commit
5f5c021
·
verified ·
1 Parent(s): 5662bb5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -12
app.py CHANGED
@@ -6,7 +6,6 @@ import os
6
  import numpy as np
7
  import base64
8
 
9
- genres = ["Pop", "Rock", "Jazz", "Electronic", "Hip-Hop", "Classical", "Lofi", "Chillpop"]
10
 
11
  @st.cache_resource()
12
  def load_model():
@@ -23,19 +22,24 @@ def generate_music_tensors(descriptions, duration: int):
23
  )
24
 
25
  with st.spinner("Generating Music..."):
 
 
26
  output = model.generate(
27
  descriptions=descriptions,
28
  progress=True,
29
  return_tokens=True
30
  )
31
 
32
- st.success("Music Generation Complete!")
33
  return output
34
 
35
-
36
  def save_audio(samples: torch.Tensor):
37
  sample_rate = 30000
38
- save_path = "/tmp/audio_output"
 
 
 
 
39
  assert samples.dim() == 2 or samples.dim() == 3
40
 
41
  samples = samples.detach().cpu()
@@ -44,23 +48,36 @@ def save_audio(samples: torch.Tensor):
44
 
45
  for idx, audio in enumerate(samples):
46
  audio_path = os.path.join(save_path, f"audio_{idx}.wav")
47
- torchaudio.save(audio_path, audio, sample_rate)
 
 
 
 
 
 
48
 
 
 
 
 
 
 
 
 
49
  def get_binary_file_downloader_html(bin_file, file_label='File'):
50
  with open(bin_file, 'rb') as f:
51
  data = f.read()
52
- bin_str = base64.b64encode(data).decode()
53
- href = f'<a href="data:application/octet-stream;base64,{bin_str}" download="{os.path.basename(bin_file)}">Download {file_label}</a>'
54
  return href
55
 
56
  st.set_page_config(
57
  page_icon= "musical_note",
58
  page_title= "Music Gen"
59
  )
60
-
61
  def main():
62
  with st.sidebar:
63
- st.header("""⚙️Generate Music ⚙️""",divider="rainbow")
64
  st.text("")
65
  st.subheader("1. Enter your music description.......")
66
  bpm = st.number_input("Enter Speed in BPM", min_value=60)
@@ -75,7 +92,7 @@ def main():
75
 
76
  st.title("""🎵 Song Lab AI 🎵""")
77
  st.text('')
78
- left_co,right_co = st.columns(2)
79
  left_co.write("""Music Generation through a prompt""")
80
  left_co.write(("""PS : First generation may take some time ......."""))
81
 
@@ -93,7 +110,7 @@ def main():
93
  descriptions = [f"{text_area} {selected_genre} {bpm} BPM" for _ in range(5)] # Adjust the batch size (5 in this case)
94
  music_tensors = generate_music_tensors(descriptions, time_slider)
95
 
96
- # Only play the full audio for index 0
97
  idx = 0
98
  music_tensor = music_tensors[idx]
99
  save_music_file = save_audio(music_tensor)
@@ -103,8 +120,9 @@ def main():
103
 
104
  # Play the full audio
105
  st.audio(audio_bytes, format='audio/wav')
106
- st.markdown(get_binary_file_downloader_html(audio_filepath, f'Audio_{idx}'), unsafe_allow_html=True)
107
 
 
 
108
 
109
  if __name__ == "__main__":
110
  main()
 
6
  import numpy as np
7
  import base64
8
 
 
9
 
10
  @st.cache_resource()
11
  def load_model():
 
22
  )
23
 
24
  with st.spinner("Generating Music..."):
25
+ st.markdown("### Generating Music... 🎵🎶🎹")
26
+
27
  output = model.generate(
28
  descriptions=descriptions,
29
  progress=True,
30
  return_tokens=True
31
  )
32
 
33
+ st.success("Music Generation Complete! 🎉")
34
  return output
35
 
 
36
  def save_audio(samples: torch.Tensor):
37
  sample_rate = 30000
38
+ save_path = "/tmp/audio_output" # Use /tmp directory
39
+
40
+ if not os.path.exists(save_path):
41
+ os.makedirs(save_path)
42
+
43
  assert samples.dim() == 2 or samples.dim() == 3
44
 
45
  samples = samples.detach().cpu()
 
48
 
49
  for idx, audio in enumerate(samples):
50
  audio_path = os.path.join(save_path, f"audio_{idx}.wav")
51
+ try:
52
+ torchaudio.save(audio_path, audio, sample_rate)
53
+ except Exception as e:
54
+ st.error(f"Error saving audio file: {e}")
55
+ return None
56
+
57
+ return save_path
58
 
59
+
60
+
61
+
62
+ # Define the genres list
63
+ genres = ["Pop", "Rock", "Jazz", "Electronic", "Hip-Hop", "Classical", "Lofi", "Chillpop"]
64
+
65
+
66
+ # Add this function for downloading binary files
67
  def get_binary_file_downloader_html(bin_file, file_label='File'):
68
  with open(bin_file, 'rb') as f:
69
  data = f.read()
70
+ bin_str = base64.b64encode(data).decode()
71
+ href = f'<a href="data:application/octet-stream;base64,{bin_str}" download="{file_label}">Download {file_label}</a>'
72
  return href
73
 
74
  st.set_page_config(
75
  page_icon= "musical_note",
76
  page_title= "Music Gen"
77
  )
 
78
  def main():
79
  with st.sidebar:
80
+ st.header("""⚙️Generate Music ⚙️""", divider="rainbow")
81
  st.text("")
82
  st.subheader("1. Enter your music description.......")
83
  bpm = st.number_input("Enter Speed in BPM", min_value=60)
 
92
 
93
  st.title("""🎵 Song Lab AI 🎵""")
94
  st.text('')
95
+ left_co, right_co = st.columns(2)
96
  left_co.write("""Music Generation through a prompt""")
97
  left_co.write(("""PS : First generation may take some time ......."""))
98
 
 
110
  descriptions = [f"{text_area} {selected_genre} {bpm} BPM" for _ in range(5)] # Adjust the batch size (5 in this case)
111
  music_tensors = generate_music_tensors(descriptions, time_slider)
112
 
113
+ # Only play the full audio for index 0
114
  idx = 0
115
  music_tensor = music_tensors[idx]
116
  save_music_file = save_audio(music_tensor)
 
120
 
121
  # Play the full audio
122
  st.audio(audio_bytes, format='audio/wav')
 
123
 
124
+ # Add download link
125
+ st.markdown(get_binary_file_downloader_html(audio_filepath, f'Audio_{idx}'), unsafe_allow_html=True)
126
 
127
  if __name__ == "__main__":
128
  main()