jhj0517 commited on
Commit
393a9c3
·
1 Parent(s): 501c404

Fix class method attribute access

Browse files
Files changed (1) hide show
  1. modules/whisper/data_classes.py +78 -60
modules/whisper/data_classes.py CHANGED
@@ -62,25 +62,37 @@ class VadParams(BaseParams):
62
  @classmethod
63
  def to_gradio_inputs(cls, defaults: Optional[Dict] = None) -> List[gr.components.base.FormComponent]:
64
  return [
65
- gr.Checkbox(label=_("Enable Silero VAD Filter"), value=defaults.get("vad_filter", cls.vad_filter),
66
- interactive=True,
67
- info=_("Enable this to transcribe only detected voice")),
68
- gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold",
69
- value=defaults.get("threshold", cls.threshold),
70
- info="Lower it to be more sensitive to small sounds."),
71
- gr.Number(label="Minimum Speech Duration (ms)", precision=0,
72
- value=defaults.get("min_speech_duration_ms", cls.min_speech_duration_ms),
73
- info="Final speech chunks shorter than this time are thrown out"),
74
- gr.Number(label="Maximum Speech Duration (s)",
75
- value=defaults.get("max_speech_duration_s", cls.max_speech_duration_s),
76
- info="Maximum duration of speech chunks in \"seconds\"."),
77
- gr.Number(label="Minimum Silence Duration (ms)", precision=0,
78
- value=defaults.get("min_silence_duration_ms", cls.min_silence_duration_ms),
79
- info="In the end of each speech chunk wait for this time"
80
- " before separating it"),
81
- gr.Number(label="Speech Padding (ms)", precision=0,
82
- value=defaults.get("speech_pad_ms", cls.speech_pad_ms),
83
- info="Final speech chunks are padded by this time each side")
 
 
 
 
 
 
 
 
 
 
 
 
84
  ]
85
 
86
 
@@ -100,18 +112,18 @@ class DiarizationParams(BaseParams):
100
  return [
101
  gr.Checkbox(
102
  label=_("Enable Diarization"),
103
- value=defaults.get("is_diarize", cls.is_diarize),
104
  info=_("Enable speaker diarization")
105
  ),
106
  gr.Textbox(
107
  label=_("HuggingFace Token"),
108
- value=defaults.get("hf_token", cls.hf_token),
109
  info=_("This is only needed the first time you download the model")
110
  ),
111
  gr.Dropdown(
112
  label=_("Device"),
113
  choices=["cpu", "cuda"] if available_devices is None else available_devices,
114
- value=defaults.get("device", cls.device),
115
  info=_("Device to run diarization model")
116
  )
117
  ]
@@ -147,36 +159,37 @@ class BGMSeparationParams(BaseParams):
147
  return [
148
  gr.Checkbox(
149
  label=_("Enable Background Music Remover Filter"),
150
- value=defaults.get("is_separate_bgm", cls.is_separate_bgm),
151
  interactive=True,
152
  info=_("Enabling this will remove background music")
153
  ),
154
  gr.Dropdown(
155
  label=_("Device"),
156
  choices=["cpu", "cuda"] if available_devices is None else available_devices,
157
- value=defaults.get("device", cls.device),
158
  info=_("Device to run UVR model")
159
  ),
160
  gr.Dropdown(
161
  label=_("Model"),
162
- choices=["UVR-MDX-NET-Inst_HQ_4", "UVR-MDX-NET-Inst_3"] if available_models is None else available_models,
163
- value=defaults.get("model_size", cls.model_size),
 
164
  info=_("UVR model size")
165
  ),
166
  gr.Number(
167
  label="Segment Size",
168
- value=defaults.get("segment_size", cls.segment_size),
169
  precision=0,
170
  info="Segment size for UVR model"
171
  ),
172
  gr.Checkbox(
173
  label=_("Save separated files to output"),
174
- value=defaults.get("save_file", cls.save_file),
175
  info=_("Whether to save separated audio files")
176
  ),
177
  gr.Checkbox(
178
  label=_("Offload sub model after removing background music"),
179
- value=defaults.get("enable_offload", cls.enable_offload),
180
  info=_("Offload UVR model after transcription")
181
  )
182
  ]
@@ -283,17 +296,17 @@ class WhisperParams(BaseParams):
283
  gr.Dropdown(
284
  label="Model Size",
285
  choices=["small", "medium", "large-v2"],
286
- value=defaults.get("model_size", cls.model_size),
287
  info="Whisper model size"
288
  ),
289
  gr.Textbox(
290
  label="Language",
291
- value=defaults.get("lang", cls.lang),
292
  info="Source language of the file to transcribe"
293
  ),
294
  gr.Checkbox(
295
  label="Translate to English",
296
- value=defaults.get("is_translate", cls.is_translate),
297
  info="Translate speech to English end-to-end"
298
  ),
299
  ]
@@ -301,18 +314,18 @@ class WhisperParams(BaseParams):
301
  inputs += [
302
  gr.Number(
303
  label="Beam Size",
304
- value=defaults.get("beam_size", cls.beam_size),
305
  precision=0,
306
  info="Beam size for decoding"
307
  ),
308
  gr.Number(
309
  label="Log Probability Threshold",
310
- value=defaults.get("log_prob_threshold", cls.log_prob_threshold),
311
  info="Threshold for average log probability of sampled tokens"
312
  ),
313
  gr.Number(
314
  label="No Speech Threshold",
315
- value=defaults.get("no_speech_threshold", cls.no_speech_threshold),
316
  info="Threshold for detecting silence"
317
  ),
318
  gr.Dropdown(
@@ -323,23 +336,24 @@ class WhisperParams(BaseParams):
323
  ),
324
  gr.Number(
325
  label="Best Of",
326
- value=defaults.get("best_of", cls.best_of),
327
  precision=0,
328
  info="Number of candidates when sampling"
329
  ),
330
  gr.Number(
331
  label="Patience",
332
- value=defaults.get("patience", cls.patience),
333
  info="Beam search patience factor"
334
  ),
335
  gr.Checkbox(
336
  label="Condition On Previous Text",
337
- value=defaults.get("condition_on_previous_text", cls.condition_on_previous_text),
338
  info="Use previous output as prompt for next window"
339
  ),
340
  gr.Slider(
341
  label="Prompt Reset On Temperature",
342
- value=defaults.get("prompt_reset_on_temperature", cls.prompt_reset_on_temperature),
 
343
  minimum=0,
344
  maximum=1,
345
  step=0.01,
@@ -347,12 +361,12 @@ class WhisperParams(BaseParams):
347
  ),
348
  gr.Textbox(
349
  label="Initial Prompt",
350
- value=defaults.get("initial_prompt", cls.initial_prompt),
351
  info="Initial prompt for first window"
352
  ),
353
  gr.Slider(
354
  label="Temperature",
355
- value=defaults.get("temperature", cls.temperature),
356
  minimum=0.0,
357
  step=0.01,
358
  maximum=1.0,
@@ -360,7 +374,8 @@ class WhisperParams(BaseParams):
360
  ),
361
  gr.Number(
362
  label="Compression Ratio Threshold",
363
- value=defaults.get("compression_ratio_threshold", cls.compression_ratio_threshold),
 
364
  info="Threshold for gzip compression ratio"
365
  )
366
  ]
@@ -368,86 +383,89 @@ class WhisperParams(BaseParams):
368
  inputs += [
369
  gr.Number(
370
  label="Length Penalty",
371
- value=defaults.get("length_penalty", cls.length_penalty),
372
  info="Exponential length penalty",
373
- visible=whisper_type=="faster_whisper"
374
  ),
375
  gr.Number(
376
  label="Repetition Penalty",
377
- value=defaults.get("repetition_penalty", cls.repetition_penalty),
378
  info="Penalty for repeated tokens"
379
  ),
380
  gr.Number(
381
  label="No Repeat N-gram Size",
382
- value=defaults.get("no_repeat_ngram_size", cls.no_repeat_ngram_size),
383
  precision=0,
384
  info="Size of n-grams to prevent repetition"
385
  ),
386
  gr.Textbox(
387
  label="Prefix",
388
- value=defaults.get("prefix", cls.prefix),
389
  info="Prefix text for first window"
390
  ),
391
  gr.Checkbox(
392
  label="Suppress Blank",
393
- value=defaults.get("suppress_blank", cls.suppress_blank),
394
  info="Suppress blank outputs at start of sampling"
395
  ),
396
  gr.Textbox(
397
  label="Suppress Tokens",
398
- value=defaults.get("suppress_tokens", cls.suppress_tokens),
399
  info="Token IDs to suppress"
400
  ),
401
  gr.Number(
402
  label="Max Initial Timestamp",
403
- value=defaults.get("max_initial_timestamp", cls.max_initial_timestamp),
404
  info="Maximum initial timestamp"
405
  ),
406
  gr.Checkbox(
407
  label="Word Timestamps",
408
- value=defaults.get("word_timestamps", cls.word_timestamps),
409
  info="Extract word-level timestamps"
410
  ),
411
  gr.Textbox(
412
  label="Prepend Punctuations",
413
- value=defaults.get("prepend_punctuations", cls.prepend_punctuations),
414
  info="Punctuations to merge with next word"
415
  ),
416
  gr.Textbox(
417
  label="Append Punctuations",
418
- value=defaults.get("append_punctuations", cls.append_punctuations),
419
  info="Punctuations to merge with previous word"
420
  ),
421
  gr.Number(
422
  label="Max New Tokens",
423
- value=defaults.get("max_new_tokens", cls.max_new_tokens),
424
  precision=0,
425
  info="Maximum number of new tokens per chunk"
426
  ),
427
  gr.Number(
428
  label="Chunk Length (s)",
429
- value=defaults.get("chunk_length", cls.chunk_length),
430
  precision=0,
431
  info="Length of audio segments in seconds"
432
  ),
433
  gr.Number(
434
  label="Hallucination Silence Threshold (sec)",
435
- value=defaults.get("hallucination_silence_threshold", cls.hallucination_silence_threshold),
 
436
  info="Threshold for skipping silent periods in hallucination detection"
437
  ),
438
  gr.Textbox(
439
  label="Hotwords",
440
- value=defaults.get("hotwords", cls.hotwords),
441
  info="Hotwords/hint phrases for the model"
442
  ),
443
  gr.Number(
444
  label="Language Detection Threshold",
445
- value=defaults.get("language_detection_threshold", cls.language_detection_threshold),
 
446
  info="Threshold for language detection probability"
447
  ),
448
  gr.Number(
449
  label="Language Detection Segments",
450
- value=defaults.get("language_detection_segments", cls.language_detection_segments),
 
451
  precision=0,
452
  info="Number of segments for language detection"
453
  )
@@ -457,7 +475,7 @@ class WhisperParams(BaseParams):
457
  inputs += [
458
  gr.Number(
459
  label="Batch Size",
460
- value=defaults.get("batch_size", cls.batch_size),
461
  precision=0,
462
  info="Batch size for processing",
463
  visible=whisper_type == "insanely_fast_whisper"
 
62
  @classmethod
63
  def to_gradio_inputs(cls, defaults: Optional[Dict] = None) -> List[gr.components.base.FormComponent]:
64
  return [
65
+ gr.Checkbox(
66
+ label=_("Enable Silero VAD Filter"),
67
+ value=defaults.get("vad_filter", cls.__fields__["vad_filter"].default),
68
+ interactive=True,
69
+ info=_("Enable this to transcribe only detected voice")
70
+ ),
71
+ gr.Slider(
72
+ minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold",
73
+ value=defaults.get("threshold", cls.__fields__["threshold"].default),
74
+ info="Lower it to be more sensitive to small sounds."
75
+ ),
76
+ gr.Number(
77
+ label="Minimum Speech Duration (ms)", precision=0,
78
+ value=defaults.get("min_speech_duration_ms", cls.__fields__["min_speech_duration_ms"].default),
79
+ info="Final speech chunks shorter than this time are thrown out"
80
+ ),
81
+ gr.Number(
82
+ label="Maximum Speech Duration (s)",
83
+ value=defaults.get("max_speech_duration_s", cls.__fields__["max_speech_duration_s"].default),
84
+ info="Maximum duration of speech chunks in \"seconds\"."
85
+ ),
86
+ gr.Number(
87
+ label="Minimum Silence Duration (ms)", precision=0,
88
+ value=defaults.get("min_silence_duration_ms", cls.__fields__["min_silence_duration_ms"].default),
89
+ info="In the end of each speech chunk wait for this time before separating it"
90
+ ),
91
+ gr.Number(
92
+ label="Speech Padding (ms)", precision=0,
93
+ value=defaults.get("speech_pad_ms", cls.__fields__["speech_pad_ms"].default),
94
+ info="Final speech chunks are padded by this time each side"
95
+ )
96
  ]
97
 
98
 
 
112
  return [
113
  gr.Checkbox(
114
  label=_("Enable Diarization"),
115
+ value=defaults.get("is_diarize", cls.__fields__["is_diarize"].default),
116
  info=_("Enable speaker diarization")
117
  ),
118
  gr.Textbox(
119
  label=_("HuggingFace Token"),
120
+ value=defaults.get("hf_token", cls.__fields__["hf_token"].default),
121
  info=_("This is only needed the first time you download the model")
122
  ),
123
  gr.Dropdown(
124
  label=_("Device"),
125
  choices=["cpu", "cuda"] if available_devices is None else available_devices,
126
+ value=defaults.get("device", cls.__fields__["device"].default),
127
  info=_("Device to run diarization model")
128
  )
129
  ]
 
159
  return [
160
  gr.Checkbox(
161
  label=_("Enable Background Music Remover Filter"),
162
+ value=defaults.get("is_separate_bgm", cls.__fields__["is_separate_bgm"].default),
163
  interactive=True,
164
  info=_("Enabling this will remove background music")
165
  ),
166
  gr.Dropdown(
167
  label=_("Device"),
168
  choices=["cpu", "cuda"] if available_devices is None else available_devices,
169
+ value=defaults.get("device", cls.__fields__["device"].default),
170
  info=_("Device to run UVR model")
171
  ),
172
  gr.Dropdown(
173
  label=_("Model"),
174
+ choices=["UVR-MDX-NET-Inst_HQ_4",
175
+ "UVR-MDX-NET-Inst_3"] if available_models is None else available_models,
176
+ value=defaults.get("model_size", cls.__fields__["model_size"].default),
177
  info=_("UVR model size")
178
  ),
179
  gr.Number(
180
  label="Segment Size",
181
+ value=defaults.get("segment_size", cls.__fields__["segment_size"].default),
182
  precision=0,
183
  info="Segment size for UVR model"
184
  ),
185
  gr.Checkbox(
186
  label=_("Save separated files to output"),
187
+ value=defaults.get("save_file", cls.__fields__["save_file"].default),
188
  info=_("Whether to save separated audio files")
189
  ),
190
  gr.Checkbox(
191
  label=_("Offload sub model after removing background music"),
192
+ value=defaults.get("enable_offload", cls.__fields__["enable_offload"].default),
193
  info=_("Offload UVR model after transcription")
194
  )
195
  ]
 
296
  gr.Dropdown(
297
  label="Model Size",
298
  choices=["small", "medium", "large-v2"],
299
+ value=defaults.get("model_size", cls.__fields__["model_size"].default),
300
  info="Whisper model size"
301
  ),
302
  gr.Textbox(
303
  label="Language",
304
+ value=defaults.get("lang", cls.__fields__["lang"].default),
305
  info="Source language of the file to transcribe"
306
  ),
307
  gr.Checkbox(
308
  label="Translate to English",
309
+ value=defaults.get("is_translate", cls.__fields__["is_translate"].default),
310
  info="Translate speech to English end-to-end"
311
  ),
312
  ]
 
314
  inputs += [
315
  gr.Number(
316
  label="Beam Size",
317
+ value=defaults.get("beam_size", cls.__fields__["beam_size"].default),
318
  precision=0,
319
  info="Beam size for decoding"
320
  ),
321
  gr.Number(
322
  label="Log Probability Threshold",
323
+ value=defaults.get("log_prob_threshold", cls.__fields__["log_prob_threshold"].default),
324
  info="Threshold for average log probability of sampled tokens"
325
  ),
326
  gr.Number(
327
  label="No Speech Threshold",
328
+ value=defaults.get("no_speech_threshold", cls.__fields__["no_speech_threshold"].default),
329
  info="Threshold for detecting silence"
330
  ),
331
  gr.Dropdown(
 
336
  ),
337
  gr.Number(
338
  label="Best Of",
339
+ value=defaults.get("best_of", cls.__fields__["best_of"].default),
340
  precision=0,
341
  info="Number of candidates when sampling"
342
  ),
343
  gr.Number(
344
  label="Patience",
345
+ value=defaults.get("patience", cls.__fields__["patience"].default),
346
  info="Beam search patience factor"
347
  ),
348
  gr.Checkbox(
349
  label="Condition On Previous Text",
350
+ value=defaults.get("condition_on_previous_text", cls.__fields__["condition_on_previous_text"].default),
351
  info="Use previous output as prompt for next window"
352
  ),
353
  gr.Slider(
354
  label="Prompt Reset On Temperature",
355
+ value=defaults.get("prompt_reset_on_temperature",
356
+ cls.__fields__["prompt_reset_on_temperature"].default),
357
  minimum=0,
358
  maximum=1,
359
  step=0.01,
 
361
  ),
362
  gr.Textbox(
363
  label="Initial Prompt",
364
+ value=defaults.get("initial_prompt", cls.__fields__["initial_prompt"].default),
365
  info="Initial prompt for first window"
366
  ),
367
  gr.Slider(
368
  label="Temperature",
369
+ value=defaults.get("temperature", cls.__fields__["temperature"].default),
370
  minimum=0.0,
371
  step=0.01,
372
  maximum=1.0,
 
374
  ),
375
  gr.Number(
376
  label="Compression Ratio Threshold",
377
+ value=defaults.get("compression_ratio_threshold",
378
+ cls.__fields__["compression_ratio_threshold"].default),
379
  info="Threshold for gzip compression ratio"
380
  )
381
  ]
 
383
  inputs += [
384
  gr.Number(
385
  label="Length Penalty",
386
+ value=defaults.get("length_penalty", cls.__fields__["length_penalty"].default),
387
  info="Exponential length penalty",
388
+ visible=whisper_type == "faster_whisper"
389
  ),
390
  gr.Number(
391
  label="Repetition Penalty",
392
+ value=defaults.get("repetition_penalty", cls.__fields__["repetition_penalty"].default),
393
  info="Penalty for repeated tokens"
394
  ),
395
  gr.Number(
396
  label="No Repeat N-gram Size",
397
+ value=defaults.get("no_repeat_ngram_size", cls.__fields__["no_repeat_ngram_size"].default),
398
  precision=0,
399
  info="Size of n-grams to prevent repetition"
400
  ),
401
  gr.Textbox(
402
  label="Prefix",
403
+ value=defaults.get("prefix", cls.__fields__["prefix"].default),
404
  info="Prefix text for first window"
405
  ),
406
  gr.Checkbox(
407
  label="Suppress Blank",
408
+ value=defaults.get("suppress_blank", cls.__fields__["suppress_blank"].default),
409
  info="Suppress blank outputs at start of sampling"
410
  ),
411
  gr.Textbox(
412
  label="Suppress Tokens",
413
+ value=defaults.get("suppress_tokens", cls.__fields__["suppress_tokens"].default),
414
  info="Token IDs to suppress"
415
  ),
416
  gr.Number(
417
  label="Max Initial Timestamp",
418
+ value=defaults.get("max_initial_timestamp", cls.__fields__["max_initial_timestamp"].default),
419
  info="Maximum initial timestamp"
420
  ),
421
  gr.Checkbox(
422
  label="Word Timestamps",
423
+ value=defaults.get("word_timestamps", cls.__fields__["word_timestamps"].default),
424
  info="Extract word-level timestamps"
425
  ),
426
  gr.Textbox(
427
  label="Prepend Punctuations",
428
+ value=defaults.get("prepend_punctuations", cls.__fields__["prepend_punctuations"].default),
429
  info="Punctuations to merge with next word"
430
  ),
431
  gr.Textbox(
432
  label="Append Punctuations",
433
+ value=defaults.get("append_punctuations", cls.__fields__["append_punctuations"].default),
434
  info="Punctuations to merge with previous word"
435
  ),
436
  gr.Number(
437
  label="Max New Tokens",
438
+ value=defaults.get("max_new_tokens", cls.__fields__["max_new_tokens"].default),
439
  precision=0,
440
  info="Maximum number of new tokens per chunk"
441
  ),
442
  gr.Number(
443
  label="Chunk Length (s)",
444
+ value=defaults.get("chunk_length", cls.__fields__["chunk_length"].default),
445
  precision=0,
446
  info="Length of audio segments in seconds"
447
  ),
448
  gr.Number(
449
  label="Hallucination Silence Threshold (sec)",
450
+ value=defaults.get("hallucination_silence_threshold",
451
+ cls.__fields__["hallucination_silence_threshold"].default),
452
  info="Threshold for skipping silent periods in hallucination detection"
453
  ),
454
  gr.Textbox(
455
  label="Hotwords",
456
+ value=defaults.get("hotwords", cls.__fields__["hotwords"].default),
457
  info="Hotwords/hint phrases for the model"
458
  ),
459
  gr.Number(
460
  label="Language Detection Threshold",
461
+ value=defaults.get("language_detection_threshold",
462
+ cls.__fields__["language_detection_threshold"].default),
463
  info="Threshold for language detection probability"
464
  ),
465
  gr.Number(
466
  label="Language Detection Segments",
467
+ value=defaults.get("language_detection_segments",
468
+ cls.__fields__["language_detection_segments"].default),
469
  precision=0,
470
  info="Number of segments for language detection"
471
  )
 
475
  inputs += [
476
  gr.Number(
477
  label="Batch Size",
478
+ value=defaults.get("batch_size", cls.__fields__["batch_size"].default),
479
  precision=0,
480
  info="Batch size for processing",
481
  visible=whisper_type == "insanely_fast_whisper"