jhj0517 commited on
Commit
f3ecc7a
·
unverified ·
2 Parent(s): 2d3dbbe 4fe08f3

Merge pull request #306 from jhj0517/feature/add-tests

Browse files
.github/workflows/{shell-scrpit-test.yml → ci-shell.yml} RENAMED
@@ -1,38 +1,42 @@
1
- name: Shell Script Test
2
 
3
  on:
 
 
4
  push:
5
- branches: ["feature/shell-script"]
6
-
7
- env:
8
- PYTHON_VERSION: '3.9'
 
9
 
10
  jobs:
11
  test-shell-script:
 
12
  runs-on: ubuntu-latest
 
 
 
 
13
  steps:
14
- - name: 'Checkout GitHub Action'
15
- uses: actions/checkout@v3
16
 
17
- - name: Setup Python ${{ env.PYTHON_VERSION }} Environment
18
- uses: actions/setup-python@v4
 
19
  with:
20
- python-version: ${{ env.PYTHON_VERSION }}
21
 
22
- - name: 'Setup FFmpeg'
23
- uses: FedericoCarboni/setup-ffmpeg@v3
24
- id: setup-ffmpeg
25
- with:
26
- ffmpeg-version: release
27
- architecture: 'arm64'
28
- linking-type: static
29
 
30
- - name: 'Execute Install.sh'
31
  run: |
32
  chmod +x ./Install.sh
33
  ./Install.sh
34
 
35
- - name: 'Execute start-webui.sh'
36
  run: |
37
  chmod +x ./start-webui.sh
38
  timeout 60s ./start-webui.sh || true
 
1
+ name: CI-Shell Script
2
 
3
  on:
4
+ workflow_dispatch:
5
+
6
  push:
7
+ branches:
8
+ - master
9
+ pull_request:
10
+ branches:
11
+ - master
12
 
13
  jobs:
14
  test-shell-script:
15
+
16
  runs-on: ubuntu-latest
17
+ strategy:
18
+ matrix:
19
+ python: [ "3.10" ]
20
+
21
  steps:
22
+ - name: Clean up space for action
23
+ run: rm -rf /opt/hostedtoolcache
24
 
25
+ - uses: actions/checkout@v4
26
+ - name: Setup Python
27
+ uses: actions/setup-python@v5
28
  with:
29
+ python-version: ${{ matrix.python }}
30
 
31
+ - name: Install git and ffmpeg
32
+ run: sudo apt-get update && sudo apt-get install -y git ffmpeg
 
 
 
 
 
33
 
34
+ - name: Execute Install.sh
35
  run: |
36
  chmod +x ./Install.sh
37
  ./Install.sh
38
 
39
+ - name: Execute start-webui.sh
40
  run: |
41
  chmod +x ./start-webui.sh
42
  timeout 60s ./start-webui.sh || true
.github/workflows/ci.yml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: CI
2
+
3
+ on:
4
+ workflow_dispatch:
5
+
6
+ push:
7
+ branches:
8
+ - master
9
+ pull_request:
10
+ branches:
11
+ - master
12
+
13
+ jobs:
14
+ build:
15
+
16
+ runs-on: ubuntu-latest
17
+ strategy:
18
+ matrix:
19
+ python: ["3.10"]
20
+
21
+ env:
22
+ DEEPL_API_KEY: ${{ secrets.DEEPL_API_KEY }}
23
+
24
+ steps:
25
+ - name: Clean up space for action
26
+ run: rm -rf /opt/hostedtoolcache
27
+
28
+ - uses: actions/checkout@v4
29
+ - name: Setup Python
30
+ uses: actions/setup-python@v5
31
+ with:
32
+ python-version: ${{ matrix.python }}
33
+
34
+ - name: Install git and ffmpeg
35
+ run: sudo apt-get update && sudo apt-get install -y git ffmpeg
36
+
37
+ - name: Install dependencies
38
+ run: pip install -r requirements.txt pytest
39
+
40
+ - name: Run test
41
+ run: python -m pytest -rs tests
modules/translation/deepl_api.py CHANGED
@@ -98,8 +98,8 @@ class DeepLAPI:
98
  fileobjs: list,
99
  source_lang: str,
100
  target_lang: str,
101
- is_pro: bool,
102
- add_timestamp: bool,
103
  progress=gr.Progress()) -> list:
104
  """
105
  Translate subtitle files using DeepL API
@@ -126,6 +126,9 @@ class DeepLAPI:
126
  String to return to gr.Textbox()
127
  Files to return to gr.Files()
128
  """
 
 
 
129
  self.cache_parameters(
130
  api_key=auth_key,
131
  is_pro=is_pro,
@@ -136,37 +139,28 @@ class DeepLAPI:
136
 
137
  files_info = {}
138
  for fileobj in fileobjs:
139
- file_path = fileobj.name
140
- file_name, file_ext = os.path.splitext(os.path.basename(fileobj.name))
141
 
142
  if file_ext == ".srt":
143
  parsed_dicts = parse_srt(file_path=file_path)
144
 
145
- batch_size = self.max_text_batch_size
146
- for batch_start in range(0, len(parsed_dicts), batch_size):
147
- batch_end = min(batch_start + batch_size, len(parsed_dicts))
148
- sentences_to_translate = [dic["sentence"] for dic in parsed_dicts[batch_start:batch_end]]
149
- translated_texts = self.request_deepl_translate(auth_key, sentences_to_translate, source_lang,
150
- target_lang, is_pro)
151
- for i, translated_text in enumerate(translated_texts):
152
- parsed_dicts[batch_start + i]["sentence"] = translated_text["text"]
153
- progress(batch_end / len(parsed_dicts), desc="Translating..")
154
-
155
- subtitle = get_serialized_srt(parsed_dicts)
156
-
157
  elif file_ext == ".vtt":
158
  parsed_dicts = parse_vtt(file_path=file_path)
159
 
160
- batch_size = self.max_text_batch_size
161
- for batch_start in range(0, len(parsed_dicts), batch_size):
162
- batch_end = min(batch_start + batch_size, len(parsed_dicts))
163
- sentences_to_translate = [dic["sentence"] for dic in parsed_dicts[batch_start:batch_end]]
164
- translated_texts = self.request_deepl_translate(auth_key, sentences_to_translate, source_lang,
165
- target_lang, is_pro)
166
- for i, translated_text in enumerate(translated_texts):
167
- parsed_dicts[batch_start + i]["sentence"] = translated_text["text"]
168
- progress(batch_end / len(parsed_dicts), desc="Translating..")
169
 
 
 
 
170
  subtitle = get_serialized_vtt(parsed_dicts)
171
 
172
  if add_timestamp:
@@ -193,8 +187,14 @@ class DeepLAPI:
193
  text: list,
194
  source_lang: str,
195
  target_lang: str,
196
- is_pro: bool):
197
  """Request API response to DeepL server"""
 
 
 
 
 
 
198
 
199
  url = 'https://api.deepl.com/v2/translate' if is_pro else 'https://api-free.deepl.com/v2/translate'
200
  headers = {
 
98
  fileobjs: list,
99
  source_lang: str,
100
  target_lang: str,
101
+ is_pro: bool = False,
102
+ add_timestamp: bool = True,
103
  progress=gr.Progress()) -> list:
104
  """
105
  Translate subtitle files using DeepL API
 
126
  String to return to gr.Textbox()
127
  Files to return to gr.Files()
128
  """
129
+ if fileobjs and isinstance(fileobjs[0], gr.utils.NamedString):
130
+ fileobjs = [fileobj.name for fileobj in fileobjs]
131
+
132
  self.cache_parameters(
133
  api_key=auth_key,
134
  is_pro=is_pro,
 
139
 
140
  files_info = {}
141
  for fileobj in fileobjs:
142
+ file_path = fileobj
143
+ file_name, file_ext = os.path.splitext(os.path.basename(fileobj))
144
 
145
  if file_ext == ".srt":
146
  parsed_dicts = parse_srt(file_path=file_path)
147
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  elif file_ext == ".vtt":
149
  parsed_dicts = parse_vtt(file_path=file_path)
150
 
151
+ batch_size = self.max_text_batch_size
152
+ for batch_start in range(0, len(parsed_dicts), batch_size):
153
+ batch_end = min(batch_start + batch_size, len(parsed_dicts))
154
+ sentences_to_translate = [dic["sentence"] for dic in parsed_dicts[batch_start:batch_end]]
155
+ translated_texts = self.request_deepl_translate(auth_key, sentences_to_translate, source_lang,
156
+ target_lang, is_pro)
157
+ for i, translated_text in enumerate(translated_texts):
158
+ parsed_dicts[batch_start + i]["sentence"] = translated_text["text"]
159
+ progress(batch_end / len(parsed_dicts), desc="Translating..")
160
 
161
+ if file_ext == ".srt":
162
+ subtitle = get_serialized_srt(parsed_dicts)
163
+ elif file_ext == ".vtt":
164
  subtitle = get_serialized_vtt(parsed_dicts)
165
 
166
  if add_timestamp:
 
187
  text: list,
188
  source_lang: str,
189
  target_lang: str,
190
+ is_pro: bool = False):
191
  """Request API response to DeepL server"""
192
+ if source_lang not in list(DEEPL_AVAILABLE_SOURCE_LANGS.keys()):
193
+ raise ValueError(f"Source language {source_lang} is not supported."
194
+ f"Use one of {list(DEEPL_AVAILABLE_SOURCE_LANGS.keys())}")
195
+ if target_lang not in list(DEEPL_AVAILABLE_TARGET_LANGS.keys()):
196
+ raise ValueError(f"Target language {target_lang} is not supported."
197
+ f"Use one of {list(DEEPL_AVAILABLE_TARGET_LANGS.keys())}")
198
 
199
  url = 'https://api.deepl.com/v2/translate' if is_pro else 'https://api-free.deepl.com/v2/translate'
200
  headers = {
modules/translation/nllb_inference.py CHANGED
@@ -37,6 +37,17 @@ class NLLBInference(TranslationBase):
37
  tgt_lang: str,
38
  progress: gr.Progress = gr.Progress()
39
  ):
 
 
 
 
 
 
 
 
 
 
 
40
  if model_size != self.current_model_size or self.model is None:
41
  print("\nInitializing NLLB Model..\n")
42
  progress(0, desc="Initializing NLLB Model..")
@@ -48,8 +59,7 @@ class NLLBInference(TranslationBase):
48
  self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_size,
49
  cache_dir=os.path.join(self.model_dir, "tokenizers"),
50
  local_files_only=local_files_only)
51
- src_lang = NLLB_AVAILABLE_LANGS[src_lang]
52
- tgt_lang = NLLB_AVAILABLE_LANGS[tgt_lang]
53
  self.pipeline = pipeline("translation",
54
  model=self.model,
55
  tokenizer=self.tokenizer,
 
37
  tgt_lang: str,
38
  progress: gr.Progress = gr.Progress()
39
  ):
40
+ def validate_language(lang: str) -> str:
41
+ if lang in NLLB_AVAILABLE_LANGS:
42
+ return NLLB_AVAILABLE_LANGS[lang]
43
+ elif lang not in NLLB_AVAILABLE_LANGS.values():
44
+ raise ValueError(
45
+ f"Language '{lang}' is not supported. Use one of: {list(NLLB_AVAILABLE_LANGS.keys())}")
46
+ return lang
47
+
48
+ src_lang = validate_language(src_lang)
49
+ tgt_lang = validate_language(tgt_lang)
50
+
51
  if model_size != self.current_model_size or self.model is None:
52
  print("\nInitializing NLLB Model..\n")
53
  progress(0, desc="Initializing NLLB Model..")
 
59
  self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_size,
60
  cache_dir=os.path.join(self.model_dir, "tokenizers"),
61
  local_files_only=local_files_only)
62
+
 
63
  self.pipeline = pipeline("translation",
64
  model=self.model,
65
  tokenizer=self.tokenizer,
modules/translation/translation_base.py CHANGED
@@ -46,8 +46,8 @@ class TranslationBase(ABC):
46
  model_size: str,
47
  src_lang: str,
48
  tgt_lang: str,
49
- max_length: int,
50
- add_timestamp: bool,
51
  progress=gr.Progress()) -> list:
52
  """
53
  Translate subtitle file from source language to target language
@@ -77,6 +77,9 @@ class TranslationBase(ABC):
77
  Files to return to gr.Files()
78
  """
79
  try:
 
 
 
80
  self.cache_parameters(model_size=model_size,
81
  src_lang=src_lang,
82
  tgt_lang=tgt_lang,
@@ -90,10 +93,9 @@ class TranslationBase(ABC):
90
 
91
  files_info = {}
92
  for fileobj in fileobjs:
93
- file_path = fileobj.name
94
- file_name, file_ext = os.path.splitext(os.path.basename(fileobj.name))
95
  if file_ext == ".srt":
96
- parsed_dicts = parse_srt(file_path=file_path)
97
  total_progress = len(parsed_dicts)
98
  for index, dic in enumerate(parsed_dicts):
99
  progress(index / total_progress, desc="Translating..")
@@ -102,7 +104,7 @@ class TranslationBase(ABC):
102
  subtitle = get_serialized_srt(parsed_dicts)
103
 
104
  elif file_ext == ".vtt":
105
- parsed_dicts = parse_vtt(file_path=file_path)
106
  total_progress = len(parsed_dicts)
107
  for index, dic in enumerate(parsed_dicts):
108
  progress(index / total_progress, desc="Translating..")
 
46
  model_size: str,
47
  src_lang: str,
48
  tgt_lang: str,
49
+ max_length: int = 200,
50
+ add_timestamp: bool = True,
51
  progress=gr.Progress()) -> list:
52
  """
53
  Translate subtitle file from source language to target language
 
77
  Files to return to gr.Files()
78
  """
79
  try:
80
+ if fileobjs and isinstance(fileobjs[0], gr.utils.NamedString):
81
+ fileobjs = [file.name for file in fileobjs]
82
+
83
  self.cache_parameters(model_size=model_size,
84
  src_lang=src_lang,
85
  tgt_lang=tgt_lang,
 
93
 
94
  files_info = {}
95
  for fileobj in fileobjs:
96
+ file_name, file_ext = os.path.splitext(os.path.basename(fileobj))
 
97
  if file_ext == ".srt":
98
+ parsed_dicts = parse_srt(file_path=fileobj)
99
  total_progress = len(parsed_dicts)
100
  for index, dic in enumerate(parsed_dicts):
101
  progress(index / total_progress, desc="Translating..")
 
104
  subtitle = get_serialized_srt(parsed_dicts)
105
 
106
  elif file_ext == ".vtt":
107
+ parsed_dicts = parse_vtt(file_path=fileobj)
108
  total_progress = len(parsed_dicts)
109
  for index, dic in enumerate(parsed_dicts):
110
  progress(index / total_progress, desc="Translating..")
modules/utils/subtitle_manager.py CHANGED
@@ -119,11 +119,8 @@ def get_serialized_vtt(dicts):
119
 
120
 
121
  def safe_filename(name):
122
- from app import _args
123
  INVALID_FILENAME_CHARS = r'[<>:"/\\|?*\x00-\x1f]'
124
  safe_name = re.sub(INVALID_FILENAME_CHARS, '_', name)
125
- if not _args.colab:
126
- return safe_name
127
  # Truncate the filename if it exceeds the max_length (20)
128
  if len(safe_name) > 20:
129
  file_extension = safe_name.split('.')[-1]
 
119
 
120
 
121
  def safe_filename(name):
 
122
  INVALID_FILENAME_CHARS = r'[<>:"/\\|?*\x00-\x1f]'
123
  safe_name = re.sub(INVALID_FILENAME_CHARS, '_', name)
 
 
124
  # Truncate the filename if it exceeds the max_length (20)
125
  if len(safe_name) > 20:
126
  file_extension = safe_name.split('.')[-1]
modules/whisper/whisper_base.py CHANGED
@@ -104,7 +104,9 @@ class WhisperBase(ABC):
104
  add_timestamp=add_timestamp
105
  )
106
 
107
- if params.lang == "Automatic Detection":
 
 
108
  params.lang = None
109
  else:
110
  language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
@@ -133,7 +135,7 @@ class WhisperBase(ABC):
133
 
134
  if params.vad_filter:
135
  # Explicit value set for float('inf') from gr.Number()
136
- if params.max_speech_duration_s >= 9999:
137
  params.max_speech_duration_s = float('inf')
138
 
139
  vad_options = VadOptions(
@@ -208,18 +210,21 @@ class WhisperBase(ABC):
208
  try:
209
  if input_folder_path:
210
  files = get_media_files(input_folder_path)
211
- files = format_gradio_files(files)
 
 
 
212
 
213
  files_info = {}
214
  for file in files:
215
  transcribed_segments, time_for_task = self.run(
216
- file.name,
217
  progress,
218
  add_timestamp,
219
  *whisper_params,
220
  )
221
 
222
- file_name, file_ext = os.path.splitext(os.path.basename(file.name))
223
  subtitle, file_path = self.generate_and_write_file(
224
  file_name=file_name,
225
  transcribed_segments=transcribed_segments,
 
104
  add_timestamp=add_timestamp
105
  )
106
 
107
+ if params.lang is None:
108
+ pass
109
+ elif params.lang == "Automatic Detection":
110
  params.lang = None
111
  else:
112
  language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
 
135
 
136
  if params.vad_filter:
137
  # Explicit value set for float('inf') from gr.Number()
138
+ if params.max_speech_duration_s is None or params.max_speech_duration_s >= 9999:
139
  params.max_speech_duration_s = float('inf')
140
 
141
  vad_options = VadOptions(
 
210
  try:
211
  if input_folder_path:
212
  files = get_media_files(input_folder_path)
213
+ if isinstance(files, str):
214
+ files = [files]
215
+ if files and isinstance(files[0], gr.utils.NamedString):
216
+ files = [file.name for file in files]
217
 
218
  files_info = {}
219
  for file in files:
220
  transcribed_segments, time_for_task = self.run(
221
+ file,
222
  progress,
223
  add_timestamp,
224
  *whisper_params,
225
  )
226
 
227
+ file_name, file_ext = os.path.splitext(os.path.basename(file))
228
  subtitle, file_path = self.generate_and_write_file(
229
  file_name=file_name,
230
  transcribed_segments=transcribed_segments,
modules/whisper/whisper_parameter.py CHANGED
@@ -357,3 +357,13 @@ class WhisperValues:
357
  },
358
  }
359
  return data
 
 
 
 
 
 
 
 
 
 
 
357
  },
358
  }
359
  return data
360
+
361
+ def as_list(self) -> list:
362
+ """
363
+ Converts the data class attributes into a list
364
+
365
+ Returns
366
+ ----------
367
+ A list of Whisper parameters
368
+ """
369
+ return [getattr(self, f.name) for f in fields(self)]
requirements.txt CHANGED
@@ -12,6 +12,6 @@ transformers==4.42.3
12
  gradio==4.43.0
13
  pytubefix
14
  ruamel.yaml==0.18.6
15
- pyannote.audio==3.3.1
16
  git+https://github.com/jhj0517/ultimatevocalremover_api.git
17
  git+https://github.com/jhj0517/pyrubberband.git
 
12
  gradio==4.43.0
13
  pytubefix
14
  ruamel.yaml==0.18.6
15
+ pyannote.audio==3.3.1;
16
  git+https://github.com/jhj0517/ultimatevocalremover_api.git
17
  git+https://github.com/jhj0517/pyrubberband.git
tests/test_bgm_separation.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.utils.paths import *
2
+ from modules.whisper.whisper_factory import WhisperFactory
3
+ from modules.whisper.whisper_parameter import WhisperValues
4
+ from test_config import *
5
+ from test_transcription import download_file, test_transcribe
6
+
7
+ import gradio as gr
8
+ import pytest
9
+ import torch
10
+ import os
11
+
12
+
13
+ @pytest.mark.skipif(
14
+ not is_cuda_available(),
15
+ reason="Skipping because the test only works on GPU"
16
+ )
17
+ @pytest.mark.parametrize(
18
+ "whisper_type,vad_filter,bgm_separation,diarization",
19
+ [
20
+ ("whisper", False, True, False),
21
+ ("faster-whisper", False, True, False),
22
+ ("insanely_fast_whisper", False, True, False)
23
+ ]
24
+ )
25
+ def test_bgm_separation_pipeline(
26
+ whisper_type: str,
27
+ vad_filter: bool,
28
+ bgm_separation: bool,
29
+ diarization: bool,
30
+ ):
31
+ test_transcribe(whisper_type, vad_filter, bgm_separation, diarization)
32
+
33
+
34
+ @pytest.mark.skipif(
35
+ not is_cuda_available(),
36
+ reason="Skipping because the test only works on GPU"
37
+ )
38
+ @pytest.mark.parametrize(
39
+ "whisper_type,vad_filter,bgm_separation,diarization",
40
+ [
41
+ ("whisper", True, True, False),
42
+ ("faster-whisper", True, True, False),
43
+ ("insanely_fast_whisper", True, True, False)
44
+ ]
45
+ )
46
+ def test_bgm_separation_with_vad_pipeline(
47
+ whisper_type: str,
48
+ vad_filter: bool,
49
+ bgm_separation: bool,
50
+ diarization: bool,
51
+ ):
52
+ test_transcribe(whisper_type, vad_filter, bgm_separation, diarization)
53
+
tests/test_config.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.utils.paths import *
2
+
3
+ import os
4
+ import torch
5
+
6
+ TEST_FILE_DOWNLOAD_URL = "https://github.com/jhj0517/whisper_flutter_new/raw/main/example/assets/jfk.wav"
7
+ TEST_FILE_PATH = os.path.join(WEBUI_DIR, "tests", "jfk.wav")
8
+ TEST_YOUTUBE_URL = "https://www.youtube.com/watch?v=4WEQtgnBu0I&ab_channel=AndriaFitzer"
9
+ TEST_WHISPER_MODEL = "tiny"
10
+ TEST_UVR_MODEL = "UVR-MDX-NET-Inst_HQ_4"
11
+ TEST_NLLB_MODEL = "facebook/nllb-200-distilled-600M"
12
+ TEST_SUBTITLE_SRT_PATH = os.path.join(WEBUI_DIR, "tests", "test_srt.srt")
13
+ TEST_SUBTITLE_VTT_PATH = os.path.join(WEBUI_DIR, "tests", "test_vtt.vtt")
14
+
15
+
16
+ def is_cuda_available():
17
+ return torch.cuda.is_available()
tests/test_diarization.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.utils.paths import *
2
+ from modules.whisper.whisper_factory import WhisperFactory
3
+ from modules.whisper.whisper_parameter import WhisperValues
4
+ from test_config import *
5
+ from test_transcription import download_file, test_transcribe
6
+
7
+ import gradio as gr
8
+ import pytest
9
+ import os
10
+
11
+
12
+ @pytest.mark.skipif(
13
+ not is_cuda_available(),
14
+ reason="Skipping because the test only works on GPU"
15
+ )
16
+ @pytest.mark.parametrize(
17
+ "whisper_type,vad_filter,bgm_separation,diarization",
18
+ [
19
+ ("whisper", False, False, True),
20
+ ("faster-whisper", False, False, True),
21
+ ("insanely_fast_whisper", False, False, True)
22
+ ]
23
+ )
24
+ def test_diarization_pipeline(
25
+ whisper_type: str,
26
+ vad_filter: bool,
27
+ bgm_separation: bool,
28
+ diarization: bool,
29
+ ):
30
+ test_transcribe(whisper_type, vad_filter, bgm_separation, diarization)
31
+
tests/test_srt.srt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ 1
2
+ 00:00:00,000 --> 00:00:02,240
3
+ You've got
4
+
5
+ 2
6
+ 00:00:02,240 --> 00:00:04,160
7
+ a friend in me.
tests/test_transcription.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.whisper.whisper_factory import WhisperFactory
2
+ from modules.whisper.whisper_parameter import WhisperValues
3
+ from modules.utils.paths import WEBUI_DIR
4
+ from test_config import *
5
+
6
+ import requests
7
+ import pytest
8
+ import gradio as gr
9
+ import os
10
+
11
+
12
+ @pytest.mark.parametrize(
13
+ "whisper_type,vad_filter,bgm_separation,diarization",
14
+ [
15
+ ("whisper", False, False, False),
16
+ ("faster-whisper", False, False, False),
17
+ ("insanely_fast_whisper", False, False, False)
18
+ ]
19
+ )
20
+ def test_transcribe(
21
+ whisper_type: str,
22
+ vad_filter: bool,
23
+ bgm_separation: bool,
24
+ diarization: bool,
25
+ ):
26
+ audio_path_dir = os.path.join(WEBUI_DIR, "tests")
27
+ audio_path = os.path.join(audio_path_dir, "jfk.wav")
28
+ if not os.path.exists(audio_path):
29
+ download_file(TEST_FILE_DOWNLOAD_URL, audio_path_dir)
30
+
31
+ whisper_inferencer = WhisperFactory.create_whisper_inference(
32
+ whisper_type=whisper_type,
33
+ )
34
+ print(
35
+ f"""Whisper Device : {whisper_inferencer.device}\n"""
36
+ f"""BGM Separation Device: {whisper_inferencer.music_separator.device}\n"""
37
+ f"""Diarization Device: {whisper_inferencer.diarizer.device}"""
38
+ )
39
+
40
+ hparams = WhisperValues(
41
+ model_size=TEST_WHISPER_MODEL,
42
+ vad_filter=vad_filter,
43
+ is_bgm_separate=bgm_separation,
44
+ compute_type=whisper_inferencer.current_compute_type,
45
+ uvr_enable_offload=True,
46
+ is_diarize=diarization,
47
+ ).as_list()
48
+
49
+ subtitle_str, file_path = whisper_inferencer.transcribe_file(
50
+ [audio_path],
51
+ None,
52
+ "SRT",
53
+ False,
54
+ gr.Progress(),
55
+ *hparams,
56
+ )
57
+
58
+ assert isinstance(subtitle_str, str) and subtitle_str
59
+ assert isinstance(file_path[0], str) and file_path
60
+
61
+ whisper_inferencer.transcribe_youtube(
62
+ TEST_YOUTUBE_URL,
63
+ "SRT",
64
+ False,
65
+ gr.Progress(),
66
+ *hparams,
67
+ )
68
+ assert isinstance(subtitle_str, str) and subtitle_str
69
+ assert isinstance(file_path[0], str) and file_path
70
+
71
+ whisper_inferencer.transcribe_mic(
72
+ audio_path,
73
+ "SRT",
74
+ False,
75
+ gr.Progress(),
76
+ *hparams,
77
+ )
78
+ assert isinstance(subtitle_str, str) and subtitle_str
79
+ assert isinstance(file_path[0], str) and file_path
80
+
81
+
82
+ def download_file(url, save_dir):
83
+ if os.path.exists(TEST_FILE_PATH):
84
+ return
85
+
86
+ if not os.path.exists(save_dir):
87
+ os.makedirs(save_dir)
88
+
89
+ file_name = url.split("/")[-1]
90
+ file_path = os.path.join(save_dir, file_name)
91
+
92
+ response = requests.get(url)
93
+
94
+ with open(file_path, "wb") as file:
95
+ file.write(response.content)
96
+
97
+ print(f"File downloaded to: {file_path}")
tests/test_translation.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.translation.deepl_api import DeepLAPI
2
+ from modules.translation.nllb_inference import NLLBInference
3
+ from test_config import *
4
+
5
+ import os
6
+ import pytest
7
+
8
+
9
+ @pytest.mark.parametrize("model_size, file_path", [
10
+ (TEST_NLLB_MODEL, TEST_SUBTITLE_SRT_PATH),
11
+ (TEST_NLLB_MODEL, TEST_SUBTITLE_VTT_PATH),
12
+ ])
13
+ def test_nllb_inference(
14
+ model_size: str,
15
+ file_path: str
16
+ ):
17
+ nllb_inferencer = NLLBInference()
18
+ print(f"NLLB Device : {nllb_inferencer.device}")
19
+
20
+ result_str, file_paths = nllb_inferencer.translate_file(
21
+ fileobjs=[file_path],
22
+ model_size=model_size,
23
+ src_lang="eng_Latn",
24
+ tgt_lang="kor_Hang",
25
+ )
26
+
27
+ assert isinstance(result_str, str)
28
+ assert isinstance(file_paths[0], str)
29
+
30
+
31
+ @pytest.mark.parametrize("file_path", [
32
+ TEST_SUBTITLE_SRT_PATH,
33
+ TEST_SUBTITLE_VTT_PATH,
34
+ ])
35
+ def test_deepl_api(
36
+ file_path: str
37
+ ):
38
+ deepl_api = DeepLAPI()
39
+
40
+ api_key = os.getenv("DEEPL_API_KEY")
41
+
42
+ result_str, file_paths = deepl_api.translate_deepl(
43
+ auth_key=api_key,
44
+ fileobjs=[file_path],
45
+ source_lang="English",
46
+ target_lang="Korean",
47
+ is_pro=False,
48
+ add_timestamp=True,
49
+ )
50
+
51
+ assert isinstance(result_str, str)
52
+ assert isinstance(file_paths[0], str)
tests/test_vad.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.utils.paths import *
2
+ from modules.whisper.whisper_factory import WhisperFactory
3
+ from modules.whisper.whisper_parameter import WhisperValues
4
+ from test_config import *
5
+ from test_transcription import download_file, test_transcribe
6
+
7
+ import gradio as gr
8
+ import pytest
9
+ import os
10
+
11
+
12
+ @pytest.mark.parametrize(
13
+ "whisper_type,vad_filter,bgm_separation,diarization",
14
+ [
15
+ ("whisper", True, False, False),
16
+ ("faster-whisper", True, False, False),
17
+ ("insanely_fast_whisper", True, False, False)
18
+ ]
19
+ )
20
+ def test_vad_pipeline(
21
+ whisper_type: str,
22
+ vad_filter: bool,
23
+ bgm_separation: bool,
24
+ diarization: bool,
25
+ ):
26
+ test_transcribe(whisper_type, vad_filter, bgm_separation, diarization)
tests/test_vtt.vtt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ WEBVTT
2
+ 00:00:00.500 --> 00:00:02.000
3
+ You've got
4
+
5
+ 00:00:02.500 --> 00:00:04.300
6
+ a friend in me.