avans06 commited on
Commit
ec7cc5c
·
1 Parent(s): 8e80889

Added the "Whisper Segments Filter" option along with some configuration adjustments.

Browse files

1. Added the Whisper Segments Filter option, which, when enabled, can effectively improve the whisper hallucination, especially for the large-v3 version of the whisper model.

2. Set the Word Timestamps option to enable by default.

3. The textarea for outputting Transcription and Segments now supports displaying a scrollbar.
_________

## Whisper Filter options
**This is an experimental feature and may potentially filter out correct transcription results.**

when enabled, can effectively improve the whisper hallucination, especially for the large-v3 version of the whisper model.

Observations for transcriptions:
1. duration: calculated by subtracting start from end, it might indicate hallucinated results when inversely proportional to text length.
2. segment_last: the last result for each segment during VAD transcription has a certain probability of being a hallucinated result.
3. avg_logprob: average log probability, ranging from logprob_threshold (default: -1) to 0, is better when a larger value. A value lower than -0.9 might suggest a poor result.
4. compression_ratio: gzip compression ratio, ranging from 0 to compression_ratio_threshold (default: 2.4), a higher positive value is preferable. If it is lower than 0.9, it might indicate suboptimal results.
5. no_speech_prob: no_speech(<|nospeech|> token) probability, ranging from 0 to no_speech_threshold (default: 0.6), a smaller positive value is preferable. If it exceeds 0.1, it might suggest suboptimal results.

Four sets of filtering conditions have now been established, utilizing text length, duration length, as well as the avg_logprob, compression_ratio, and no_speech_prob parameters returned by Whisper.
1. avg_logprob < -0.9
2. (durationLen < 1.5 || segment_last), textLen > 5, avg_logprob < -0.4, no_speech_prob > 0.5
3. (durationLen < 1.5 || segment_last), textLen > 5, avg_logprob < -0.4, no_speech_prob > 0.07, compression_ratio < 0.9
4. (durationLen < 1.5 || segment_last), compression_ratio < 0.9, no_speech_prob > 0.1

Files changed (5) hide show
  1. app.py +139 -19
  2. config.json5 +21 -1
  3. docs/options.md +18 -0
  4. src/config.py +8 -3
  5. src/vad.py +5 -1
app.py CHANGED
@@ -1,7 +1,7 @@
1
  from datetime import datetime
2
  import json
3
  import math
4
- from typing import Iterator, Union, List
5
  import argparse
6
 
7
  from io import StringIO
@@ -244,14 +244,38 @@ class WhisperTranscriber:
244
  microphoneData: str = decodeOptions.pop("microphoneData")
245
  task: str = decodeOptions.pop("task")
246
 
247
- vad: str = decodeOptions.pop("vad")
248
- vadMergeWindow: float = decodeOptions.pop("vadMergeWindow")
249
- vadMaxMergeSize: float = decodeOptions.pop("vadMaxMergeSize")
250
- vadPadding: float = decodeOptions.pop("vadPadding", self.app_config.vad_padding)
251
- vadPromptWindow: float = decodeOptions.pop("vadPromptWindow", self.app_config.vad_prompt_window)
252
- vadInitialPromptMode: str = decodeOptions.pop("vadInitialPromptMode", self.app_config.vad_initial_prompt_mode)
253
- self.vad_process_timeout: float = decodeOptions.pop("vadPocessTimeout", self.vad_process_timeout)
254
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  diarization: bool = decodeOptions.pop("diarization", False)
256
  diarization_speakers: int = decodeOptions.pop("diarization_speakers", 2)
257
  diarization_min_speakers: int = decodeOptions.pop("diarization_min_speakers", 1)
@@ -388,6 +412,10 @@ class WhisperTranscriber:
388
 
389
  # Transcribe
390
  result = self.transcribe_file(model, source.source_path, whisperLangCode, task, vadOptions, scaled_progress_listener, **decodeOptions)
 
 
 
 
391
  if translationModel is not None and whisperLang is None and result["language"] is not None and len(result["language"]) > 0:
392
  whisperLang = get_lang_from_whisper_code(result["language"])
393
  translationModel.whisperLang = whisperLang
@@ -466,8 +494,8 @@ class WhisperTranscriber:
466
  zip.write(download_file, arcname=zip_file_name)
467
 
468
  download.insert(0, downloadAllPath)
469
-
470
- return download, text, vtt
471
 
472
  finally:
473
  # Cleanup source
@@ -481,10 +509,10 @@ class WhisperTranscriber:
481
  print("Error deleting temporary source file: \n" + source.source_path + ", \n" + str(e))
482
 
483
  except ExceededMaximumDuration as e:
484
- return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
485
  except Exception as e:
486
  print(traceback.format_exc())
487
- return [], ("Error occurred during transcribe: " + str(e)), traceback.format_exc()
488
 
489
 
490
  def transcribe_file(self, model: AbstractWhisperContainer, audio_path: str, languageCode: str, task: str = None,
@@ -549,7 +577,13 @@ class WhisperTranscriber:
549
  else:
550
  # Default VAD
551
  result = whisperCallable.invoke(audio_path, 0, None, None, progress_listener=progressListener)
552
-
 
 
 
 
 
 
553
  # Diarization
554
  if self.diarization and self.diarization_kwargs:
555
  print("Diarizing ", audio_path)
@@ -564,6 +598,68 @@ class WhisperTranscriber:
564
  result = self.diarization.mark_speakers(diarization_result, result)
565
 
566
  return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
567
 
568
  def _create_progress_listener(self, progress: gr.Progress):
569
  if (progress is None):
@@ -874,6 +970,11 @@ def create_ui(app_config: ApplicationConfig):
874
  gr.Checkbox(label="Word Timestamps", value=app_config.word_timestamps, elem_id="word_timestamps"),
875
  gr.Checkbox(label="Word Timestamps - Highlight Words", value=app_config.highlight_words, elem_id="highlight_words"),
876
  }
 
 
 
 
 
877
 
878
  has_diarization_libs = Diarization.has_libraries()
879
 
@@ -889,15 +990,30 @@ def create_ui(app_config: ApplicationConfig):
889
  }
890
 
891
  common_output = lambda : [
892
- gr.File(label="Download"),
893
- gr.Text(label="Transcription", autoscroll=False),
894
- gr.Text(label="Segments", autoscroll=False),
 
895
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
896
 
897
  is_queue_mode = app_config.queue_concurrency_count is not None and app_config.queue_concurrency_count > 0
898
 
899
  simpleInputDict = {}
900
-
901
  with gr.Blocks() as simpleTranscribe:
902
  simpleTranslateInput = gr.State(value="m2m100", elem_id = "translateInput")
903
  simpleSourceInput = gr.State(value="urlData", elem_id = "sourceInput")
@@ -939,6 +1055,8 @@ def create_ui(app_config: ApplicationConfig):
939
  simpleInputDict.update(common_vad_inputs())
940
  with gr.Accordion("Word Timestamps options", open=False):
941
  simpleInputDict.update(common_word_timestamps_inputs())
 
 
942
  with gr.Accordion("Diarization options", open=False):
943
  simpleInputDict.update(common_diarization_inputs())
944
  with gr.Accordion("Translation options", open=False):
@@ -957,7 +1075,7 @@ def create_ui(app_config: ApplicationConfig):
957
  gr.Markdown(readmeMd)
958
 
959
  simpleInputDict.update({simpleTranslateInput, simpleSourceInput})
960
- simpleSubmit.click(fn=ui.transcribe_webui_simple_progress if is_queue_mode else ui.transcribe_webui_simple,
961
  inputs=simpleInputDict, outputs=simpleOutput)
962
 
963
  fullInputDict = {}
@@ -1032,6 +1150,8 @@ def create_ui(app_config: ApplicationConfig):
1032
  gr.Number(label="Repetition Penalty", value=app_config.repetition_penalty, elem_id = "repetition_penalty"),
1033
  gr.Number(label="No Repeat Ngram Size", value=app_config.no_repeat_ngram_size, precision=0, elem_id = "no_repeat_ngram_size")
1034
  })
 
 
1035
  with gr.Accordion("Diarization options", open=False):
1036
  fullInputDict.update(common_diarization_inputs())
1037
  with gr.Accordion("Translation options", open=False):
@@ -1051,7 +1171,7 @@ def create_ui(app_config: ApplicationConfig):
1051
  fullSubmit.click(fn=ui.transcribe_webui_full_progress if is_queue_mode else ui.transcribe_webui_full,
1052
  inputs=fullInputDict, outputs=fullOutput)
1053
 
1054
- demo = gr.TabbedInterface([simpleTranscribe, fullTranscribe], tab_names=["Simple", "Full"])
1055
 
1056
  # Queue up the demo
1057
  if is_queue_mode:
 
1
  from datetime import datetime
2
  import json
3
  import math
4
+ from typing import Iterator, Union, List, Dict, Any
5
  import argparse
6
 
7
  from io import StringIO
 
244
  microphoneData: str = decodeOptions.pop("microphoneData")
245
  task: str = decodeOptions.pop("task")
246
 
247
+ vad: str = decodeOptions.pop("vad")
248
+ vadMergeWindow: float = decodeOptions.pop("vadMergeWindow")
249
+ vadMaxMergeSize: float = decodeOptions.pop("vadMaxMergeSize")
250
+ vadPadding: float = decodeOptions.pop("vadPadding", self.app_config.vad_padding)
251
+ vadPromptWindow: float = decodeOptions.pop("vadPromptWindow", self.app_config.vad_prompt_window)
252
+ vadInitialPromptMode: str = decodeOptions.pop("vadInitialPromptMode", self.app_config.vad_initial_prompt_mode)
253
+ self.vad_process_timeout: float = decodeOptions.pop("vadPocessTimeout", self.vad_process_timeout)
254
 
255
+ self.whisperSegmentsFilters: List[List] = []
256
+ inputFilter: bool = decodeOptions.pop("whisperSegmentsFilter", None)
257
+ inputFilters = []
258
+ for idx in range(0,len(self.app_config.whisper_segments_filters),1):
259
+ inputFilters.append(decodeOptions.pop(f"whisperSegmentsFilter{idx}", None))
260
+ inputFilters = filter(None, inputFilters)
261
+ if inputFilter:
262
+ for inputFilter in inputFilters:
263
+ self.whisperSegmentsFilters.append([])
264
+ self.whisperSegmentsFilters[-1].append(inputFilter)
265
+ for text in inputFilter.split(","):
266
+ result = []
267
+ subFilter = [text] if "||" not in text else [strFilter_ for strFilter_ in text.lstrip("(").rstrip(")").split("||") if strFilter_]
268
+ for string in subFilter:
269
+ conditions = [condition for condition in string.split(" ") if condition]
270
+ if len(conditions) == 1 and conditions[0] == "segment_last":
271
+ pass
272
+ elif len(conditions) == 3:
273
+ conditions[-1] = float(conditions[-1])
274
+ else:
275
+ continue
276
+ result.append(conditions)
277
+ self.whisperSegmentsFilters[-1].append(result)
278
+
279
  diarization: bool = decodeOptions.pop("diarization", False)
280
  diarization_speakers: int = decodeOptions.pop("diarization_speakers", 2)
281
  diarization_min_speakers: int = decodeOptions.pop("diarization_min_speakers", 1)
 
412
 
413
  # Transcribe
414
  result = self.transcribe_file(model, source.source_path, whisperLangCode, task, vadOptions, scaled_progress_listener, **decodeOptions)
415
+ filterLog = result.get("filterLog", None)
416
+ filterLogText = [gr.Text.update(visible=False)]
417
+ if filterLog:
418
+ filterLogText = [gr.Text.update(visible=True, value=filterLog)]
419
  if translationModel is not None and whisperLang is None and result["language"] is not None and len(result["language"]) > 0:
420
  whisperLang = get_lang_from_whisper_code(result["language"])
421
  translationModel.whisperLang = whisperLang
 
494
  zip.write(download_file, arcname=zip_file_name)
495
 
496
  download.insert(0, downloadAllPath)
497
+
498
+ return [download, text, vtt] + filterLogText
499
 
500
  finally:
501
  # Cleanup source
 
509
  print("Error deleting temporary source file: \n" + source.source_path + ", \n" + str(e))
510
 
511
  except ExceededMaximumDuration as e:
512
+ return [], "[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s", "[ERROR]", ""
513
  except Exception as e:
514
  print(traceback.format_exc())
515
+ return [], "Error occurred during transcribe: " + str(e), traceback.format_exc(), ""
516
 
517
 
518
  def transcribe_file(self, model: AbstractWhisperContainer, audio_path: str, languageCode: str, task: str = None,
 
577
  else:
578
  # Default VAD
579
  result = whisperCallable.invoke(audio_path, 0, None, None, progress_listener=progressListener)
580
+
581
+ if self.whisperSegmentsFilters:
582
+ querySegmentsResult, filterLog = self.filterSegments(result["segments"])
583
+ result["segments"] = querySegmentsResult
584
+ if filterLog:
585
+ result["filterLog"] = filterLog
586
+
587
  # Diarization
588
  if self.diarization and self.diarization_kwargs:
589
  print("Diarizing ", audio_path)
 
598
  result = self.diarization.mark_speakers(diarization_result, result)
599
 
600
  return result
601
+
602
+ def filterSegments(self, querySegments: List[Dict[str, Any]]):
603
+ try:
604
+ if not self.whisperSegmentsFilters: return
605
+
606
+ filterIdx = 0
607
+ filterLog = []
608
+ querySegmentsResult = querySegments.copy()
609
+ for idx in range(len(querySegmentsResult),0,-1):
610
+ currentID = idx - 1
611
+ querySegment = querySegmentsResult[currentID]
612
+ for segmentsFilter in self.whisperSegmentsFilters:
613
+ isFilter: bool = True
614
+ for idx, strFilter in enumerate(segmentsFilter):
615
+ if not isFilter: break
616
+ if idx == 0:
617
+ filterCondition = strFilter
618
+ continue
619
+
620
+ isFilter = True
621
+ for subFilter in strFilter:
622
+ key: str = subFilter[0]
623
+
624
+ if key == "segment_last":
625
+ isFilter = querySegment.get(key, None)
626
+ if isFilter: break
627
+ continue
628
+
629
+ sign: str = subFilter[1]
630
+ threshold: float = subFilter[2]
631
+
632
+ if key == "durationLen":
633
+ value = querySegment["end"] - querySegment["start"]
634
+ elif key == "textLen":
635
+ value = len(querySegment["text"])
636
+ else:
637
+ value = querySegment[key]
638
+
639
+ if sign == "=" or sign == "==":
640
+ isFilter = value == threshold
641
+ elif sign == ">":
642
+ isFilter = value > threshold
643
+ elif sign == ">=":
644
+ isFilter = value >= threshold
645
+ elif sign == "<":
646
+ isFilter = value < threshold
647
+ elif sign == "<=":
648
+ isFilter = value <= threshold
649
+ else: isFilter = False
650
+
651
+ if isFilter: break
652
+ if isFilter: break
653
+ if isFilter:
654
+ filterIdx += 1
655
+ filterLog.append(f"filter{filterIdx:03d} [{filterCondition}]:")
656
+ filterLog.append(f"\t{querySegment}\n")
657
+ del querySegmentsResult[currentID]
658
+
659
+ return querySegmentsResult, "\n".join(filterLog)
660
+ except Exception as e:
661
+ print(traceback.format_exc())
662
+ print("Error filter segments: " + str(e))
663
 
664
  def _create_progress_listener(self, progress: gr.Progress):
665
  if (progress is None):
 
970
  gr.Checkbox(label="Word Timestamps", value=app_config.word_timestamps, elem_id="word_timestamps"),
971
  gr.Checkbox(label="Word Timestamps - Highlight Words", value=app_config.highlight_words, elem_id="highlight_words"),
972
  }
973
+
974
+ common_segments_filter_inputs = lambda : {
975
+ gr.Checkbox(label="Whisper Segments Filter", value=app_config.whisper_segments_filter, elem_id="whisperSegmentsFilter") if idx == 0 else
976
+ gr.Text(label=f"Filter {idx}", value=filterStr, elem_id=f"whisperSegmentsFilter{idx}") for idx, filterStr in enumerate([""] + app_config.whisper_segments_filters)
977
+ }
978
 
979
  has_diarization_libs = Diarization.has_libraries()
980
 
 
990
  }
991
 
992
  common_output = lambda : [
993
+ gr.File(label="Download", elem_id="outputDownload"),
994
+ gr.Text(label="Transcription", autoscroll=False, show_copy_button=True, interactive=True, elem_id="outputTranscription", elem_classes="scroll-show"),
995
+ gr.Text(label="Segments", autoscroll=False, show_copy_button=True, interactive=True, elem_id="outputSegments", elem_classes="scroll-show"),
996
+ gr.Text(label="Filtered segment items", autoscroll=False, visible=False, show_copy_button=True, interactive=True, elem_id="outputFiltered", elem_classes="scroll-show"),
997
  ]
998
+
999
+ css = """
1000
+ .scroll-show textarea {
1001
+ overflow-y: auto !important;
1002
+ }
1003
+ .scroll-show textarea::-webkit-scrollbar {
1004
+ all: initial !important;
1005
+ background: #f1f1f1 !important;
1006
+ }
1007
+ .scroll-show textarea::-webkit-scrollbar-thumb {
1008
+ all: initial !important;
1009
+ background: #a8a8a8 !important;
1010
+ }
1011
+ """
1012
 
1013
  is_queue_mode = app_config.queue_concurrency_count is not None and app_config.queue_concurrency_count > 0
1014
 
1015
  simpleInputDict = {}
1016
+
1017
  with gr.Blocks() as simpleTranscribe:
1018
  simpleTranslateInput = gr.State(value="m2m100", elem_id = "translateInput")
1019
  simpleSourceInput = gr.State(value="urlData", elem_id = "sourceInput")
 
1055
  simpleInputDict.update(common_vad_inputs())
1056
  with gr.Accordion("Word Timestamps options", open=False):
1057
  simpleInputDict.update(common_word_timestamps_inputs())
1058
+ with gr.Accordion("Whisper Filter options", open=False):
1059
+ simpleInputDict.update(common_segments_filter_inputs())
1060
  with gr.Accordion("Diarization options", open=False):
1061
  simpleInputDict.update(common_diarization_inputs())
1062
  with gr.Accordion("Translation options", open=False):
 
1075
  gr.Markdown(readmeMd)
1076
 
1077
  simpleInputDict.update({simpleTranslateInput, simpleSourceInput})
1078
+ simpleSubmit.click(fn=ui.transcribe_webui_simple_progress if is_queue_mode else ui.transcribe_webui_simple,
1079
  inputs=simpleInputDict, outputs=simpleOutput)
1080
 
1081
  fullInputDict = {}
 
1150
  gr.Number(label="Repetition Penalty", value=app_config.repetition_penalty, elem_id = "repetition_penalty"),
1151
  gr.Number(label="No Repeat Ngram Size", value=app_config.no_repeat_ngram_size, precision=0, elem_id = "no_repeat_ngram_size")
1152
  })
1153
+ with gr.Accordion("Whisper Segments Filter options", open=False):
1154
+ fullInputDict.update(common_segments_filter_inputs())
1155
  with gr.Accordion("Diarization options", open=False):
1156
  fullInputDict.update(common_diarization_inputs())
1157
  with gr.Accordion("Translation options", open=False):
 
1171
  fullSubmit.click(fn=ui.transcribe_webui_full_progress if is_queue_mode else ui.transcribe_webui_full,
1172
  inputs=fullInputDict, outputs=fullOutput)
1173
 
1174
+ demo = gr.TabbedInterface([simpleTranscribe, fullTranscribe], tab_names=["Simple", "Full"], css=css)
1175
 
1176
  # Queue up the demo
1177
  if is_queue_mode:
config.json5 CHANGED
@@ -317,9 +317,13 @@
317
  "logprob_threshold": -1.0,
318
  // If the probability of the <no-speech> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence
319
  "no_speech_threshold": 0.6,
 
 
 
 
320
 
321
  // (experimental) extract word-level timestamps and refine the results based on them
322
- "word_timestamps": false,
323
  // if word_timestamps is True, merge these punctuation symbols with the next word
324
  "prepend_punctuations": "\"\'“¿([{-",
325
  // if word_timestamps is True, merge these punctuation symbols with the previous word
@@ -339,4 +343,20 @@
339
  "diarization_max_speakers": 8,
340
  // The number of seconds before inactivate processes are terminated. Use 0 to close processes immediately, or None for no timeout.
341
  "diarization_process_timeout": 60,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
  }
 
317
  "logprob_threshold": -1.0,
318
  // If the probability of the <no-speech> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence
319
  "no_speech_threshold": 0.6,
320
+ // [faster-whisper] The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
321
+ "repetition_penalty": 1.0,
322
+ // [faster-whisper] The model ensures that a sequence of words of no_repeat_ngram_size isn’t repeated in the output sequence. If specified, it must be a positive integer greater than 1.
323
+ "no_repeat_ngram_size": 0,
324
 
325
  // (experimental) extract word-level timestamps and refine the results based on them
326
+ "word_timestamps": true,
327
  // if word_timestamps is True, merge these punctuation symbols with the next word
328
  "prepend_punctuations": "\"\'“¿([{-",
329
  // if word_timestamps is True, merge these punctuation symbols with the previous word
 
343
  "diarization_max_speakers": 8,
344
  // The number of seconds before inactivate processes are terminated. Use 0 to close processes immediately, or None for no timeout.
345
  "diarization_process_timeout": 60,
346
+
347
+ // Whisper Segments Filter
348
+ "whisper_segments_filter": false,
349
+ "whisper_segments_filters": [
350
+ "avg_logprob < -0.9",
351
+ "(durationLen < 1.5 || segment_last), textLen > 5, avg_logprob < -0.4, no_speech_prob > 0.5",
352
+ "(durationLen < 1.5 || segment_last), textLen > 5, avg_logprob < -0.4, no_speech_prob > 0.07, compression_ratio < 0.9",
353
+ "(durationLen < 1.5 || segment_last), compression_ratio < 0.9, no_speech_prob > 0.1"
354
+ ],
355
+
356
+ // Translation - The maximum batch size.
357
+ "translation_batch_size": 2,
358
+ // Translation - Prevent repetitions of ngrams with this size (set 0 to disable).
359
+ "translation_no_repeat_ngram_size": 3,
360
+ // Translation - Beam size (1 for greedy search).
361
+ "translation_num_beams": 2,
362
  }
docs/options.md CHANGED
@@ -166,6 +166,24 @@ Penalty applied to the score of previously generated tokens (set > 1 to penalize
166
  This parameter only takes effect in [faster-whisper (ctranslate2)](https://github.com/SYSTRAN/faster-whisper/issues/478).
167
  Prevent repetitions of ngrams with this size (set 0 to disable).
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  ## Translation - Batch Size
170
  - transformers: batch_size
171
  When the pipeline will use DataLoader (when passing a dataset, on GPU for a Pytorch model), the size of the batch to use, for inference this is not always beneficial.
 
166
  This parameter only takes effect in [faster-whisper (ctranslate2)](https://github.com/SYSTRAN/faster-whisper/issues/478).
167
  Prevent repetitions of ngrams with this size (set 0 to disable).
168
 
169
+ ## Whisper Filter options
170
+ **This is an experimental feature and may potentially filter out correct transcription results.**
171
+
172
+ when enabled, can effectively improve the whisper hallucination, especially for the large-v3 version of the whisper model.
173
+
174
+ Observations for transcriptions:
175
+ 1. duration: calculated by subtracting start from end, it might indicate hallucinated results when inversely proportional to text length.
176
+ 1. segment_last: the last result for each segment during VAD transcription has a certain probability of being a hallucinated result.
177
+ 1. avg_logprob: average log probability, ranging from logprob_threshold (default: -1) to 0, is better when a larger value. A value lower than -0.9 might suggest a poor result.
178
+ 1. compression_ratio: gzip compression ratio, ranging from 0 to compression_ratio_threshold (default: 2.4), a higher positive value is preferable. If it is lower than 0.9, it might indicate suboptimal results.
179
+ 1. no_speech_prob: no_speech(<|nospeech|> token) probability, ranging from 0 to no_speech_threshold (default: 0.6), a smaller positive value is preferable. If it exceeds 0.1, it might suggest suboptimal results.
180
+
181
+ Four sets of filtering conditions have now been established, utilizing text length, duration length, as well as the avg_logprob, compression_ratio, and no_speech_prob parameters returned by Whisper.
182
+ 1. avg_logprob < -0.9
183
+ 1. (durationLen < 1.5 || segment_last), textLen > 5, avg_logprob < -0.4, no_speech_prob > 0.5
184
+ 1. (durationLen < 1.5 || segment_last), textLen > 5, avg_logprob < -0.4, no_speech_prob > 0.07, compression_ratio < 0.9
185
+ 1. (durationLen < 1.5 || segment_last), compression_ratio < 0.9, no_speech_prob > 0.1
186
+
187
  ## Translation - Batch Size
188
  - transformers: batch_size
189
  When the pipeline will use DataLoader (when passing a dataset, on GPU for a Pytorch model), the size of the batch to use, for inference this is not always beneficial.
src/config.py CHANGED
@@ -54,7 +54,7 @@ class ApplicationConfig:
54
  input_audio_max_duration: int = 600, share: bool = False, server_name: str = None, server_port: int = 7860,
55
  queue_concurrency_count: int = 1, delete_uploaded_files: bool = True,
56
  whisper_implementation: str = "whisper", default_model_name: str = "medium",
57
- default_nllb_model_name: str = "distilled-600M", default_vad: str = "silero-vad",
58
  vad_parallel_devices: str = "", vad_cpu_cores: int = 1, vad_process_timeout: int = 1800,
59
  auto_parallel: bool = False, output_dir: str = None,
60
  model_dir: str = None, device: str = None,
@@ -71,7 +71,7 @@ class ApplicationConfig:
71
  logprob_threshold: float = -1.0, no_speech_threshold: float = 0.6,
72
  repetition_penalty: float = 1.0, no_repeat_ngram_size: int = 0,
73
  # Word timestamp settings
74
- word_timestamps: bool = False, prepend_punctuations: str = "\"\'“¿([{-",
75
  append_punctuations: str = "\"\'.。,,!!??::”)]}、",
76
  highlight_words: bool = False,
77
  # Diarization
@@ -82,6 +82,9 @@ class ApplicationConfig:
82
  translation_batch_size: int = 2,
83
  translation_no_repeat_ngram_size: int = 3,
84
  translation_num_beams: int = 2,
 
 
 
85
  ):
86
 
87
  self.models = models
@@ -96,7 +99,6 @@ class ApplicationConfig:
96
 
97
  self.whisper_implementation = whisper_implementation
98
  self.default_model_name = default_model_name
99
- self.default_nllb_model_name = default_nllb_model_name
100
  self.default_vad = default_vad
101
  self.vad_parallel_devices = vad_parallel_devices
102
  self.vad_cpu_cores = vad_cpu_cores
@@ -148,6 +150,9 @@ class ApplicationConfig:
148
  self.translation_batch_size = translation_batch_size
149
  self.translation_no_repeat_ngram_size = translation_no_repeat_ngram_size
150
  self.translation_num_beams = translation_num_beams
 
 
 
151
 
152
  def get_model_names(self, name: str):
153
  return [ x.name for x in self.models[name] ]
 
54
  input_audio_max_duration: int = 600, share: bool = False, server_name: str = None, server_port: int = 7860,
55
  queue_concurrency_count: int = 1, delete_uploaded_files: bool = True,
56
  whisper_implementation: str = "whisper", default_model_name: str = "medium",
57
+ default_vad: str = "silero-vad",
58
  vad_parallel_devices: str = "", vad_cpu_cores: int = 1, vad_process_timeout: int = 1800,
59
  auto_parallel: bool = False, output_dir: str = None,
60
  model_dir: str = None, device: str = None,
 
71
  logprob_threshold: float = -1.0, no_speech_threshold: float = 0.6,
72
  repetition_penalty: float = 1.0, no_repeat_ngram_size: int = 0,
73
  # Word timestamp settings
74
+ word_timestamps: bool = True, prepend_punctuations: str = "\"\'“¿([{-",
75
  append_punctuations: str = "\"\'.。,,!!??::”)]}、",
76
  highlight_words: bool = False,
77
  # Diarization
 
82
  translation_batch_size: int = 2,
83
  translation_no_repeat_ngram_size: int = 3,
84
  translation_num_beams: int = 2,
85
+ # Whisper Segments Filter
86
+ whisper_segments_filter: bool = False,
87
+ whisper_segments_filters: List[str] = [],
88
  ):
89
 
90
  self.models = models
 
99
 
100
  self.whisper_implementation = whisper_implementation
101
  self.default_model_name = default_model_name
 
102
  self.default_vad = default_vad
103
  self.vad_parallel_devices = vad_parallel_devices
104
  self.vad_cpu_cores = vad_cpu_cores
 
150
  self.translation_batch_size = translation_batch_size
151
  self.translation_no_repeat_ngram_size = translation_no_repeat_ngram_size
152
  self.translation_num_beams = translation_num_beams
153
+ # Whisper Segments Filter
154
+ self.whisper_segments_filter = whisper_segments_filter
155
+ self.whisper_segments_filters = whisper_segments_filters
156
 
157
  def get_model_names(self, name: str):
158
  return [ x.name for x in self.models[name] ]
src/vad.py CHANGED
@@ -219,7 +219,11 @@ class AbstractTranscription(ABC):
219
  perf_end_time = time.perf_counter()
220
  print("\tWhisper took {} seconds".format(perf_end_time - perf_start_time))
221
 
222
- adjusted_segments = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
 
 
 
 
223
 
224
  # Propagate expand amount to the segments
225
  if (segment_expand_amount > 0):
 
219
  perf_end_time = time.perf_counter()
220
  print("\tWhisper took {} seconds".format(perf_end_time - perf_start_time))
221
 
222
+ adjusted_segments: List[Dict[str, Any]] = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
223
+
224
+ if len(adjusted_segments) > 0:
225
+ adjusted_segments[0]["segment_first"] = True
226
+ adjusted_segments[-1]["segment_last"] = True
227
 
228
  # Propagate expand amount to the segments
229
  if (segment_expand_amount > 0):