kgout commited on
Commit
8778797
·
verified ·
1 Parent(s): 6f64722

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +425 -2
app.py CHANGED
@@ -1,8 +1,364 @@
1
  import gradio as gr
2
- from audiosr import super_resolution, build_model
3
  import torch
4
  import gc # free up memory
5
  import spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  @spaces.GPU(duration=300)
8
  def inference(audio_file, model_name, guidance_scale, ddim_steps, seed):
@@ -26,12 +382,79 @@ def inference(audio_file, model_name, guidance_scale, ddim_steps, seed):
26
  ddim_steps=ddim_steps
27
  )
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  if torch.cuda.is_avaible():
30
  torch.cuda.empty_cache()
31
 
32
  gc.collect()
33
 
34
- return (48000, waveform)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  iface = gr.Interface(
37
  fn=inference,
 
1
  import gradio as gr
 
2
  import torch
3
  import gc # free up memory
4
  import spaces
5
+ import gc
6
+ import os
7
+ import random
8
+ import numpy as np
9
+ from scipy.signal.windows import hann
10
+ import soundfile as sf
11
+ import torch
12
+ import librosa
13
+ from audiosr import build_model, super_resolution
14
+ from scipy import signal
15
+ import pyloudnorm as pyln
16
+ import tempfile
17
+
18
+ class AudioUpscaler:
19
+ """
20
+ Upscales audio using the AudioSR model.
21
+ """
22
+
23
+ def __init__(self, model_name="basic", device="auto"):
24
+ """
25
+ Initializes the AudioUpscaler.
26
+
27
+ Args:
28
+ model_name (str, optional): Name of the AudioSR model to use. Defaults to "basic".
29
+ device (str, optional): Device to use for inference. Defaults to "auto".
30
+ """
31
+
32
+ self.model_name = model_name
33
+ self.device = device
34
+ self.sr = 48000
35
+ self.audiosr = None # Model will be loaded in setup()
36
+
37
+ def setup(self):
38
+ """
39
+ Loads the AudioSR model.
40
+ """
41
+
42
+ print("Loading Model...")
43
+ self.audiosr = build_model(model_name=self.model_name, device=self.device)
44
+ print("Model loaded!")
45
+
46
+ def _match_array_shapes(self, array_1: np.ndarray, array_2: np.ndarray):
47
+ """
48
+ Matches the shapes of two arrays by padding the shorter one with zeros.
49
+
50
+ Args:
51
+ array_1 (np.ndarray): First array.
52
+ array_2 (np.ndarray): Second array.
53
+
54
+ Returns:
55
+ np.ndarray: The first array with a matching shape to the second array.
56
+ """
57
+
58
+ if (len(array_1.shape) == 1) & (len(array_2.shape) == 1):
59
+ if array_1.shape[0] > array_2.shape[0]:
60
+ array_1 = array_1[: array_2.shape[0]]
61
+ elif array_1.shape[0] < array_2.shape[0]:
62
+ array_1 = np.pad(
63
+ array_1,
64
+ ((array_2.shape[0] - array_1.shape[0], 0)),
65
+ "constant",
66
+ constant_values=0,
67
+ )
68
+ else:
69
+ if array_1.shape[1] > array_2.shape[1]:
70
+ array_1 = array_1[:, : array_2.shape[1]]
71
+ elif array_1.shape[1] < array_2.shape[1]:
72
+ padding = array_2.shape[1] - array_1.shape[1]
73
+ array_1 = np.pad(
74
+ array_1, ((0, 0), (0, padding)), "constant", constant_values=0
75
+ )
76
+ return array_1
77
+
78
+ def _lr_filter(
79
+ self, audio, cutoff, filter_type, order=12, sr=48000
80
+ ):
81
+ """
82
+ Applies a low-pass or high-pass filter to the audio.
83
+
84
+ Args:
85
+ audio (np.ndarray): Audio data.
86
+ cutoff (int): Cutoff frequency.
87
+ filter_type (str): Filter type ("lowpass" or "highpass").
88
+ order (int, optional): Filter order. Defaults to 12.
89
+ sr (int, optional): Sample rate. Defaults to 48000.
90
+
91
+ Returns:
92
+ np.ndarray: Filtered audio data.
93
+ """
94
+
95
+ audio = audio.T
96
+ nyquist = 0.5 * sr
97
+ normal_cutoff = cutoff / nyquist
98
+ b, a = signal.butter(
99
+ order // 2, normal_cutoff, btype=filter_type, analog=False
100
+ )
101
+ sos = signal.tf2sos(b, a)
102
+ filtered_audio = signal.sosfiltfilt(sos, audio)
103
+ return filtered_audio.T
104
+
105
+ def _process_audio(
106
+ self,
107
+ input_file,
108
+ chunk_size=5.12,
109
+ overlap=0.1,
110
+ seed=None,
111
+ guidance_scale=3.5,
112
+ ddim_steps=50,
113
+ multiband_ensemble=True,
114
+ input_cutoff=14000,
115
+ ):
116
+ """
117
+ Processes the audio in chunks and performs upsampling.
118
+
119
+ Args:
120
+ input_file (str): Path to the input audio file.
121
+ chunk_size (float, optional): Chunk size in seconds. Defaults to 5.12.
122
+ overlap (float, optional): Overlap between chunks in seconds. Defaults to 0.1.
123
+ seed (int, optional): Random seed. Defaults to None.
124
+ guidance_scale (float, optional): Scale for classifier-free guidance. Defaults to 3.5.
125
+ ddim_steps (int, optional): Number of inference steps. Defaults to 50.
126
+ multiband_ensemble (bool, optional): Whether to use multiband ensemble. Defaults to True.
127
+ input_cutoff (int, optional): Input cutoff frequency for multiband ensemble. Defaults to 14000.
128
+
129
+ Returns:
130
+ np.ndarray: Upsampled audio data.
131
+ """
132
+
133
+ audio, sr = librosa.load(input_file, sr=input_cutoff * 2, mono=False)
134
+ audio = audio.T
135
+ sr = input_cutoff * 2
136
+
137
+ is_stereo = len(audio.shape) == 2
138
+ if is_stereo:
139
+ audio_ch1, audio_ch2 = audio[:, 0], audio[:, 1]
140
+ else:
141
+ audio_ch1 = audio
142
+
143
+ chunk_samples = int(chunk_size * sr)
144
+ overlap_samples = int(overlap * chunk_samples)
145
+
146
+ output_chunk_samples = int(chunk_size * self.sr)
147
+ output_overlap_samples = int(overlap * output_chunk_samples)
148
+ enable_overlap = True if overlap > 0 else False
149
+
150
+ def process_chunks(audio):
151
+ chunks = []
152
+ original_lengths = []
153
+ start = 0
154
+ while start < len(audio):
155
+ end = min(start + chunk_samples, len(audio))
156
+ chunk = audio[start:end]
157
+ if len(chunk) < chunk_samples:
158
+ original_lengths.append(len(chunk))
159
+ pad = np.zeros(chunk_samples - len(chunk))
160
+ chunk = np.concatenate([chunk, pad])
161
+ else:
162
+ original_lengths.append(chunk_samples)
163
+ chunks.append(chunk)
164
+ start += (
165
+ chunk_samples - overlap_samples
166
+ if enable_overlap
167
+ else chunk_samples
168
+ )
169
+ return chunks, original_lengths
170
+
171
+ chunks_ch1, original_lengths_ch1 = process_chunks(audio_ch1)
172
+ if is_stereo:
173
+ chunks_ch2, original_lengths_ch2 = process_chunks(audio_ch2)
174
+
175
+ sample_rate_ratio = self.sr / sr
176
+ total_length = (
177
+ len(chunks_ch1) * output_chunk_samples
178
+ - (len(chunks_ch1) - 1)
179
+ * (output_overlap_samples if enable_overlap else 0)
180
+ )
181
+ reconstructed_ch1 = np.zeros((1, total_length))
182
+
183
+ meter_before = pyln.Meter(sr)
184
+ meter_after = pyln.Meter(self.sr)
185
+
186
+ for i, chunk in enumerate(chunks_ch1):
187
+ loudness_before = meter_before.integrated_loudness(chunk)
188
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_wav:
189
+ sf.write(temp_wav.name, chunk, sr)
190
+
191
+ out_chunk = super_resolution(
192
+ self.audiosr,
193
+ temp_wav.name,
194
+ seed=seed,
195
+ guidance_scale=guidance_scale,
196
+ ddim_steps=ddim_steps,
197
+ latent_t_per_second=12.8,
198
+ )
199
+ out_chunk = out_chunk[0]
200
+ num_samples_to_keep = int(
201
+ original_lengths_ch1[i] * sample_rate_ratio
202
+ )
203
+ out_chunk = out_chunk[:, :num_samples_to_keep].squeeze()
204
+
205
+ loudness_after = meter_after.integrated_loudness(out_chunk)
206
+ out_chunk = pyln.normalize.loudness(
207
+ out_chunk, loudness_after, loudness_before
208
+ )
209
+
210
+ if enable_overlap:
211
+ actual_overlap_samples = min(
212
+ output_overlap_samples, num_samples_to_keep
213
+ )
214
+ fade_out = np.linspace(1.0, 0.0, actual_overlap_samples)
215
+ fade_in = np.linspace(0.0, 1.0, actual_overlap_samples)
216
+
217
+ if i == 0:
218
+ out_chunk[-actual_overlap_samples:] *= fade_out
219
+ elif i < len(chunks_ch1) - 1:
220
+ out_chunk[:actual_overlap_samples] *= fade_in
221
+ out_chunk[-actual_overlap_samples:] *= fade_out
222
+ else:
223
+ out_chunk[:actual_overlap_samples] *= fade_in
224
+
225
+ start = i * (
226
+ output_chunk_samples - output_overlap_samples
227
+ if enable_overlap
228
+ else output_chunk_samples
229
+ )
230
+ end = start + out_chunk.shape[0]
231
+ reconstructed_ch1[0, start:end] += out_chunk.flatten()
232
+
233
+ if is_stereo:
234
+ reconstructed_ch2 = np.zeros((1, total_length))
235
+ for i, chunk in enumerate(chunks_ch2):
236
+ loudness_before = meter_before.integrated_loudness(chunk)
237
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_wav:
238
+ sf.write(temp_wav.name, chunk, sr)
239
+
240
+ out_chunk = super_resolution(
241
+ self.audiosr,
242
+ temp_wav.name,
243
+ seed=seed,
244
+ guidance_scale=guidance_scale,
245
+ ddim_steps=ddim_steps,
246
+ latent_t_per_second=12.8,
247
+ )
248
+ out_chunk = out_chunk[0]
249
+ num_samples_to_keep = int(
250
+ original_lengths_ch2[i] * sample_rate_ratio
251
+ )
252
+ out_chunk = out_chunk[:, :num_samples_to_keep].squeeze()
253
+
254
+ loudness_after = meter_after.integrated_loudness(out_chunk)
255
+ out_chunk = pyln.normalize.loudness(
256
+ out_chunk, loudness_after, loudness_before
257
+ )
258
+
259
+ if enable_overlap:
260
+ actual_overlap_samples = min(
261
+ output_overlap_samples, num_samples_to_keep
262
+ )
263
+ fade_out = np.linspace(1.0, 0.0, actual_overlap_samples)
264
+ fade_in = np.linspace(0.0, 1.0, actual_overlap_samples)
265
+
266
+ if i == 0:
267
+ out_chunk[-actual_overlap_samples:] *= fade_out
268
+ elif i < len(chunks_ch1) - 1:
269
+ out_chunk[:actual_overlap_samples] *= fade_in
270
+ out_chunk[-actual_overlap_samples:] *= fade_out
271
+ else:
272
+ out_chunk[:actual_overlap_samples] *= fade_in
273
+
274
+ start = i * (
275
+ output_chunk_samples - output_overlap_samples
276
+ if enable_overlap
277
+ else output_chunk_samples
278
+ )
279
+ end = start + out_chunk.shape[0]
280
+ reconstructed_ch2[0, start:end] += out_chunk.flatten()
281
+
282
+ reconstructed_audio = np.stack(
283
+ [reconstructed_ch1, reconstructed_ch2], axis=-1
284
+ )
285
+ else:
286
+ reconstructed_audio = reconstructed_ch1
287
+
288
+ if multiband_ensemble:
289
+ low, _ = librosa.load(input_file, sr=48000, mono=False)
290
+ output = self._match_array_shapes(
291
+ reconstructed_audio[0].T, low
292
+ )
293
+ crossover_freq = input_cutoff - 1000
294
+ low = self._lr_filter(
295
+ low.T, crossover_freq, "lowpass", order=10
296
+ )
297
+ high = self._lr_filter(
298
+ output.T, crossover_freq, "highpass", order=10
299
+ )
300
+ high = self._lr_filter(
301
+ high, 23000, "lowpass", order=2
302
+ )
303
+ output = low + high
304
+ else:
305
+ output = reconstructed_audio[0]
306
+
307
+ return output
308
+
309
+ def predict(
310
+ self,
311
+ input_file,
312
+ output_folder,
313
+ ddim_steps=50,
314
+ guidance_scale=3.5,
315
+ overlap=0.04,
316
+ chunk_size=10.24,
317
+ seed=None,
318
+ multiband_ensemble=True,
319
+ input_cutoff=14000,
320
+ ):
321
+ """
322
+ Upscales the audio and saves the result.
323
+
324
+ Args:
325
+ input_file (str): Path to the input audio file.
326
+ output_folder (str): Path to the output folder.
327
+ ddim_steps (int, optional): Number of inference steps. Defaults to 50.
328
+ guidance_scale (float, optional): Scale for classifier-free guidance. Defaults to 3.5.
329
+ overlap (float, optional): Overlap between chunks. Defaults to 0.04.
330
+ chunk_size (float, optional): Chunk size in seconds. Defaults to 10.24.
331
+ seed (int, optional): Random seed. Defaults to None.
332
+ multiband_ensemble (bool, optional): Whether to use multiband ensemble. Defaults to True.
333
+ input_cutoff (int, optional): Input cutoff frequency for multiband ensemble. Defaults to 14000.
334
+ """
335
+ if seed == 0:
336
+ seed = random.randint(0, 2**32 - 1)
337
+
338
+ os.makedirs(output_folder, exist_ok=True)
339
+ waveform = self._process_audio(
340
+ input_file,
341
+ chunk_size=chunk_size,
342
+ overlap=overlap,
343
+ seed=seed,
344
+ guidance_scale=guidance_scale,
345
+ ddim_steps=ddim_steps,
346
+ multiband_ensemble=multiband_ensemble,
347
+ input_cutoff=input_cutoff,
348
+ )
349
+
350
+ filename = os.path.splitext(os.path.basename(input_file))[0]
351
+ output_file = f"{output_folder}/SR_{filename}.wav"
352
+ sf.write(output_file, data=waveform, samplerate=48000, subtype="PCM_16")
353
+ print(f"File created: {output_file}")
354
+
355
+ # Cleanup
356
+ del waveform
357
+ gc.collect()
358
+ torch.cuda.empty_cache()
359
+ return output_file
360
+
361
+
362
 
363
  @spaces.GPU(duration=300)
364
  def inference(audio_file, model_name, guidance_scale, ddim_steps, seed):
 
382
  ddim_steps=ddim_steps
383
  )
384
 
385
+
386
+
387
+ return (48000, waveform)
388
+
389
+
390
+ def upscale_audio(
391
+ input_file,
392
+ output_folder,
393
+ ddim_steps=20,
394
+ guidance_scale=3.5,
395
+ overlap=0.04,
396
+ chunk_size=10.24,
397
+ seed=0,
398
+ multiband_ensemble=True,
399
+ input_cutoff=14000,
400
+ ):
401
+ """
402
+ Upscales the audio using the AudioSR model.
403
+
404
+ Args:
405
+ input_file (str): Path to the input audio file.
406
+ output_folder (str): Path to the output folder.
407
+ ddim_steps (int, optional): Number of inference steps. Defaults to 20.
408
+ guidance_scale (float, optional): Scale for classifier-free guidance. Defaults to 3.5.
409
+ overlap (float, optional): Overlap between chunks. Defaults to 0.04.
410
+ chunk_size (float, optional): Chunk size in seconds. Defaults to 10.24.
411
+ seed (int, optional): Random seed. Defaults to 0.
412
+ multiband_ensemble (bool, optional): Whether to use multiband ensemble. Defaults to True.
413
+ input_cutoff (int, optional): Input cutoff frequency for multiband ensemble. Defaults to 14000.
414
+
415
+ Returns:
416
+ tuple: Upscaled audio data and sample rate.
417
+ """
418
+ upscaler = AudioUpscaler()
419
+ upscaler.setup()
420
+
421
+ output_file = upscaler.predict(
422
+ input_file,
423
+ output_folder,
424
+ ddim_steps=ddim_steps,
425
+ guidance_scale=guidance_scale,
426
+ overlap=overlap,
427
+ chunk_size=chunk_size,
428
+ seed=seed,
429
+ multiband_ensemble=multiband_ensemble,
430
+ input_cutoff=input_cutoff,
431
+ )
432
+
433
  if torch.cuda.is_avaible():
434
  torch.cuda.empty_cache()
435
 
436
  gc.collect()
437
 
438
+ return output_file
439
+
440
+ os.getcwd()
441
+ gr.Textbox
442
+
443
+ iface = gr.Interface(
444
+ fn=upscale_audio,
445
+ inputs=[
446
+ gr.Audio(type="filepath", label="Input Audio"),
447
+ gr.Textbox(".",label="Out-dir"),
448
+ gr.Slider(10, 500, value=20, step=1, label="DDIM Steps", info="Number of inference steps (quality/speed)"),
449
+ gr.Slider(1.0, 20.0, value=3.5, step=0.1, label="Guidance Scale", info="Guidance scale (creativity/fidelity)"),
450
+ gr.Slider(0.0, 0.5, value=0.04, step=0.01, label="Overlap (s)", info="Overlap between chunks (smooth transitions)"),
451
+ gr.Slider(5.12, 20.48, value=5.12, step=0.64, label="Chunk Size (s)", info="Chunk size (memory/artifact balance)"),
452
+ gr.Number(value=0, precision=0, label="Seed", info="Random seed (0 for random)"),
453
+ gr.Checkbox(label="Multiband Ensemble", value=False, info="Enhance high frequencies"),
454
+ gr.Slider(500, 15000, value=9000, step=500, label="Crossover Frequency (Hz)", info="For multiband processing", visible=True)
455
+ ],
456
+
457
+
458
 
459
  iface = gr.Interface(
460
  fn=inference,