multimodalart HF staff commited on
Commit
9f88b44
·
verified ·
1 Parent(s): 6d3fdbe

make it multiplayer

Browse files
Files changed (1) hide show
  1. app.py +111 -26
app.py CHANGED
@@ -10,6 +10,13 @@ import uuid
10
  from stable_audio_tools import get_pretrained_model
11
  from stable_audio_tools.inference.generation import generate_diffusion_cond
12
 
 
 
 
 
 
 
 
13
  # Load the model outside of the GPU-decorated function
14
  def load_model():
15
 
@@ -19,7 +26,7 @@ def load_model():
19
 
20
  # Function to set up, generate, and process the audio
21
  @spaces.GPU(duration=120) # Allocate GPU only when this function is called
22
- def generate_audio(prompt, sampler_type_dropdown, seconds_total=30, steps=100, cfg_scale=7,sigma_min_slider=0.3,sigma_max_slider=500):
23
  print(f"Prompt received: {prompt}")
24
  print(f"Settings: Duration={seconds_total}s, Steps={steps}, CFG Scale={cfg_scale}")
25
 
@@ -76,34 +83,54 @@ def generate_audio(prompt, sampler_type_dropdown, seconds_total=30, steps=100, c
76
  print(f"Audio trimmed to {seconds_total} seconds.")
77
 
78
  # Generate a unique filename for the output
79
- unique_filename = f"output_{uuid.uuid4().hex}.wav"
 
 
80
  print(f"Saving audio to file: {unique_filename}")
81
 
82
  # Save to file
83
  torchaudio.save(unique_filename, output, sample_rate)
84
  print(f"Audio saved: {unique_filename}")
85
 
 
 
 
86
  # Return the path to the generated audio file
87
  return unique_filename
88
 
89
- # Setting up the Gradio Interface
90
- interface = gr.Interface(
91
- fn=generate_audio,
92
-
93
- inputs=[
94
- gr.Textbox(label="Prompt", placeholder="Enter your text prompt here"),
95
- gr.Dropdown(["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"], label="Sampler type", value="dpmpp-3m-sde"),
96
- gr.Slider(0, 47, value=30, step=1, label="Duration in Seconds"),
97
- gr.Slider(10, 150, value=100, step=10, label="Number of Diffusion Steps"),
98
- gr.Slider(1, 15, value=7, step=0.1, label="CFG Scale"),
99
- gr.Slider(minimum=0.0, maximum=5.0, step=0.01, value=0.3, label="Sigma min"),
100
- gr.Slider(minimum=0.0, maximum=1000.0, step=0.1, value=500, label="Sigma max"),
101
-
102
- ],
103
- outputs=gr.Audio(type="filepath", label="Generated Audio"),
104
- title="Stable Audio Generator",
105
- description="Generate variable-length stereo audio at 44.1kHz from text prompts using Stable Audio Open 1.0.",
106
- examples=[
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  [
108
  "Create a serene soundscape of a quiet beach at sunset.", # Text prompt
109
  "dpmpp-2m-sde", # Sampler type
@@ -157,12 +184,70 @@ interface = gr.Interface(
157
  0.3, # Sigma min
158
  500 # Sigma max
159
  ]
160
- ]
161
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
- # Pre-load the model to avoid multiprocessing issues
164
  model, model_config = load_model()
165
 
166
- # Launch the Interface
167
- interface.queue(max_size=10).launch()
168
-
 
10
  from stable_audio_tools import get_pretrained_model
11
  from stable_audio_tools.inference.generation import generate_diffusion_cond
12
 
13
+ PAGE_SIZE = 10
14
+ FILE_DIR_PATH = "/data"
15
+
16
+ theme = gr.themes.Base(
17
+ font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'],
18
+ )
19
+
20
  # Load the model outside of the GPU-decorated function
21
  def load_model():
22
 
 
26
 
27
  # Function to set up, generate, and process the audio
28
  @spaces.GPU(duration=120) # Allocate GPU only when this function is called
29
+ def generate_audio(prompt, sampler_type_dropdown, seconds_total=30, steps=100, cfg_scale=7,sigma_min_slider=0.3,sigma_max_slider=500, progress=gr.Progress(track_tqdm=True)):
30
  print(f"Prompt received: {prompt}")
31
  print(f"Settings: Duration={seconds_total}s, Steps={steps}, CFG Scale={cfg_scale}")
32
 
 
83
  print(f"Audio trimmed to {seconds_total} seconds.")
84
 
85
  # Generate a unique filename for the output
86
+ random_uuid = uuid.uuid4().hex
87
+ unique_filename = f"/data/output_{random_uuid}.wav"
88
+ unique_textfile = f"/data/output_{random_uuid}.txt"
89
  print(f"Saving audio to file: {unique_filename}")
90
 
91
  # Save to file
92
  torchaudio.save(unique_filename, output, sample_rate)
93
  print(f"Audio saved: {unique_filename}")
94
 
95
+ with open(unique_textfile, "w") as file:
96
+ file.write(prompt)
97
+
98
  # Return the path to the generated audio file
99
  return unique_filename
100
 
101
+ def list_all_outputs(generation_history):
102
+ directory_path = FILE_DIR_PATH
103
+ files_in_directory = os.listdir(directory_path)
104
+ wav_files = [os.path.join(directory_path, file) for file in files_in_directory if file.endswith('.wav')]
105
+ wav_files.sort(key=lambda x: os.path.getmtime(os.path.join(directory_path, x)), reverse=True)
106
+ history_list = generation_history.split(',') if generation_history else []
107
+ updated_files = [file for file in wav_files if file not in history_list]
108
+ updated_history = updated_files + history_list
109
+ return ','.join(updated_history), gr.update(visible=True)
110
+
111
+ def increase_list_size(list_size):
112
+ return list_size+PAGE_SIZE
113
+
114
+ css = '''
115
+ #live_gen:before {
116
+ content: '';
117
+ animation: svelte-z7cif2-pulseStart 1s cubic-bezier(.4,0,.6,1), svelte-z7cif2-pulse 2s cubic-bezier(.4,0,.6,1) 1s infinite;
118
+ border: 2px solid var(--color-accent);
119
+ background: transparent;
120
+ z-index: var(--layer-1);
121
+ pointer-events: none;
122
+ position: absolute;
123
+ height: 100%;
124
+ width: 100%;
125
+ border-radius: 7px;
126
+ }
127
+ #live_gen_items{
128
+ max-height: 570px;
129
+ overflow-y: scroll;
130
+ }
131
+ '''
132
+
133
+ examples = [
134
  [
135
  "Create a serene soundscape of a quiet beach at sunset.", # Text prompt
136
  "dpmpp-2m-sde", # Sampler type
 
184
  0.3, # Sigma min
185
  500 # Sigma max
186
  ]
187
+ ]
188
+ with gr.Blocks(theme=theme, css=css) as demo:
189
+ gr.Markdown("# Stable Audio Multiplayer Live")
190
+ gr.Markdown("Generate audio with text, share and learn from others how to best prompt this new model")
191
+ generation_history = gr.Textbox(visible=False)
192
+ list_size = gr.Number(value=PAGE_SIZE, visible=False)
193
+ with gr.Row():
194
+ with gr.Column():
195
+ prompt = gr.Textbox(label="Prompt", placeholder="Enter your text prompt here")
196
+ btn_run = gr.Button("Generate")
197
+ with gr.Accordion("Parameters", open=True):
198
+ with gr.Row():
199
+ duration = gr.Slider(0, 47, value=20, step=1, label="Duration in Seconds")
200
+
201
+ with gr.Accordion("Advanced parameters", open=False):
202
+ steps = gr.Slider(10, 150, value=80, step=10, label="Number of Diffusion Steps")
203
+ sampler_type = gr.Dropdown(["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms",
204
+ "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"],
205
+ label="Sampler type", value="dpmpp-3m-sde")
206
+ with gr.Row():
207
+ cfg_scale = gr.Slider(1, 15, value=7, step=0.1, label="CFG Scale")
208
+ sigma_min = gr.Slider(0.0, 5.0, step=0.01, value=0.3, label="Sigma min")
209
+ sigma_max = gr.Slider(0.0, 1000.0, step=0.1, value=500, label="Sigma max")
210
+ with gr.Column() as output_list:
211
+ output = gr.Audio(type="filepath", label="Generated Audio")
212
+ with gr.Column(elem_id="live_gen") as community_list:
213
+ gr.Markdown("# Community generations")
214
+ with gr.Column(elem_id="live_gen_items"):
215
+ @gr.render(inputs=[generation_history, list_size])
216
+ def show_output_list(generation_history, list_size):
217
+ history_list = generation_history.split(',') if generation_history else []
218
+ history_list_latest = history_list[:list_size]
219
+ for generation in history_list_latest:
220
+ generation_prompt_file = generation.replace('.wav', '.txt')
221
+ with open(generation_prompt_file, 'r') as file:
222
+ generation_prompt = file.read()
223
+ with gr.Group():
224
+ gr.Markdown(value=f"### {generation_prompt}")
225
+ gr.Audio(value=generation)
226
+
227
+
228
+ load_more = gr.Button("Load more")
229
+ load_more.click(fn=increase_list_size, inputs=list_size, outputs=list_size)
230
+
231
+ gr.Examples(
232
+ fn=generate_audio,
233
+ examples=examples,
234
+ inputs=[prompt, sampler_type, duration, steps, cfg_scale, sigma_min, sigma_max],
235
+ outputs=output,
236
+ cache_examples="lazy"
237
+ )
238
+ gr.on(
239
+ triggers=[btn_run.click, prompt.submit],
240
+ fn=generate_audio,
241
+ inputs=[prompt, sampler_type, duration, steps, cfg_scale, sigma_min, sigma_max],
242
+ outputs=output
243
+ )
244
+ btn_run.click(
245
+ generate_audio,
246
+ inputs=[prompt, sampler_type, duration, steps, cfg_scale, sigma_min, sigma_max],
247
+ outputs=output
248
+ )
249
+ demo.load(fn=list_all_outputs, inputs=generation_history, outputs=[generation_history, community_list], every=2)
250
 
 
251
  model, model_config = load_model()
252
 
253
+ demo.launch()