wrice commited on
Commit
76842df
1 Parent(s): 6b54c4d

Add UNet1DModel

Browse files
Files changed (1) hide show
  1. app.py +10 -3
app.py CHANGED
@@ -3,15 +3,22 @@ import gradio as gr
3
  import numpy as np
4
  import torch
5
  import torchaudio
6
- from denoisers import WaveUNetModel
7
  from tqdm import tqdm
8
 
9
- MODELS = ["wrice/waveunet-vctk-48khz", "wrice/waveunet-vctk-24khz"]
 
 
 
 
10
 
11
 
12
  def denoise(model_name, inputs):
13
  """Denoise audio."""
14
- model = WaveUNetModel.from_pretrained(model_name)
 
 
 
15
  sr, audio = inputs
16
  audio = torch.from_numpy(audio)[None]
17
  audio = audio / 32768.0
 
3
  import numpy as np
4
  import torch
5
  import torchaudio
6
+ from denoisers import UNet1DModel, WaveUNetModel
7
  from tqdm import tqdm
8
 
9
+ MODELS = [
10
+ "wrice/unet1d-vctk-48khz",
11
+ "wrice/waveunet-vctk-48khz",
12
+ "wrice/waveunet-vctk-24khz",
13
+ ]
14
 
15
 
16
  def denoise(model_name, inputs):
17
  """Denoise audio."""
18
+ if "unet1d" in model_name:
19
+ model = UNet1DModel.from_pretrained(model_name)
20
+ else:
21
+ model = WaveUNetModel.from_pretrained(model_name)
22
  sr, audio = inputs
23
  audio = torch.from_numpy(audio)[None]
24
  audio = audio / 32768.0