Spaces:
Runtime error
Runtime error
Add support for downloading the results.
Browse files
app.py
CHANGED
@@ -1,7 +1,11 @@
|
|
1 |
from io import StringIO
|
|
|
|
|
|
|
|
|
2 |
import gradio as gr
|
3 |
|
4 |
-
from utils import write_vtt
|
5 |
import whisper
|
6 |
|
7 |
import ffmpeg
|
@@ -40,6 +44,8 @@ class UI:
|
|
40 |
|
41 |
def transcribeFile(self, modelName, languageName, uploadFile, microphoneData, task):
|
42 |
source = uploadFile if uploadFile is not None else microphoneData
|
|
|
|
|
43 |
selectedLanguage = languageName.lower() if len(languageName) > 0 else None
|
44 |
selectedModel = modelName if modelName is not None else "base"
|
45 |
|
@@ -56,14 +62,43 @@ class UI:
|
|
56 |
model = whisper.load_model(selectedModel)
|
57 |
model_cache[selectedModel] = model
|
58 |
|
|
|
59 |
result = model.transcribe(source, language=selectedLanguage, task=task)
|
60 |
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
66 |
|
|
|
|
|
67 |
|
68 |
def createUi(inputAudioMaxDuration, share=False):
|
69 |
ui = UI(inputAudioMaxDuration)
|
@@ -81,7 +116,11 @@ def createUi(inputAudioMaxDuration, share=False):
|
|
81 |
gr.Audio(source="upload", type="filepath", label="Upload Audio"),
|
82 |
gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
|
83 |
gr.Dropdown(choices=["transcribe", "translate"], label="Task"),
|
84 |
-
], outputs=[
|
|
|
|
|
|
|
|
|
85 |
|
86 |
demo.launch(share=share)
|
87 |
|
|
|
1 |
from io import StringIO
|
2 |
+
import os
|
3 |
+
import tempfile
|
4 |
+
|
5 |
+
from typing import Iterator
|
6 |
import gradio as gr
|
7 |
|
8 |
+
from utils import slugify, write_srt, write_vtt
|
9 |
import whisper
|
10 |
|
11 |
import ffmpeg
|
|
|
44 |
|
45 |
def transcribeFile(self, modelName, languageName, uploadFile, microphoneData, task):
|
46 |
source = uploadFile if uploadFile is not None else microphoneData
|
47 |
+
sourceName = os.path.basename(source)
|
48 |
+
|
49 |
selectedLanguage = languageName.lower() if len(languageName) > 0 else None
|
50 |
selectedModel = modelName if modelName is not None else "base"
|
51 |
|
|
|
62 |
model = whisper.load_model(selectedModel)
|
63 |
model_cache[selectedModel] = model
|
64 |
|
65 |
+
# The results
|
66 |
result = model.transcribe(source, language=selectedLanguage, task=task)
|
67 |
|
68 |
+
text = result["text"]
|
69 |
+
vtt = getSubs(result["segments"], "vtt")
|
70 |
+
srt = getSubs(result["segments"], "srt")
|
71 |
+
|
72 |
+
# Files that can be downloaded
|
73 |
+
downloadDirectory = tempfile.mkdtemp()
|
74 |
+
filePrefix = slugify(sourceName, allow_unicode=True)
|
75 |
+
|
76 |
+
download = []
|
77 |
+
download.append(createFile(srt, downloadDirectory, filePrefix + "-subs.srt"));
|
78 |
+
download.append(createFile(vtt, downloadDirectory, filePrefix + "-subs.vtt"));
|
79 |
+
download.append(createFile(text, downloadDirectory, filePrefix + "-transcript.txt"));
|
80 |
+
|
81 |
+
return text, vtt, download
|
82 |
+
|
83 |
+
def createFile(text: str, directory: str, fileName: str) -> str:
|
84 |
+
# Write the text to a file
|
85 |
+
with open(os.path.join(directory, fileName), 'w+', encoding="utf-8") as file:
|
86 |
+
file.write(text)
|
87 |
+
|
88 |
+
return file.name
|
89 |
+
|
90 |
+
def getSubs(segments: Iterator[dict], format: str) -> str:
|
91 |
+
segmentStream = StringIO()
|
92 |
|
93 |
+
if format == 'vtt':
|
94 |
+
write_vtt(segments, file=segmentStream)
|
95 |
+
elif format == 'srt':
|
96 |
+
write_srt(segments, file=segmentStream)
|
97 |
+
else:
|
98 |
+
raise Exception("Unknown format " + format)
|
99 |
|
100 |
+
segmentStream.seek(0)
|
101 |
+
return segmentStream.read()
|
102 |
|
103 |
def createUi(inputAudioMaxDuration, share=False):
|
104 |
ui = UI(inputAudioMaxDuration)
|
|
|
116 |
gr.Audio(source="upload", type="filepath", label="Upload Audio"),
|
117 |
gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
|
118 |
gr.Dropdown(choices=["transcribe", "translate"], label="Task"),
|
119 |
+
], outputs=[
|
120 |
+
gr.Text(label="Transcription"),
|
121 |
+
gr.Text(label="Segments"),
|
122 |
+
gr.File(label="Download")
|
123 |
+
])
|
124 |
|
125 |
demo.launch(share=share)
|
126 |
|
utils.py
CHANGED
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
1 |
import zlib
|
2 |
from typing import Iterator, TextIO
|
3 |
|
@@ -27,7 +30,7 @@ def compression_ratio(text) -> float:
|
|
27 |
return len(text) / len(zlib.compress(text.encode("utf-8")))
|
28 |
|
29 |
|
30 |
-
def format_timestamp(seconds: float):
|
31 |
assert seconds >= 0, "non-negative timestamp expected"
|
32 |
milliseconds = round(seconds * 1000.0)
|
33 |
|
@@ -40,7 +43,13 @@ def format_timestamp(seconds: float):
|
|
40 |
seconds = milliseconds // 1_000
|
41 |
milliseconds -= seconds * 1_000
|
42 |
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
|
46 |
def write_vtt(transcript: Iterator[dict], file: TextIO):
|
@@ -52,3 +61,43 @@ def write_vtt(transcript: Iterator[dict], file: TextIO):
|
|
52 |
file=file,
|
53 |
flush=True,
|
54 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unicodedata
|
2 |
+
import re
|
3 |
+
|
4 |
import zlib
|
5 |
from typing import Iterator, TextIO
|
6 |
|
|
|
30 |
return len(text) / len(zlib.compress(text.encode("utf-8")))
|
31 |
|
32 |
|
33 |
+
def format_timestamp(seconds: float, always_include_hours: bool = False):
|
34 |
assert seconds >= 0, "non-negative timestamp expected"
|
35 |
milliseconds = round(seconds * 1000.0)
|
36 |
|
|
|
43 |
seconds = milliseconds // 1_000
|
44 |
milliseconds -= seconds * 1_000
|
45 |
|
46 |
+
hours_marker = f"{hours}:" if always_include_hours or hours > 0 else ""
|
47 |
+
return f"{hours_marker}{minutes:02d}:{seconds:02d}.{milliseconds:03d}"
|
48 |
+
|
49 |
+
|
50 |
+
def write_txt(transcript: Iterator[dict], file: TextIO):
|
51 |
+
for segment in transcript:
|
52 |
+
print(segment['text'].strip(), file=file, flush=True)
|
53 |
|
54 |
|
55 |
def write_vtt(transcript: Iterator[dict], file: TextIO):
|
|
|
61 |
file=file,
|
62 |
flush=True,
|
63 |
)
|
64 |
+
|
65 |
+
|
66 |
+
def write_srt(transcript: Iterator[dict], file: TextIO):
|
67 |
+
"""
|
68 |
+
Write a transcript to a file in SRT format.
|
69 |
+
Example usage:
|
70 |
+
from pathlib import Path
|
71 |
+
from whisper.utils import write_srt
|
72 |
+
result = transcribe(model, audio_path, temperature=temperature, **args)
|
73 |
+
# save SRT
|
74 |
+
audio_basename = Path(audio_path).stem
|
75 |
+
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
76 |
+
write_srt(result["segments"], file=srt)
|
77 |
+
"""
|
78 |
+
for i, segment in enumerate(transcript, start=1):
|
79 |
+
# write srt lines
|
80 |
+
print(
|
81 |
+
f"{i}\n"
|
82 |
+
f"{format_timestamp(segment['start'], always_include_hours=True)} --> "
|
83 |
+
f"{format_timestamp(segment['end'], always_include_hours=True)}\n"
|
84 |
+
f"{segment['text'].strip().replace('-->', '->')}\n",
|
85 |
+
file=file,
|
86 |
+
flush=True,
|
87 |
+
)
|
88 |
+
|
89 |
+
def slugify(value, allow_unicode=False):
|
90 |
+
"""
|
91 |
+
Taken from https://github.com/django/django/blob/master/django/utils/text.py
|
92 |
+
Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
|
93 |
+
dashes to single dashes. Remove characters that aren't alphanumerics,
|
94 |
+
underscores, or hyphens. Convert to lowercase. Also strip leading and
|
95 |
+
trailing whitespace, dashes, and underscores.
|
96 |
+
"""
|
97 |
+
value = str(value)
|
98 |
+
if allow_unicode:
|
99 |
+
value = unicodedata.normalize('NFKC', value)
|
100 |
+
else:
|
101 |
+
value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii')
|
102 |
+
value = re.sub(r'[^\w\s-]', '', value.lower())
|
103 |
+
return re.sub(r'[-\s]+', '-', value).strip('-_')
|