aadnk commited on
Commit
cb9ee50
1 Parent(s): 479b187

Add support for multiple input files and output files

Browse files
Files changed (4) hide show
  1. app.py +77 -29
  2. cli.py +0 -3
  3. src/download.py +9 -3
  4. src/source.py +70 -0
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import math
2
  from typing import Iterator
3
  import argparse
@@ -6,9 +7,11 @@ from io import StringIO
6
  import os
7
  import pathlib
8
  import tempfile
 
9
 
10
  import torch
11
  from src.modelCache import ModelCache
 
12
  from src.vadParallel import ParallelContext, ParallelTranscription
13
 
14
  # External programs
@@ -78,9 +81,9 @@ class WhisperTranscriber:
78
  self.vad_cpu_cores = min(os.cpu_count(), MAX_AUTO_CPU_CORES)
79
  print("[Auto parallel] Using GPU devices " + str(self.parallel_device_list) + " and " + str(self.vad_cpu_cores) + " CPU cores for VAD/transcription.")
80
 
81
- def transcribe_webui(self, modelName, languageName, urlData, uploadFile, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow):
82
  try:
83
- source, sourceName = self.__get_source(urlData, uploadFile, microphoneData)
84
 
85
  try:
86
  selectedLanguage = languageName.lower() if len(languageName) > 0 else None
@@ -88,22 +91,84 @@ class WhisperTranscriber:
88
 
89
  model = WhisperContainer(model_name=selectedModel, cache=self.model_cache)
90
 
91
- # Execute whisper
92
- result = self.transcribe_file(model, source, selectedLanguage, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
 
 
 
93
 
94
  # Write result
95
  downloadDirectory = tempfile.mkdtemp()
96
-
97
- filePrefix = slugify(sourceName, allow_unicode=True)
98
- download, text, vtt = self.write_result(result, filePrefix, downloadDirectory)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  return download, text, vtt
101
 
102
  finally:
103
  # Cleanup source
104
  if self.deleteUploadedFiles:
105
- print("Deleting source file " + source)
106
- os.remove(source)
 
 
 
 
 
 
107
 
108
  except ExceededMaximumDuration as e:
109
  return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
@@ -222,25 +287,8 @@ class WhisperTranscriber:
222
  self.model_cache.clear()
223
  self.vad_model = None
224
 
225
- def __get_source(self, urlData, uploadFile, microphoneData):
226
- if urlData:
227
- # Download from YouTube
228
- source = download_url(urlData, self.inputAudioMaxDuration)[0]
229
- else:
230
- # File input
231
- source = uploadFile if uploadFile is not None else microphoneData
232
-
233
- if self.inputAudioMaxDuration > 0:
234
- # Calculate audio length
235
- audioDuration = ffmpeg.probe(source)["format"]["duration"]
236
-
237
- if float(audioDuration) > self.inputAudioMaxDuration:
238
- raise ExceededMaximumDuration(videoDuration=audioDuration, maxDuration=self.inputAudioMaxDuration, message="Video is too long")
239
-
240
- file_path = pathlib.Path(source)
241
- sourceName = file_path.stem[:MAX_FILE_PREFIX_LENGTH] + file_path.suffix
242
-
243
- return source, sourceName
244
 
245
  def __get_max_line_width(self, language: str) -> int:
246
  if (language and language.lower() in ["japanese", "ja", "chinese", "zh"]):
@@ -304,7 +352,7 @@ def create_ui(input_audio_max_duration, share=False, server_name: str = None, se
304
  gr.Dropdown(choices=["tiny", "base", "small", "medium", "large"], value=default_model_name, label="Model"),
305
  gr.Dropdown(choices=sorted(LANGUAGES), label="Language"),
306
  gr.Text(label="URL (YouTube, etc.)"),
307
- gr.Audio(source="upload", type="filepath", label="Upload Audio"),
308
  gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
309
  gr.Dropdown(choices=["transcribe", "translate"], label="Task"),
310
  gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], value=default_vad, label="VAD"),
 
1
+ from datetime import datetime
2
  import math
3
  from typing import Iterator
4
  import argparse
 
7
  import os
8
  import pathlib
9
  import tempfile
10
+ import zipfile
11
 
12
  import torch
13
  from src.modelCache import ModelCache
14
+ from src.source import get_audio_source_collection
15
  from src.vadParallel import ParallelContext, ParallelTranscription
16
 
17
  # External programs
 
81
  self.vad_cpu_cores = min(os.cpu_count(), MAX_AUTO_CPU_CORES)
82
  print("[Auto parallel] Using GPU devices " + str(self.parallel_device_list) + " and " + str(self.vad_cpu_cores) + " CPU cores for VAD/transcription.")
83
 
84
+ def transcribe_webui(self, modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow):
85
  try:
86
+ sources = self.__get_source(urlData, multipleFiles, microphoneData)
87
 
88
  try:
89
  selectedLanguage = languageName.lower() if len(languageName) > 0 else None
 
91
 
92
  model = WhisperContainer(model_name=selectedModel, cache=self.model_cache)
93
 
94
+ # Result
95
+ download = []
96
+ zip_file_lookup = {}
97
+ text = ""
98
+ vtt = ""
99
 
100
  # Write result
101
  downloadDirectory = tempfile.mkdtemp()
102
+ source_index = 0
103
+
104
+ # Execute whisper
105
+ for source in sources:
106
+ source_prefix = ""
107
+
108
+ if (len(sources) > 1):
109
+ # Prefix (minimum 2 digits)
110
+ source_index += 1
111
+ source_prefix = str(source_index).zfill(2) + "_"
112
+ print("Transcribing ", source.source_path)
113
+
114
+ # Transcribe
115
+ result = self.transcribe_file(model, source.source_path, selectedLanguage, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
116
+ filePrefix = slugify(source_prefix + source.get_short_name(), allow_unicode=True)
117
+
118
+ source_download, source_text, source_vtt = self.write_result(result, filePrefix, downloadDirectory)
119
+
120
+ if len(sources) > 1:
121
+ # Add new line separators
122
+ if (len(source_text) > 0):
123
+ source_text += os.linesep + os.linesep
124
+ if (len(source_vtt) > 0):
125
+ source_vtt += os.linesep + os.linesep
126
+
127
+ # Append file name to source text too
128
+ source_text = source.get_full_name() + ":" + os.linesep + source_text
129
+ source_vtt = source.get_full_name() + ":" + os.linesep + source_vtt
130
+
131
+ # Add to result
132
+ download.extend(source_download)
133
+ text += source_text
134
+ vtt += source_vtt
135
+
136
+ if (len(sources) > 1):
137
+ # Zip files support at least 260 characters, but we'll play it safe and use 200
138
+ zipFilePrefix = slugify(source_prefix + source.get_short_name(max_length=200), allow_unicode=True)
139
+
140
+ # File names in ZIP file can be longer
141
+ for source_download_file in source_download:
142
+ # Get file postfix (after last -)
143
+ filePostfix = os.path.basename(source_download_file).split("-")[-1]
144
+ zip_file_name = zipFilePrefix + "-" + filePostfix
145
+ zip_file_lookup[source_download_file] = zip_file_name
146
+
147
+ # Create zip file from all sources
148
+ if len(sources) > 1:
149
+ downloadAllPath = os.path.join(downloadDirectory, "All_Output-" + datetime.now().strftime("%Y%m%d-%H%M%S") + ".zip")
150
+
151
+ with zipfile.ZipFile(downloadAllPath, 'w', zipfile.ZIP_DEFLATED) as zip:
152
+ for download_file in download:
153
+ # Get file name from lookup
154
+ zip_file_name = zip_file_lookup.get(download_file, os.path.basename(download_file))
155
+ zip.write(download_file, arcname=zip_file_name)
156
+
157
+ download.insert(0, downloadAllPath)
158
 
159
  return download, text, vtt
160
 
161
  finally:
162
  # Cleanup source
163
  if self.deleteUploadedFiles:
164
+ for source in sources:
165
+ print("Deleting source file " + source.source_path)
166
+
167
+ try:
168
+ os.remove(source.source_path)
169
+ except Exception as e:
170
+ # Ignore error - it's just a cleanup
171
+ print("Error deleting source file " + source.source_path + ": " + str(e))
172
 
173
  except ExceededMaximumDuration as e:
174
  return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
 
287
  self.model_cache.clear()
288
  self.vad_model = None
289
 
290
+ def __get_source(self, urlData, multipleFiles, microphoneData):
291
+ return get_audio_source_collection(urlData, multipleFiles, microphoneData, self.inputAudioMaxDuration)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
 
293
  def __get_max_line_width(self, language: str) -> int:
294
  if (language and language.lower() in ["japanese", "ja", "chinese", "zh"]):
 
352
  gr.Dropdown(choices=["tiny", "base", "small", "medium", "large"], value=default_model_name, label="Model"),
353
  gr.Dropdown(choices=sorted(LANGUAGES), label="Language"),
354
  gr.Text(label="URL (YouTube, etc.)"),
355
+ gr.File(label="Upload Files", file_count="multiple"),
356
  gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
357
  gr.Dropdown(choices=["transcribe", "translate"], label="Task"),
358
  gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], value=default_vad, label="VAD"),
cli.py CHANGED
@@ -5,8 +5,6 @@ from urllib.parse import urlparse
5
  import warnings
6
  import numpy as np
7
 
8
- import whisper
9
-
10
  import torch
11
  from app import LANGUAGES, WhisperTranscriber
12
  from src.download import download_url
@@ -14,7 +12,6 @@ from src.download import download_url
14
  from src.utils import optional_float, optional_int, str2bool
15
  from src.whisperContainer import WhisperContainer
16
 
17
-
18
  def cli():
19
  parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
20
  parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
 
5
  import warnings
6
  import numpy as np
7
 
 
 
8
  import torch
9
  from app import LANGUAGES, WhisperTranscriber
10
  from src.download import download_url
 
12
  from src.utils import optional_float, optional_int, str2bool
13
  from src.whisperContainer import WhisperContainer
14
 
 
15
  def cli():
16
  parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
17
  parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
src/download.py CHANGED
@@ -46,10 +46,16 @@ def _perform_download(url: str, maxDuration: int = None, outputTemplate: str = N
46
  with YoutubeDL(ydl_opts) as ydl:
47
  if maxDuration and maxDuration > 0:
48
  info = ydl.extract_info(url, download=False)
49
- duration = info['duration']
50
 
51
- if duration >= maxDuration:
52
- raise ExceededMaximumDuration(videoDuration=duration, maxDuration=maxDuration, message="Video is too long")
 
 
 
 
 
 
53
 
54
  ydl.add_post_processor(filename_collector)
55
  ydl.download([url])
 
46
  with YoutubeDL(ydl_opts) as ydl:
47
  if maxDuration and maxDuration > 0:
48
  info = ydl.extract_info(url, download=False)
49
+ entries = "entries" in info and info["entries"] or [info]
50
 
51
+ total_duration = 0
52
+
53
+ # Compute total duration
54
+ for entry in entries:
55
+ total_duration += float(entry["duration"])
56
+
57
+ if total_duration >= maxDuration:
58
+ raise ExceededMaximumDuration(videoDuration=total_duration, maxDuration=maxDuration, message="Video is too long")
59
 
60
  ydl.add_post_processor(filename_collector)
61
  ydl.download([url])
src/source.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Gradio seems to truncate files without keeping the extension, so we need to truncate the file prefix ourself
2
+ import os
3
+ import pathlib
4
+ from typing import List
5
+ import zipfile
6
+
7
+ import ffmpeg
8
+ from more_itertools import unzip
9
+
10
+ from src.download import ExceededMaximumDuration, download_url
11
+
12
+ MAX_FILE_PREFIX_LENGTH = 17
13
+
14
+ class AudioSource:
15
+ def __init__(self, source_path, source_name = None):
16
+ self.source_path = source_path
17
+ self.source_name = source_name
18
+
19
+ # Load source name if not provided
20
+ if (self.source_name is None):
21
+ file_path = pathlib.Path(self.source_path)
22
+ self.source_name = file_path.name
23
+
24
+ def get_full_name(self):
25
+ return self.source_name
26
+
27
+ def get_short_name(self, max_length: int = MAX_FILE_PREFIX_LENGTH):
28
+ file_path = pathlib.Path(self.source_name)
29
+ short_name = file_path.stem[:max_length] + file_path.suffix
30
+
31
+ return short_name
32
+
33
+ def __str__(self) -> str:
34
+ return self.source_path
35
+
36
+ class AudioSourceCollection:
37
+ def __init__(self, sources: List[AudioSource]):
38
+ self.sources = sources
39
+
40
+ def __iter__(self):
41
+ return iter(self.sources)
42
+
43
+ def get_audio_source_collection(urlData: str, multipleFiles: List, microphoneData: str, input_audio_max_duration: float = -1) -> List[AudioSource]:
44
+ output: List[AudioSource] = []
45
+
46
+ if urlData:
47
+ # Download from YouTube. This could also be a playlist or a channel.
48
+ output.extend([ AudioSource(x) for x in download_url(urlData, input_audio_max_duration, playlistItems=None) ])
49
+ else:
50
+ # Add input files
51
+ if (multipleFiles is not None):
52
+ output.extend([ AudioSource(x.name) for x in multipleFiles ])
53
+ if (microphoneData is not None):
54
+ output.append(AudioSource(microphoneData))
55
+
56
+ total_duration = 0
57
+
58
+ # Calculate total audio length. We do this even if input_audio_max_duration
59
+ # is disabled to ensure that all the audio files are valid.
60
+ for source in output:
61
+ audioDuration = ffmpeg.probe(source.source_path)["format"]["duration"]
62
+ total_duration += float(audioDuration)
63
+
64
+ # Ensure the total duration of the audio is not too long
65
+ if input_audio_max_duration > 0:
66
+ if float(total_duration) > input_audio_max_duration:
67
+ raise ExceededMaximumDuration(videoDuration=total_duration, maxDuration=input_audio_max_duration, message="Video(s) is too long")
68
+
69
+ # Return a list of audio sources
70
+ return output