wrice commited on
Commit
6c4aae6
1 Parent(s): b7217bd

add app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -2
app.py CHANGED
@@ -1,15 +1,22 @@
 
 
1
  import gradio as gr
2
  import torch
3
  import torchaudio
4
  from denoisers import WaveUNetModel
5
 
 
6
  MODEL = WaveUNetModel.from_pretrained("wrice/waveunet-vctk-24khz")
7
 
 
8
  def denoise(inputs):
9
  sr, audio = inputs
10
  audio = torch.from_numpy(audio)[None]
11
  audio = audio / 32768.0
12
 
 
 
 
13
  if sr != MODEL.config.sample_rate:
14
  audio = torchaudio.functional.resample(audio, sr, MODEL.config.sample_rate)
15
 
@@ -20,14 +27,18 @@ def denoise(inputs):
20
 
21
  clean = []
22
  for i in range(0, padded.shape[-1], chunk_size):
23
- audio_chunk = padded[:, i:i + chunk_size]
24
  with torch.no_grad():
25
  clean_chunk = MODEL(audio_chunk[None]).logits
26
  clean.append(clean_chunk.squeeze(0))
27
 
28
- denoised = torch.concat(clean)[:, :audio.shape[-1]].squeeze().clamp(-1.0, 1.0)
29
  denoised = (denoised * 32767.0).numpy().astype("int16")
 
 
 
30
  return MODEL.config.sample_rate, denoised
31
 
 
32
  iface = gr.Interface(fn=denoise, inputs="audio", outputs="audio")
33
  iface.launch()
 
1
+ from logging import getLogger
2
+
3
  import gradio as gr
4
  import torch
5
  import torchaudio
6
  from denoisers import WaveUNetModel
7
 
8
+ LOGGER = getLogger(__name__)
9
  MODEL = WaveUNetModel.from_pretrained("wrice/waveunet-vctk-24khz")
10
 
11
+
12
  def denoise(inputs):
13
  sr, audio = inputs
14
  audio = torch.from_numpy(audio)[None]
15
  audio = audio / 32768.0
16
 
17
+ LOGGER.info(f"Audio shape: {audio.shape}")
18
+ LOGGER.info(f"Sample rate: {sr}")
19
+
20
  if sr != MODEL.config.sample_rate:
21
  audio = torchaudio.functional.resample(audio, sr, MODEL.config.sample_rate)
22
 
 
27
 
28
  clean = []
29
  for i in range(0, padded.shape[-1], chunk_size):
30
+ audio_chunk = padded[:, i : i + chunk_size]
31
  with torch.no_grad():
32
  clean_chunk = MODEL(audio_chunk[None]).logits
33
  clean.append(clean_chunk.squeeze(0))
34
 
35
+ denoised = torch.concat(clean)[:, : audio.shape[-1]].squeeze().clamp(-1.0, 1.0)
36
  denoised = (denoised * 32767.0).numpy().astype("int16")
37
+
38
+ LOGGER.info(f"Denoised shape: {denoised.shape}")
39
+
40
  return MODEL.config.sample_rate, denoised
41
 
42
+
43
  iface = gr.Interface(fn=denoise, inputs="audio", outputs="audio")
44
  iface.launch()