jhj0517 commited on
Commit
21bbf6d
·
1 Parent(s): 0da25b6

Update visibility by whisper implementation

Browse files
Files changed (1) hide show
  1. modules/whisper/data_classes.py +111 -101
modules/whisper/data_classes.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
  import torch
3
  from typing import Optional, Dict, List
4
- from pydantic import BaseModel, Field, field_validator
5
  from gradio_i18n import Translate, gettext as _
6
  from enum import Enum
7
  from copy import deepcopy
@@ -17,6 +17,8 @@ class WhisperImpl(Enum):
17
 
18
 
19
  class BaseParams(BaseModel):
 
 
20
  def to_dict(self) -> Dict:
21
  return self.model_dump()
22
 
@@ -231,7 +233,6 @@ class WhisperParams(BaseParams):
231
  gt=0,
232
  description="Threshold for gzip compression ratio"
233
  )
234
- batch_size: int = Field(default=24, gt=0, description="Batch size for processing")
235
  length_penalty: float = Field(default=1.0, gt=0, description="Exponential length penalty")
236
  repetition_penalty: float = Field(default=1.0, gt=0, description="Penalty for repeated tokens")
237
  no_repeat_ngram_size: int = Field(default=0, ge=0, description="Size of n-grams to prevent repetition")
@@ -271,6 +272,7 @@ class WhisperParams(BaseParams):
271
  gt=0,
272
  description="Number of segments for language detection"
273
  )
 
274
 
275
  @field_validator('lang')
276
  def validate_lang(cls, v):
@@ -375,108 +377,116 @@ class WhisperParams(BaseParams):
375
  info="Threshold for gzip compression ratio"
376
  )
377
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378
  if whisper_type == WhisperImpl.FASTER_WHISPER:
379
- inputs += [
380
- gr.Number(
381
- label="Length Penalty",
382
- value=defaults.get("length_penalty", cls.__fields__["length_penalty"].default),
383
- info="Exponential length penalty",
384
- visible=whisper_type == "faster_whisper"
385
- ),
386
- gr.Number(
387
- label="Repetition Penalty",
388
- value=defaults.get("repetition_penalty", cls.__fields__["repetition_penalty"].default),
389
- info="Penalty for repeated tokens"
390
- ),
391
- gr.Number(
392
- label="No Repeat N-gram Size",
393
- value=defaults.get("no_repeat_ngram_size", cls.__fields__["no_repeat_ngram_size"].default),
394
- precision=0,
395
- info="Size of n-grams to prevent repetition"
396
- ),
397
- gr.Textbox(
398
- label="Prefix",
399
- value=defaults.get("prefix", cls.__fields__["prefix"].default),
400
- info="Prefix text for first window"
401
- ),
402
- gr.Checkbox(
403
- label="Suppress Blank",
404
- value=defaults.get("suppress_blank", cls.__fields__["suppress_blank"].default),
405
- info="Suppress blank outputs at start of sampling"
406
- ),
407
- gr.Textbox(
408
- label="Suppress Tokens",
409
- value=defaults.get("suppress_tokens", cls.__fields__["suppress_tokens"].default),
410
- info="Token IDs to suppress"
411
- ),
412
- gr.Number(
413
- label="Max Initial Timestamp",
414
- value=defaults.get("max_initial_timestamp", cls.__fields__["max_initial_timestamp"].default),
415
- info="Maximum initial timestamp"
416
- ),
417
- gr.Checkbox(
418
- label="Word Timestamps",
419
- value=defaults.get("word_timestamps", cls.__fields__["word_timestamps"].default),
420
- info="Extract word-level timestamps"
421
- ),
422
- gr.Textbox(
423
- label="Prepend Punctuations",
424
- value=defaults.get("prepend_punctuations", cls.__fields__["prepend_punctuations"].default),
425
- info="Punctuations to merge with next word"
426
- ),
427
- gr.Textbox(
428
- label="Append Punctuations",
429
- value=defaults.get("append_punctuations", cls.__fields__["append_punctuations"].default),
430
- info="Punctuations to merge with previous word"
431
- ),
432
- gr.Number(
433
- label="Max New Tokens",
434
- value=defaults.get("max_new_tokens", cls.__fields__["max_new_tokens"].default),
435
- precision=0,
436
- info="Maximum number of new tokens per chunk"
437
- ),
438
- gr.Number(
439
- label="Chunk Length (s)",
440
- value=defaults.get("chunk_length", cls.__fields__["chunk_length"].default),
441
- precision=0,
442
- info="Length of audio segments in seconds"
443
- ),
444
- gr.Number(
445
- label="Hallucination Silence Threshold (sec)",
446
- value=defaults.get("hallucination_silence_threshold",
447
- cls.__fields__["hallucination_silence_threshold"].default),
448
- info="Threshold for skipping silent periods in hallucination detection"
449
- ),
450
- gr.Textbox(
451
- label="Hotwords",
452
- value=defaults.get("hotwords", cls.__fields__["hotwords"].default),
453
- info="Hotwords/hint phrases for the model"
454
- ),
455
- gr.Number(
456
- label="Language Detection Threshold",
457
- value=defaults.get("language_detection_threshold",
458
- cls.__fields__["language_detection_threshold"].default),
459
- info="Threshold for language detection probability"
460
- ),
461
- gr.Number(
462
- label="Language Detection Segments",
463
- value=defaults.get("language_detection_segments",
464
- cls.__fields__["language_detection_segments"].default),
465
- precision=0,
466
- info="Number of segments for language detection"
467
- )
468
- ]
469
 
470
  if whisper_type == WhisperImpl.INSANELY_FAST_WHISPER:
471
- inputs += [
472
- gr.Number(
473
- label="Batch Size",
474
- value=defaults.get("batch_size", cls.__fields__["batch_size"].default),
475
- precision=0,
476
- info="Batch size for processing",
477
- visible=whisper_type == "insanely_fast_whisper"
478
- )
479
- ]
480
  return inputs
481
 
482
 
 
1
  import gradio as gr
2
  import torch
3
  from typing import Optional, Dict, List
4
+ from pydantic import BaseModel, Field, field_validator, ConfigDict
5
  from gradio_i18n import Translate, gettext as _
6
  from enum import Enum
7
  from copy import deepcopy
 
17
 
18
 
19
  class BaseParams(BaseModel):
20
+ model_config = ConfigDict(protected_namespaces=())
21
+
22
  def to_dict(self) -> Dict:
23
  return self.model_dump()
24
 
 
233
  gt=0,
234
  description="Threshold for gzip compression ratio"
235
  )
 
236
  length_penalty: float = Field(default=1.0, gt=0, description="Exponential length penalty")
237
  repetition_penalty: float = Field(default=1.0, gt=0, description="Penalty for repeated tokens")
238
  no_repeat_ngram_size: int = Field(default=0, ge=0, description="Size of n-grams to prevent repetition")
 
272
  gt=0,
273
  description="Number of segments for language detection"
274
  )
275
+ batch_size: int = Field(default=24, gt=0, description="Batch size for processing")
276
 
277
  @field_validator('lang')
278
  def validate_lang(cls, v):
 
377
  info="Threshold for gzip compression ratio"
378
  )
379
  ]
380
+
381
+ faster_whisper_inputs = [
382
+ gr.Number(
383
+ label="Length Penalty",
384
+ value=defaults.get("length_penalty", cls.__fields__["length_penalty"].default),
385
+ info="Exponential length penalty",
386
+ ),
387
+ gr.Number(
388
+ label="Repetition Penalty",
389
+ value=defaults.get("repetition_penalty", cls.__fields__["repetition_penalty"].default),
390
+ info="Penalty for repeated tokens"
391
+ ),
392
+ gr.Number(
393
+ label="No Repeat N-gram Size",
394
+ value=defaults.get("no_repeat_ngram_size", cls.__fields__["no_repeat_ngram_size"].default),
395
+ precision=0,
396
+ info="Size of n-grams to prevent repetition"
397
+ ),
398
+ gr.Textbox(
399
+ label="Prefix",
400
+ value=defaults.get("prefix", cls.__fields__["prefix"].default),
401
+ info="Prefix text for first window"
402
+ ),
403
+ gr.Checkbox(
404
+ label="Suppress Blank",
405
+ value=defaults.get("suppress_blank", cls.__fields__["suppress_blank"].default),
406
+ info="Suppress blank outputs at start of sampling"
407
+ ),
408
+ gr.Textbox(
409
+ label="Suppress Tokens",
410
+ value=defaults.get("suppress_tokens", cls.__fields__["suppress_tokens"].default),
411
+ info="Token IDs to suppress"
412
+ ),
413
+ gr.Number(
414
+ label="Max Initial Timestamp",
415
+ value=defaults.get("max_initial_timestamp", cls.__fields__["max_initial_timestamp"].default),
416
+ info="Maximum initial timestamp"
417
+ ),
418
+ gr.Checkbox(
419
+ label="Word Timestamps",
420
+ value=defaults.get("word_timestamps", cls.__fields__["word_timestamps"].default),
421
+ info="Extract word-level timestamps"
422
+ ),
423
+ gr.Textbox(
424
+ label="Prepend Punctuations",
425
+ value=defaults.get("prepend_punctuations", cls.__fields__["prepend_punctuations"].default),
426
+ info="Punctuations to merge with next word"
427
+ ),
428
+ gr.Textbox(
429
+ label="Append Punctuations",
430
+ value=defaults.get("append_punctuations", cls.__fields__["append_punctuations"].default),
431
+ info="Punctuations to merge with previous word"
432
+ ),
433
+ gr.Number(
434
+ label="Max New Tokens",
435
+ value=defaults.get("max_new_tokens", cls.__fields__["max_new_tokens"].default),
436
+ precision=0,
437
+ info="Maximum number of new tokens per chunk"
438
+ ),
439
+ gr.Number(
440
+ label="Chunk Length (s)",
441
+ value=defaults.get("chunk_length", cls.__fields__["chunk_length"].default),
442
+ precision=0,
443
+ info="Length of audio segments in seconds"
444
+ ),
445
+ gr.Number(
446
+ label="Hallucination Silence Threshold (sec)",
447
+ value=defaults.get("hallucination_silence_threshold",
448
+ cls.__fields__["hallucination_silence_threshold"].default),
449
+ info="Threshold for skipping silent periods in hallucination detection"
450
+ ),
451
+ gr.Textbox(
452
+ label="Hotwords",
453
+ value=defaults.get("hotwords", cls.__fields__["hotwords"].default),
454
+ info="Hotwords/hint phrases for the model"
455
+ ),
456
+ gr.Number(
457
+ label="Language Detection Threshold",
458
+ value=defaults.get("language_detection_threshold",
459
+ cls.__fields__["language_detection_threshold"].default),
460
+ info="Threshold for language detection probability"
461
+ ),
462
+ gr.Number(
463
+ label="Language Detection Segments",
464
+ value=defaults.get("language_detection_segments",
465
+ cls.__fields__["language_detection_segments"].default),
466
+ precision=0,
467
+ info="Number of segments for language detection"
468
+ )
469
+ ]
470
+
471
+ insanely_fast_whisper_inputs = [
472
+ gr.Number(
473
+ label="Batch Size",
474
+ value=defaults.get("batch_size", cls.__fields__["batch_size"].default),
475
+ precision=0,
476
+ info="Batch size for processing"
477
+ )
478
+ ]
479
+
480
  if whisper_type == WhisperImpl.FASTER_WHISPER:
481
+ for input_component in faster_whisper_inputs:
482
+ input_component.visible = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
483
 
484
  if whisper_type == WhisperImpl.INSANELY_FAST_WHISPER:
485
+ for input_component in insanely_fast_whisper_inputs:
486
+ input_component.visible = True
487
+
488
+ inputs += faster_whisper_inputs + insanely_fast_whisper_inputs
489
+
 
 
 
 
490
  return inputs
491
 
492