d22cs051 commited on
Commit
6aca51e
1 Parent(s): 6ffa81b

Add application file

Browse files
Files changed (3) hide show
  1. .gitignore +2 -0
  2. app.py +95 -0
  3. requirements.txt +11 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.ckpt
2
+ *__
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import librosa
3
+ import torch
4
+ import soundfile as sf
5
+ from speechbrain.inference.separation import SepformerSeparation as separator
6
+ import torchaudio, torchmetrics, torch
7
+
8
+
9
+ # defineing model class
10
+ class SepformerFineTune(torch.nn.Module):
11
+ def __init__(self, model):
12
+ super(SepformerFineTune, self).__init__()
13
+ self.model = model
14
+ # disabling gradient computation
15
+ for parms in self.model.parameters():
16
+ parms.requires_grad = False
17
+
18
+ # enable gradient computation for the last layer
19
+ named_layers = dict(model.named_modules())
20
+ for name, layer in named_layers.items():
21
+ # print(f"Name: {name}, Layer: {layer}")
22
+ if name == "mods.masknet.output.0":
23
+ for param in layer.parameters():
24
+ param.requires_grad = True
25
+ if name == "mods.masknet.output_gate":
26
+ for param in layer.parameters():
27
+ param.requires_grad = True
28
+
29
+
30
+ # printing all tranble parameters
31
+ # for model_name, model_params in model.named_parameters():
32
+ # print(f"Model Layer Name: {model_name}, Model Params: {model_params.requires_grad}")
33
+ def forward(self, mix):
34
+ est_sources = self.model.separate_batch(mix)
35
+ return est_sources[:,:,0], est_sources[:,:,1] # NOTE: Working with 2 sources ONLY
36
+
37
+
38
+ class SourceSeparationApp:
39
+ def __init__(self, model_path,device="cpu"):
40
+ self.model = self.load_model(model_path)
41
+ self.device = device
42
+
43
+ def load_model(self, model_path):
44
+ model = separator.from_hparams(source="speechbrain/sepformer-wsj03mix", savedir='pretrained_models/sepformer-wsj03mix', run_opts={"device": device})
45
+ checkpoint = torch.load(model_path)
46
+ fine_tuned_model = SepformerFineTune(model)
47
+ fine_tuned_model.load_state_dict(checkpoint["model"])
48
+ return fine_tuned_model
49
+
50
+ def separate_sources(self, audio_file):
51
+ # Load input audio
52
+ # print(f"[LOG] Audio file: {audio_file}")
53
+ input_audio_tensor, sr = audio_file[1], audio_file[0]
54
+
55
+ if self.model is None:
56
+ return "Error: Model not loaded."
57
+
58
+ # sending input audio to PyTorch tensor
59
+ input_audio_tensor = torch.tensor(input_audio_tensor,dtype=torch.float).unsqueeze(0)
60
+ input_audio_tensor = input_audio_tensor.to(self.device)
61
+
62
+ # Source separation using the loaded model
63
+ self.model.to(self.device)
64
+ self.model.eval()
65
+ with torch.inference_mode():
66
+ # print(f"[LOG] mix shape: {mix.shape}, s1 shape: {s1.shape}, s2 shape: {s2.shape}, noise shape: {noise.shape}")
67
+ source1,source2 = self.model(input_audio_tensor)
68
+
69
+
70
+ # Save separated sources
71
+ sf.write("source1.wav", source1.squeeze().cpu().numpy(), sr)
72
+ sf.write("source2.wav", source2.squeeze().cpu().numpy(), sr)
73
+
74
+ return "Separation completed", "source1.wav", "source2.wav"
75
+
76
+ def run(self):
77
+ audio_input = gr.Audio(label="Upload or record audio")
78
+ output_text = gr.Label(label="Status:")
79
+ audio_output1 = gr.Audio(label="Source 1", type="filepath",)
80
+ audio_output2 = gr.Audio(label="Source 2", type="filepath",)
81
+ gr.Interface(
82
+ fn=self.separate_sources,
83
+ inputs=audio_input,
84
+ outputs=[output_text, audio_output1, audio_output2],
85
+ title="Audio Source Separation",
86
+ description="Separate sources from a mixed audio signal.",
87
+ allow_flagging=False
88
+ ).launch()
89
+
90
+
91
+ if __name__ == "__main__":
92
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
93
+ model_path = "fine_tuned_sepformer-wsj03mix-7sec.ckpt" # Replace with your model path
94
+ app = SourceSeparationApp(model_path, device=device)
95
+ app.run()
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ soundfile>=0.10.3.post1
2
+ tqdm>=4.46.1
3
+ pysndfx>=0.3.6
4
+ pandas>=1.0.1
5
+ numpy>=1.18.1
6
+ pyloudnorm>=0.1.0
7
+ scipy>=1.4.1
8
+ matplotlib>=3.1.3
9
+ torch==2.2.1
10
+ torchaudio==2.2.1
11
+ speechbrain