Gpagejr12 commited on
Commit
7159bbe
·
verified ·
1 Parent(s): 503ccd3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -11
app.py CHANGED
@@ -6,13 +6,6 @@ import os
6
  import numpy as np
7
  import base64
8
 
9
- # Add this function for downloading binary files
10
- def get_binary_file_downloader_html(bin_file, file_label='File'):
11
- with open(bin_file, 'rb') as f:
12
- data = f.read()
13
- bin_str = base64.b64encode(data).decode()
14
- href = f'<a href="data:application/octet-stream;base64,{bin_str}" download="{file_label}">Download {file_label}</a>'
15
- return href
16
 
17
  @st.cache_resource()
18
  def load_model():
@@ -44,10 +37,10 @@ def save_audio(samples: torch.Tensor):
44
  sample_rate = 30000
45
  save_path = "/tmp/audio_output" # Use /tmp directory
46
 
47
- if not os.path.exists(save_path):
48
- os.makedirs(save_path)
49
 
50
- assert samples.dim() == 2 or samples.dim() == 3
51
 
52
  samples = samples.detach().cpu()
53
  if samples.dim() == 2:
@@ -63,9 +56,21 @@ def save_audio(samples: torch.Tensor):
63
 
64
  return save_path
65
 
 
 
 
66
  # Define the genres list
67
  genres = ["Pop", "Rock", "Jazz", "Electronic", "Hip-Hop", "Classical", "Lofi", "Chillpop"]
68
 
 
 
 
 
 
 
 
 
 
69
  st.set_page_config(
70
  page_icon= "musical_note",
71
  page_title= "Music Gen"
@@ -109,7 +114,7 @@ def main():
109
  idx = 0
110
  music_tensor = music_tensors[idx]
111
  save_music_file = save_audio(music_tensor)
112
- audio_filepath = f'audio_output/audio_{idx}.wav'
113
  audio_file = open(audio_filepath, 'rb')
114
  audio_bytes = audio_file.read()
115
 
 
6
  import numpy as np
7
  import base64
8
 
 
 
 
 
 
 
 
9
 
10
  @st.cache_resource()
11
  def load_model():
 
37
  sample_rate = 30000
38
  save_path = "/tmp/audio_output" # Use /tmp directory
39
 
40
+ # if not os.path.exists(save_path):
41
+ # os.makedirs(save_path)
42
 
43
+ # assert samples.dim() == 2 or samples.dim() == 3
44
 
45
  samples = samples.detach().cpu()
46
  if samples.dim() == 2:
 
56
 
57
  return save_path
58
 
59
+
60
+
61
+
62
  # Define the genres list
63
  genres = ["Pop", "Rock", "Jazz", "Electronic", "Hip-Hop", "Classical", "Lofi", "Chillpop"]
64
 
65
+
66
+ # Add this function for downloading binary files
67
+ def get_binary_file_downloader_html(bin_file, file_label='File'):
68
+ with open(bin_file, 'rb') as f:
69
+ data = f.read()
70
+ bin_str = base64.b64encode(data).decode()
71
+ href = f'<a href="data:application/octet-stream;base64,{bin_str}" download="{file_label}">Download {file_label}</a>'
72
+ return href
73
+
74
  st.set_page_config(
75
  page_icon= "musical_note",
76
  page_title= "Music Gen"
 
114
  idx = 0
115
  music_tensor = music_tensors[idx]
116
  save_music_file = save_audio(music_tensor)
117
+ audio_filepath = f'/tmp/audio_output/audio_{idx}.wav'
118
  audio_file = open(audio_filepath, 'rb')
119
  audio_bytes = audio_file.read()
120