Porjaz commited on
Commit
628e831
1 Parent(s): 519fba9

Update custom_interface_app.py

Browse files
Files changed (1) hide show
  1. custom_interface_app.py +43 -3
custom_interface_app.py CHANGED
@@ -176,9 +176,7 @@ class ASR(Pretrained):
176
  rel_length = torch.tensor([1.0]).to(device)
177
  outputs = self.encode_batch_w2v2(device, batch, rel_length)
178
  return [" ".join(out) for out in outputs]
179
-
180
-
181
-
182
 
183
  def classify_file_whisper_mkd(self, waveform, device):
184
  # Load the audio file
@@ -231,6 +229,48 @@ class ASR(Pretrained):
231
  return outputs
232
 
233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
  def classify_file_whisper(self, waveform, pipe, device):
236
  # waveform, sr = librosa.load(path, sr=16000)
 
176
  rel_length = torch.tensor([1.0]).to(device)
177
  outputs = self.encode_batch_w2v2(device, batch, rel_length)
178
  return [" ".join(out) for out in outputs]
179
+
 
 
180
 
181
  def classify_file_whisper_mkd(self, waveform, device):
182
  # Load the audio file
 
229
  return outputs
230
 
231
 
232
+ def classify_file_whisper_mkd_streaming(self, waveform, device):
233
+ # Load the audio file
234
+ # waveform, sr = librosa.load(path, sr=16000)
235
+
236
+ # Get audio length in seconds
237
+ audio_length = len(waveform) / 16000
238
+
239
+ if audio_length >= 30:
240
+ # split audio every 30 seconds
241
+ segments = []
242
+ max_duration = 30 * 16000 # Maximum segment duration in samples (20 seconds)
243
+ num_segments = int(np.ceil(len(waveform) / max_duration))
244
+ start = 0
245
+ for i in range(num_segments):
246
+ end = start + max_duration
247
+ if end > len(waveform):
248
+ end = len(waveform)
249
+ segment_part = waveform[start:end]
250
+ segment_len = len(segment_part) / 16000
251
+ if segment_len < 1:
252
+ continue
253
+ segments.append(segment_part)
254
+ start = end
255
+
256
+ for segment in segments:
257
+ segment_tensor = torch.tensor(segment).to(device)
258
+
259
+ # Fake a batch for the segment
260
+ batch = segment_tensor.unsqueeze(0).to(device)
261
+ rel_length = torch.tensor([1.0]).to(device)
262
+
263
+ # Pass the segment through the ASR model
264
+ segment_output = self.encode_batch_whisper(device, batch, rel_length)
265
+ yield segment_output
266
+ else:
267
+ waveform = torch.tensor(waveform).to(device)
268
+ waveform = waveform.to(device)
269
+ batch = waveform.unsqueeze(0)
270
+ rel_length = torch.tensor([1.0]).to(device)
271
+ outputs = self.encode_batch_whisper(device, batch, rel_length)
272
+ yield outputs
273
+
274
 
275
  def classify_file_whisper(self, waveform, pipe, device):
276
  # waveform, sr = librosa.load(path, sr=16000)