Shokoufehhh commited on
Commit
b8ab735
·
verified ·
1 Parent(s): b3a65d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -4
app.py CHANGED
@@ -1,7 +1,72 @@
 
 
1
  import gradio as gr
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
  import gradio as gr
4
+ from sgmse.model import ScoreModel
5
+ from sgmse.util.other import pad_spec
6
+ import time # Import the time module
7
+ import os
8
 
9
+ # Define parameters based on the configuration in enhancement.py
10
+ args = {
11
+ "test_dir": "./test_data", # example directory, adjust as needed
12
+ "enhanced_dir": "./enhanced_data", # example directory, adjust as needed
13
+ "ckpt": "https://huggingface.co/sp-uhh/speech-enhancement-sgmse/resolve/main/train_vb_29nqe0uh_epoch%3D115.ckpt",
14
+ "corrector": "ald",
15
+ "corrector_steps": 1,
16
+ "snr": 0.5,
17
+ "N": 30,
18
+ "device": "cuda" if torch.cuda.is_available() else "cpu"
19
+ }
20
 
21
+ # Ensure the model is loaded to the correct device
22
+ model = ScoreModel.load_from_checkpoint(args["ckpt"]).to(args["device"])
23
+
24
+ def enhance_speech(audio_file):
25
+ start_time = time.time() # Start the timer
26
+
27
+ # Load and process the audio file
28
+ y, sr = torchaudio.load(audio_file) # Gradio passes the file path
29
+ print(f"Loaded audio in {time.time() - start_time:.2f}s")
30
+ T_orig = y.size(1)
31
+
32
+ # Normalize
33
+ norm_factor = y.abs().max()
34
+ y = y / norm_factor
35
+
36
+ # Prepare DNN input
37
+ Y = torch.unsqueeze(model._forward_transform(model._stft(y.to(args["device"]))), 0)
38
+ print(f"Transformed input in {time.time() - start_time:.2f}s")
39
+
40
+ Y = pad_spec(Y, mode="zero_pad") # Use "zero_pad" mode for padding
41
+
42
+ # Reverse sampling
43
+ sampler = model.get_pc_sampler(
44
+ 'reverse_diffusion', args["corrector"], Y.to(args["device"]),
45
+ N=args["N"], corrector_steps=args["corrector_steps"], snr=args["snr"]
46
+ )
47
+ sample, _ = sampler()
48
+
49
+ # Backward transform in time domain
50
+ x_hat = model.to_audio(sample.squeeze(), T_orig)
51
+
52
+ # Renormalize
53
+ x_hat = x_hat * norm_factor
54
+
55
+ # Create a temporary path for saving the enhanced audio in Hugging Face Space
56
+ output_file = "/tmp/enhanced_output.wav" # Use a temporary directory
57
+ torchaudio.save(output_file, x_hat.cpu(), sr)
58
+
59
+ print(f"Processed audio in {time.time() - start_time:.2f}s")
60
+
61
+ # Return the path to the enhanced file for Gradio to handle
62
+ return output_file
63
+
64
+ # Gradio interface setup
65
+ inputs = gr.Audio(label="Input Audio", type="filepath") # Adjusted to 'filepath'
66
+ outputs = gr.Audio(label="Enhanced Audio", type="filepath") # Output as filepath
67
+ title = "Speech Enhancement using SGMSE"
68
+ description = "This Gradio demo uses the SGMSE model for speech enhancement. Upload your audio file to enhance it."
69
+ article = "<p style='text-align: center'><a href='https://huggingface.co/SP-UHH/speech-enhancement-sgmse' target='_blank'>Model Card</a></p>"
70
+
71
+ # Launch the Gradio interface
72
+ gr.Interface(fn=enhance_speech, inputs=inputs, outputs=outputs, title=title, description=description, article=article).launch()