aadnk commited on
Commit
883c794
1 Parent(s): 7f502b4

Refactor function names

Browse files

Also prepare code for creating a CLI

Files changed (5) hide show
  1. app-local.py +2 -2
  2. app-network.py +2 -2
  3. app-shared.py +2 -2
  4. app.py +50 -46
  5. src/download.py +4 -4
app-local.py CHANGED
@@ -1,3 +1,3 @@
1
  # Run the app with no audio file restrictions
2
- from app import createUi
3
- createUi(-1)
 
1
  # Run the app with no audio file restrictions
2
+ from app import create_ui
3
+ create_ui(-1)
app-network.py CHANGED
@@ -1,3 +1,3 @@
1
  # Run the app with no audio file restrictions, and make it available on the network
2
- from app import createUi
3
- createUi(-1, server_name="0.0.0.0")
 
1
  # Run the app with no audio file restrictions, and make it available on the network
2
+ from app import create_ui
3
+ create_ui(-1, server_name="0.0.0.0")
app-shared.py CHANGED
@@ -1,3 +1,3 @@
1
  # Run the app with no audio file restrictions
2
- from app import createUi
3
- createUi(-1, share=True)
 
1
  # Run the app with no audio file restrictions
2
+ from app import create_ui
3
+ create_ui(-1, share=True)
app.py CHANGED
@@ -12,7 +12,7 @@ import ffmpeg
12
  # UI
13
  import gradio as gr
14
 
15
- from src.download import ExceededMaximumDuration, downloadUrl
16
  from src.utils import slugify, write_srt, write_vtt
17
  from src.vad import VadPeriodicTranscription, VadSileroTranscription
18
 
@@ -45,26 +45,27 @@ LANGUAGES = [
45
  "Hausa", "Bashkir", "Javanese", "Sundanese"
46
  ]
47
 
48
- model_cache = dict()
 
 
49
 
50
- class UI:
51
- def __init__(self, inputAudioMaxDuration):
52
  self.vad_model = None
53
  self.inputAudioMaxDuration = inputAudioMaxDuration
 
54
 
55
- def transcribeFile(self, modelName, languageName, urlData, uploadFile, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding):
56
  try:
57
- source, sourceName = self.getSource(urlData, uploadFile, microphoneData)
58
 
59
  try:
60
  selectedLanguage = languageName.lower() if len(languageName) > 0 else None
61
  selectedModel = modelName if modelName is not None else "base"
62
 
63
- model = model_cache.get(selectedModel, None)
64
 
65
  if not model:
66
  model = whisper.load_model(selectedModel)
67
- model_cache[selectedModel] = model
68
 
69
  # Callable for processing an audio file
70
  whisperCallable = lambda audio : model.transcribe(audio, language=selectedLanguage, task=task)
@@ -100,36 +101,39 @@ class UI:
100
  text = result["text"]
101
 
102
  language = result["language"]
103
- languageMaxLineWidth = getMaxLineWidth(language)
104
 
105
  print("Max line width " + str(languageMaxLineWidth))
106
- vtt = getSubs(result["segments"], "vtt", languageMaxLineWidth)
107
- srt = getSubs(result["segments"], "srt", languageMaxLineWidth)
108
 
109
  # Files that can be downloaded
110
  downloadDirectory = tempfile.mkdtemp()
111
  filePrefix = slugify(sourceName, allow_unicode=True)
112
 
113
  download = []
114
- download.append(createFile(srt, downloadDirectory, filePrefix + "-subs.srt"));
115
- download.append(createFile(vtt, downloadDirectory, filePrefix + "-subs.vtt"));
116
- download.append(createFile(text, downloadDirectory, filePrefix + "-transcript.txt"));
117
 
118
  return download, text, vtt
119
 
120
  finally:
121
  # Cleanup source
122
- if DELETE_UPLOADED_FILES:
123
  print("Deleting source file " + source)
124
  os.remove(source)
125
 
126
  except ExceededMaximumDuration as e:
127
  return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
128
 
129
- def getSource(self, urlData, uploadFile, microphoneData):
 
 
 
130
  if urlData:
131
  # Download from YouTube
132
- source = downloadUrl(urlData, self.inputAudioMaxDuration)
133
  else:
134
  # File input
135
  source = uploadFile if uploadFile is not None else microphoneData
@@ -146,38 +150,38 @@ class UI:
146
 
147
  return source, sourceName
148
 
149
- def getMaxLineWidth(language: str) -> int:
150
- if (language and language.lower() in ["japanese", "ja", "chinese", "zh"]):
151
- # Chinese characters and kana are wider, so limit line length to 40 characters
152
- return 40
153
- else:
154
- # TODO: Add more languages
155
- # 80 latin characters should fit on a 1080p/720p screen
156
- return 80
 
 
 
157
 
158
- def createFile(text: str, directory: str, fileName: str) -> str:
159
- # Write the text to a file
160
- with open(os.path.join(directory, fileName), 'w+', encoding="utf-8") as file:
161
- file.write(text)
 
 
162
 
163
- return file.name
 
164
 
165
- def getSubs(segments: Iterator[dict], format: str, maxLineWidth: int) -> str:
166
- segmentStream = StringIO()
 
 
167
 
168
- if format == 'vtt':
169
- write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
170
- elif format == 'srt':
171
- write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
172
- else:
173
- raise Exception("Unknown format " + format)
174
 
175
- segmentStream.seek(0)
176
- return segmentStream.read()
177
-
178
 
179
- def createUi(inputAudioMaxDuration, share=False, server_name: str = None):
180
- ui = UI(inputAudioMaxDuration)
181
 
182
  ui_description = "Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse "
183
  ui_description += " audio and is also a multi-task model that can perform multilingual speech recognition "
@@ -188,9 +192,9 @@ def createUi(inputAudioMaxDuration, share=False, server_name: str = None):
188
  if inputAudioMaxDuration > 0:
189
  ui_description += "\n\n" + "Max audio file length: " + str(inputAudioMaxDuration) + " s"
190
 
191
- ui_article = "Read the [documentation her](https://huggingface.co/spaces/aadnk/whisper-webui/blob/main/docs/options.md)"
192
 
193
- demo = gr.Interface(fn=ui.transcribeFile, description=ui_description, article=ui_article, inputs=[
194
  gr.Dropdown(choices=["tiny", "base", "small", "medium", "large"], value="medium", label="Model"),
195
  gr.Dropdown(choices=sorted(LANGUAGES), label="Language"),
196
  gr.Text(label="URL (YouTube, etc.)"),
@@ -210,4 +214,4 @@ def createUi(inputAudioMaxDuration, share=False, server_name: str = None):
210
  demo.launch(share=share, server_name=server_name)
211
 
212
  if __name__ == '__main__':
213
- createUi(DEFAULT_INPUT_AUDIO_MAX_DURATION)
 
12
  # UI
13
  import gradio as gr
14
 
15
+ from src.download import ExceededMaximumDuration, download_url
16
  from src.utils import slugify, write_srt, write_vtt
17
  from src.vad import VadPeriodicTranscription, VadSileroTranscription
18
 
 
45
  "Hausa", "Bashkir", "Javanese", "Sundanese"
46
  ]
47
 
48
+ class WhisperTranscriber:
49
+ def __init__(self, inputAudioMaxDuration: float = DEFAULT_INPUT_AUDIO_MAX_DURATION, deleteUploadedFiles: bool = DELETE_UPLOADED_FILES):
50
+ self.model_cache = dict()
51
 
 
 
52
  self.vad_model = None
53
  self.inputAudioMaxDuration = inputAudioMaxDuration
54
+ self.deleteUploadedFiles = deleteUploadedFiles
55
 
56
+ def transcribe_file(self, modelName, languageName, urlData, uploadFile, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding):
57
  try:
58
+ source, sourceName = self.__get_source(urlData, uploadFile, microphoneData)
59
 
60
  try:
61
  selectedLanguage = languageName.lower() if len(languageName) > 0 else None
62
  selectedModel = modelName if modelName is not None else "base"
63
 
64
+ model = self.model_cache.get(selectedModel, None)
65
 
66
  if not model:
67
  model = whisper.load_model(selectedModel)
68
+ self.model_cache[selectedModel] = model
69
 
70
  # Callable for processing an audio file
71
  whisperCallable = lambda audio : model.transcribe(audio, language=selectedLanguage, task=task)
 
101
  text = result["text"]
102
 
103
  language = result["language"]
104
+ languageMaxLineWidth = self.__get_max_line_width(language)
105
 
106
  print("Max line width " + str(languageMaxLineWidth))
107
+ vtt = self.__get_subs(result["segments"], "vtt", languageMaxLineWidth)
108
+ srt = self.__get_subs(result["segments"], "srt", languageMaxLineWidth)
109
 
110
  # Files that can be downloaded
111
  downloadDirectory = tempfile.mkdtemp()
112
  filePrefix = slugify(sourceName, allow_unicode=True)
113
 
114
  download = []
115
+ download.append(self.__create_file(srt, downloadDirectory, filePrefix + "-subs.srt"));
116
+ download.append(self.__create_file(vtt, downloadDirectory, filePrefix + "-subs.vtt"));
117
+ download.append(self.__create_file(text, downloadDirectory, filePrefix + "-transcript.txt"));
118
 
119
  return download, text, vtt
120
 
121
  finally:
122
  # Cleanup source
123
+ if self.deleteUploadedFiles:
124
  print("Deleting source file " + source)
125
  os.remove(source)
126
 
127
  except ExceededMaximumDuration as e:
128
  return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
129
 
130
+ def clear_cache(self):
131
+ self.model_cache = dict()
132
+
133
+ def __get_source(self, urlData, uploadFile, microphoneData):
134
  if urlData:
135
  # Download from YouTube
136
+ source = download_url(urlData, self.inputAudioMaxDuration)
137
  else:
138
  # File input
139
  source = uploadFile if uploadFile is not None else microphoneData
 
150
 
151
  return source, sourceName
152
 
153
+ def __get_max_line_width(self, language: str) -> int:
154
+ if (language and language.lower() in ["japanese", "ja", "chinese", "zh"]):
155
+ # Chinese characters and kana are wider, so limit line length to 40 characters
156
+ return 40
157
+ else:
158
+ # TODO: Add more languages
159
+ # 80 latin characters should fit on a 1080p/720p screen
160
+ return 80
161
+
162
+ def __get_subs(self, segments: Iterator[dict], format: str, maxLineWidth: int) -> str:
163
+ segmentStream = StringIO()
164
 
165
+ if format == 'vtt':
166
+ write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
167
+ elif format == 'srt':
168
+ write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
169
+ else:
170
+ raise Exception("Unknown format " + format)
171
 
172
+ segmentStream.seek(0)
173
+ return segmentStream.read()
174
 
175
+ def __create_file(self, text: str, directory: str, fileName: str) -> str:
176
+ # Write the text to a file
177
+ with open(os.path.join(directory, fileName), 'w+', encoding="utf-8") as file:
178
+ file.write(text)
179
 
180
+ return file.name
 
 
 
 
 
181
 
 
 
 
182
 
183
+ def create_ui(inputAudioMaxDuration, share=False, server_name: str = None):
184
+ ui = WhisperTranscriber(inputAudioMaxDuration)
185
 
186
  ui_description = "Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse "
187
  ui_description += " audio and is also a multi-task model that can perform multilingual speech recognition "
 
192
  if inputAudioMaxDuration > 0:
193
  ui_description += "\n\n" + "Max audio file length: " + str(inputAudioMaxDuration) + " s"
194
 
195
+ ui_article = "Read the [documentation here](https://huggingface.co/spaces/aadnk/whisper-webui/blob/main/docs/options.md)"
196
 
197
+ demo = gr.Interface(fn=ui.transcribe_file, description=ui_description, article=ui_article, inputs=[
198
  gr.Dropdown(choices=["tiny", "base", "small", "medium", "large"], value="medium", label="Model"),
199
  gr.Dropdown(choices=sorted(LANGUAGES), label="Language"),
200
  gr.Text(label="URL (YouTube, etc.)"),
 
214
  demo.launch(share=share, server_name=server_name)
215
 
216
  if __name__ == '__main__':
217
+ create_ui(DEFAULT_INPUT_AUDIO_MAX_DURATION)
src/download.py CHANGED
@@ -13,16 +13,16 @@ class FilenameCollectorPP(PostProcessor):
13
  self.filenames.append(information["filepath"])
14
  return [], information
15
 
16
- def downloadUrl(url: str, maxDuration: int = None):
17
  try:
18
- return _performDownload(url, maxDuration=maxDuration)
19
  except yt_dlp.utils.DownloadError as e:
20
  # In case of an OS error, try again with a different output template
21
  if e.msg and e.msg.find("[Errno 36] File name too long") >= 0:
22
- return _performDownload(url, maxDuration=maxDuration, outputTemplate="%(title).10s %(id)s.%(ext)s")
23
  pass
24
 
25
- def _performDownload(url: str, maxDuration: int = None, outputTemplate: str = None):
26
  destinationDirectory = mkdtemp()
27
 
28
  ydl_opts = {
 
13
  self.filenames.append(information["filepath"])
14
  return [], information
15
 
16
+ def download_url(url: str, maxDuration: int = None):
17
  try:
18
+ return _perform_download(url, maxDuration=maxDuration)
19
  except yt_dlp.utils.DownloadError as e:
20
  # In case of an OS error, try again with a different output template
21
  if e.msg and e.msg.find("[Errno 36] File name too long") >= 0:
22
+ return _perform_download(url, maxDuration=maxDuration, outputTemplate="%(title).10s %(id)s.%(ext)s")
23
  pass
24
 
25
+ def _perform_download(url: str, maxDuration: int = None, outputTemplate: str = None):
26
  destinationDirectory = mkdtemp()
27
 
28
  ydl_opts = {