Porjaz commited on
Commit
4b65bb8
1 Parent(s): 134a152

Update custom_interface.py

Browse files
Files changed (1) hide show
  1. custom_interface.py +8 -132
custom_interface.py CHANGED
@@ -1,7 +1,6 @@
1
  import torch
2
  from speechbrain.inference.interfaces import Pretrained
3
  import librosa
4
- import numpy as np
5
 
6
 
7
  class ASR(Pretrained):
@@ -84,139 +83,16 @@ class ASR(Pretrained):
84
  seq.append(token)
85
  output = []
86
  return seq
87
-
88
-
89
- # def classify_file(self, path):
90
- # # waveform = self.load_audio(path)
91
- # waveform, sr = librosa.load(path, sr=16000)
92
- # waveform = torch.tensor(waveform)
93
-
94
- # # Fake a batch:
95
- # batch = waveform.unsqueeze(0)
96
- # rel_length = torch.tensor([1.0])
97
- # outputs = self.encode_batch(batch, rel_length)
98
-
99
- # return outputs
100
 
101
 
102
  def classify_file(self, path, device):
103
- # Load the audio file
104
- # path = "long_sample.wav"
105
  waveform, sr = librosa.load(path, sr=16000)
106
 
107
- # Get audio length in seconds
108
- audio_length = len(waveform) / sr
109
-
110
- if audio_length >= 20:
111
- print(f"Audio is too long ({audio_length:.2f} seconds), splitting into segments")
112
- # Detect non-silent segments
113
- non_silent_intervals = librosa.effects.split(waveform, top_db=20) # Adjust top_db for sensitivity
114
-
115
- segments = []
116
- current_segment = []
117
- current_length = 0
118
- max_duration = 20 * sr # Maximum segment duration in samples (20 seconds)
119
-
120
- for interval in non_silent_intervals:
121
- start, end = interval
122
- segment_part = waveform[start:end]
123
-
124
- # If adding the next part exceeds max duration, store the segment and start a new one
125
- if current_length + len(segment_part) > max_duration:
126
- segments.append(np.concatenate(current_segment))
127
- current_segment = []
128
- current_length = 0
129
-
130
- current_segment.append(segment_part)
131
- current_length += len(segment_part)
132
-
133
- # Append the last segment if it's not empty
134
- if current_segment:
135
- segments.append(np.concatenate(current_segment))
136
-
137
- # Process each segment
138
- outputs = []
139
- for i, segment in enumerate(segments):
140
- print(f"Processing segment {i + 1}/{len(segments)}, length: {len(segment) / sr:.2f} seconds")
141
-
142
- segment_tensor = torch.tensor(segment).to(device)
143
-
144
- # Fake a batch for the segment
145
- batch = segment_tensor.unsqueeze(0).to(device)
146
- rel_length = torch.tensor([1.0]).to(device) # Adjust if necessary
147
-
148
- # Pass the segment through the ASR model
149
- segment_output = self.encode_batch(device, batch, rel_length)
150
- yield segment_output
151
- else:
152
- waveform = torch.tensor(waveform).to(device)
153
- waveform = waveform.to(device)
154
- # Fake a batch:
155
- batch = waveform.unsqueeze(0)
156
- rel_length = torch.tensor([1.0]).to(device)
157
- outputs = self.encode_batch(device, batch, rel_length)
158
- yield outputs
159
-
160
-
161
- def classify_file_whisper(self, path, pipe, device):
162
- waveform, sr = librosa.load(path, sr=16000)
163
- transcription = pipe(waveform, generate_kwargs={"language": "macedonian"})["text"]
164
- return transcription
165
-
166
-
167
- def classify_file_mms(self, path, processor, model, device):
168
- # Load the audio file
169
- waveform, sr = librosa.load(path, sr=16000)
170
-
171
- # Get audio length in seconds
172
- audio_length = len(waveform) / sr
173
-
174
- if audio_length >= 20:
175
- print(f"MMS Audio is too long ({audio_length:.2f} seconds), splitting into segments")
176
- # Detect non-silent segments
177
- non_silent_intervals = librosa.effects.split(waveform, top_db=20) # Adjust top_db for sensitivity
178
-
179
- segments = []
180
- current_segment = []
181
- current_length = 0
182
- max_duration = 20 * sr # Maximum segment duration in samples (20 seconds)
183
-
184
-
185
- for interval in non_silent_intervals:
186
- start, end = interval
187
- segment_part = waveform[start:end]
188
-
189
- # If adding the next part exceeds max duration, store the segment and start a new one
190
- if current_length + len(segment_part) > max_duration:
191
- segments.append(np.concatenate(current_segment))
192
- current_segment = []
193
- current_length = 0
194
-
195
- current_segment.append(segment_part)
196
- current_length += len(segment_part)
197
-
198
- # Append the last segment if it's not empty
199
- if current_segment:
200
- segments.append(np.concatenate(current_segment))
201
-
202
- # Process each segment
203
- outputs = []
204
- for i, segment in enumerate(segments):
205
- print(f"MMS Processing segment {i + 1}/{len(segments)}, length: {len(segment) / sr:.2f} seconds")
206
-
207
- segment_tensor = torch.tensor(segment).to(device)
208
-
209
- # Pass the segment through the ASR model
210
- inputs = processor(segment_tensor, sampling_rate=16_000, return_tensors="pt").to(device)
211
- outputs = model(**inputs).logits
212
- ids = torch.argmax(outputs, dim=-1)[0]
213
- segment_output = processor.decode(ids)
214
- yield segment_output
215
- else:
216
- waveform = torch.tensor(waveform).to(device)
217
- inputs = processor(waveform, sampling_rate=16_000, return_tensors="pt").to(device)
218
- outputs = model(**inputs).logits
219
- ids = torch.argmax(outputs, dim=-1)[0]
220
- transcription = processor.decode(ids)
221
- yield transcription
222
-
 
1
  import torch
2
  from speechbrain.inference.interfaces import Pretrained
3
  import librosa
 
4
 
5
 
6
  class ASR(Pretrained):
 
83
  seq.append(token)
84
  output = []
85
  return seq
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
 
88
  def classify_file(self, path, device):
 
 
89
  waveform, sr = librosa.load(path, sr=16000)
90
 
91
+ waveform = torch.tensor(waveform).to(device)
92
+ waveform = waveform.to(device)
93
+ # Fake a batch:
94
+ batch = waveform.unsqueeze(0)
95
+ rel_length = torch.tensor([1.0]).to(device)
96
+ outputs = self.encode_batch(device, batch, rel_length)
97
+ outputs = " ".join(outputs[0])
98
+ return outputs