Update custom_interface_app.py
Browse files- custom_interface_app.py +43 -3
custom_interface_app.py
CHANGED
@@ -176,9 +176,7 @@ class ASR(Pretrained):
|
|
176 |
rel_length = torch.tensor([1.0]).to(device)
|
177 |
outputs = self.encode_batch_w2v2(device, batch, rel_length)
|
178 |
return [" ".join(out) for out in outputs]
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
|
183 |
def classify_file_whisper_mkd(self, waveform, device):
|
184 |
# Load the audio file
|
@@ -231,6 +229,48 @@ class ASR(Pretrained):
|
|
231 |
return outputs
|
232 |
|
233 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
|
235 |
def classify_file_whisper(self, waveform, pipe, device):
|
236 |
# waveform, sr = librosa.load(path, sr=16000)
|
|
|
176 |
rel_length = torch.tensor([1.0]).to(device)
|
177 |
outputs = self.encode_batch_w2v2(device, batch, rel_length)
|
178 |
return [" ".join(out) for out in outputs]
|
179 |
+
|
|
|
|
|
180 |
|
181 |
def classify_file_whisper_mkd(self, waveform, device):
|
182 |
# Load the audio file
|
|
|
229 |
return outputs
|
230 |
|
231 |
|
232 |
+
def classify_file_whisper_mkd_streaming(self, waveform, device):
|
233 |
+
# Load the audio file
|
234 |
+
# waveform, sr = librosa.load(path, sr=16000)
|
235 |
+
|
236 |
+
# Get audio length in seconds
|
237 |
+
audio_length = len(waveform) / 16000
|
238 |
+
|
239 |
+
if audio_length >= 30:
|
240 |
+
# split audio every 30 seconds
|
241 |
+
segments = []
|
242 |
+
max_duration = 30 * 16000 # Maximum segment duration in samples (20 seconds)
|
243 |
+
num_segments = int(np.ceil(len(waveform) / max_duration))
|
244 |
+
start = 0
|
245 |
+
for i in range(num_segments):
|
246 |
+
end = start + max_duration
|
247 |
+
if end > len(waveform):
|
248 |
+
end = len(waveform)
|
249 |
+
segment_part = waveform[start:end]
|
250 |
+
segment_len = len(segment_part) / 16000
|
251 |
+
if segment_len < 1:
|
252 |
+
continue
|
253 |
+
segments.append(segment_part)
|
254 |
+
start = end
|
255 |
+
|
256 |
+
for segment in segments:
|
257 |
+
segment_tensor = torch.tensor(segment).to(device)
|
258 |
+
|
259 |
+
# Fake a batch for the segment
|
260 |
+
batch = segment_tensor.unsqueeze(0).to(device)
|
261 |
+
rel_length = torch.tensor([1.0]).to(device)
|
262 |
+
|
263 |
+
# Pass the segment through the ASR model
|
264 |
+
segment_output = self.encode_batch_whisper(device, batch, rel_length)
|
265 |
+
yield segment_output
|
266 |
+
else:
|
267 |
+
waveform = torch.tensor(waveform).to(device)
|
268 |
+
waveform = waveform.to(device)
|
269 |
+
batch = waveform.unsqueeze(0)
|
270 |
+
rel_length = torch.tensor([1.0]).to(device)
|
271 |
+
outputs = self.encode_batch_whisper(device, batch, rel_length)
|
272 |
+
yield outputs
|
273 |
+
|
274 |
|
275 |
def classify_file_whisper(self, waveform, pipe, device):
|
276 |
# waveform, sr = librosa.load(path, sr=16000)
|