gobeldan commited on
Commit
3d58577
1 Parent(s): 82c30a2

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +172 -0
  2. languages.py +147 -0
  3. requirements.txt +5 -0
  4. subtitle_manager.py +52 -0
app.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import time
3
+ import logging
4
+ import torch
5
+ from sys import platform
6
+ from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor
7
+ from transformers.utils import is_flash_attn_2_available
8
+ from languages import get_language_names
9
+ from subtitle_manager import Subtitle
10
+
11
+
12
+ logging.basicConfig(level=logging.INFO)
13
+ last_model = None
14
+
15
+ def write_file(output_file,subtitle):
16
+ with open(output_file, 'w', encoding='utf-8') as f:
17
+ f.write(subtitle)
18
+
19
+ def create_pipe(model, flash):
20
+ if torch.cuda.is_available():
21
+ device = "cuda:0"
22
+ elif platform == "darwin":
23
+ device = "mps"
24
+ else:
25
+ device = "cpu"
26
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
27
+ model_id = model
28
+
29
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
30
+ model_id,
31
+ torch_dtype=torch_dtype,
32
+ low_cpu_mem_usage=True,
33
+ use_safetensors=True,
34
+ attn_implementation="flash_attention_2" if flash and is_flash_attn_2_available() else "sdpa",
35
+ # eager (manual attention implementation)
36
+ # flash_attention_2 (implementation using flash attention 2)
37
+ # sdpa (implementation using torch.nn.functional.scaled_dot_product_attention)
38
+ # PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.1.1.
39
+ )
40
+ model.to(device)
41
+
42
+ processor = AutoProcessor.from_pretrained(model_id)
43
+
44
+ pipe = pipeline(
45
+ "automatic-speech-recognition",
46
+ model=model,
47
+ tokenizer=processor.tokenizer,
48
+ feature_extractor=processor.feature_extractor,
49
+ # max_new_tokens=128,
50
+ # chunk_length_s=15,
51
+ # batch_size=16,
52
+ torch_dtype=torch_dtype,
53
+ device=device,
54
+ )
55
+ return pipe
56
+
57
+ def transcribe_webui_simple_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task, flash,
58
+ chunk_length_s, batch_size, progress=gr.Progress()):
59
+ global last_model
60
+
61
+ progress(0, desc="Loading Audio..")
62
+ logging.info(f"urlData:{urlData}")
63
+ logging.info(f"multipleFiles:{multipleFiles}")
64
+ logging.info(f"microphoneData:{microphoneData}")
65
+ logging.info(f"task: {task}")
66
+ logging.info(f"is_flash_attn_2_available: {is_flash_attn_2_available()}")
67
+ logging.info(f"chunk_length_s: {chunk_length_s}")
68
+ logging.info(f"batch_size: {batch_size}")
69
+
70
+ if last_model == None:
71
+ logging.info("first model")
72
+ progress(0.1, desc="Loading Model..")
73
+ pipe = create_pipe(modelName, flash)
74
+ elif modelName != last_model:
75
+ logging.info("new model")
76
+ torch.cuda.empty_cache()
77
+ progress(0.1, desc="Loading Model..")
78
+ pipe = create_pipe(modelName, flash)
79
+ else:
80
+ logging.info("Model not changed")
81
+ last_model = modelName
82
+
83
+ srt_sub = Subtitle("srt")
84
+ vtt_sub = Subtitle("vtt")
85
+ txt_sub = Subtitle("txt")
86
+
87
+ files = []
88
+ if multipleFiles:
89
+ files+=multipleFiles
90
+ if urlData:
91
+ files.append(urlData)
92
+ if microphoneData:
93
+ files.append(microphoneData)
94
+ logging.info(files)
95
+
96
+ generate_kwargs = {}
97
+ if languageName != "Automatic Detection" and modelName.endswith(".en") == False:
98
+ generate_kwargs["language"] = languageName
99
+ if modelName.endswith(".en") == False:
100
+ generate_kwargs["task"] = task
101
+
102
+ files_out = []
103
+ for file in progress.tqdm(files, desc="Working..."):
104
+ start_time = time.time()
105
+ logging.info(file)
106
+ outputs = pipe(
107
+ file,
108
+ chunk_length_s=chunk_length_s,#30
109
+ batch_size=batch_size,#24
110
+ generate_kwargs=generate_kwargs,
111
+ return_timestamps=True,
112
+ )
113
+ logging.debug(outputs)
114
+ logging.info(print(f"transcribe: {time.time() - start_time} sec."))
115
+
116
+ file_out = file.split('/')[-1]
117
+ srt = srt_sub.get_subtitle(outputs["chunks"])
118
+ vtt = vtt_sub.get_subtitle(outputs["chunks"])
119
+ txt = txt_sub.get_subtitle(outputs["chunks"])
120
+ write_file(file_out+".srt",srt)
121
+ write_file(file_out+".vtt",vtt)
122
+ write_file(file_out+".txt",txt)
123
+ files_out += [file_out+".srt", file_out+".vtt", file_out+".txt"]
124
+
125
+ progress(1, desc="Completed!")
126
+
127
+ return files_out, vtt, txt
128
+
129
+
130
+ with gr.Blocks(title="Insanely Fast Whisper") as demo:
131
+ description = "An opinionated CLI to transcribe Audio files w/ Whisper on-device! Powered by 🤗 Transformers, Optimum & flash-attn"
132
+ article = "Read the [documentation here](https://github.com/Vaibhavs10/insanely-fast-whisper#cli-options)."
133
+ whisper_models = [
134
+ "openai/whisper-tiny", "openai/whisper-tiny.en",
135
+ "openai/whisper-base", "openai/whisper-base.en",
136
+ "openai/whisper-small", "openai/whisper-small.en", "distil-whisper/distil-small.en",
137
+ "openai/whisper-medium", "openai/whisper-medium.en", "distil-whisper/distil-medium.en",
138
+ "openai/whisper-large",
139
+ "openai/whisper-large-v1",
140
+ "openai/whisper-large-v2", "distil-whisper/distil-large-v2",
141
+ "openai/whisper-large-v3", "xaviviro/whisper-large-v3-catalan-finetuned-v2",
142
+ ]
143
+ waveform_options=gr.WaveformOptions(
144
+ waveform_color="#01C6FF",
145
+ waveform_progress_color="#0066B4",
146
+ skip_length=2,
147
+ show_controls=False,
148
+ )
149
+
150
+ simple_transcribe = gr.Interface(fn=transcribe_webui_simple_progress,
151
+ description=description,
152
+ article=article,
153
+ inputs=[
154
+ gr.Dropdown(choices=whisper_models, value="distil-whisper/distil-large-v2", label="Model", info="Select whisper model", interactive = True,),
155
+ gr.Dropdown(choices=["Automatic Detection"] + sorted(get_language_names()), value="Automatic Detection", label="Language", info="Select audio voice language", interactive = True,),
156
+ gr.Text(label="URL", info="(YouTube, etc.)", interactive = True),
157
+ gr.File(label="Upload Files", file_count="multiple"),
158
+ gr.Audio(sources=["microphone"], type="filepath", label="Microphone Input", waveform_options = waveform_options),
159
+ gr.Dropdown(choices=["transcribe", "translate"], label="Task", value="transcribe", interactive = True),
160
+ gr.Checkbox(label='Flash',info='Use Flash Attention 2'),
161
+ gr.Number(label='chunk_length_s',value=30, interactive = True),
162
+ gr.Number(label='batch_size',value=24, interactive = True)
163
+ ], outputs=[
164
+ gr.File(label="Download"),
165
+ gr.Text(label="Transcription"),
166
+ gr.Text(label="Segments")
167
+ ]
168
+ )
169
+
170
+ if __name__ == "__main__":
171
+ demo.launch()
172
+
languages.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Language():
2
+ def __init__(self, code, name):
3
+ self.code = code
4
+ self.name = name
5
+
6
+ def __str__(self):
7
+ return "Language(code={}, name={})".format(self.code, self.name)
8
+
9
+ LANGUAGES = [
10
+ Language('en', 'English'),
11
+ Language('zh', 'Chinese'),
12
+ Language('de', 'German'),
13
+ Language('es', 'Spanish'),
14
+ Language('ru', 'Russian'),
15
+ Language('ko', 'Korean'),
16
+ Language('fr', 'French'),
17
+ Language('ja', 'Japanese'),
18
+ Language('pt', 'Portuguese'),
19
+ Language('tr', 'Turkish'),
20
+ Language('pl', 'Polish'),
21
+ Language('ca', 'Catalan'),
22
+ Language('nl', 'Dutch'),
23
+ Language('ar', 'Arabic'),
24
+ Language('sv', 'Swedish'),
25
+ Language('it', 'Italian'),
26
+ Language('id', 'Indonesian'),
27
+ Language('hi', 'Hindi'),
28
+ Language('fi', 'Finnish'),
29
+ Language('vi', 'Vietnamese'),
30
+ Language('he', 'Hebrew'),
31
+ Language('uk', 'Ukrainian'),
32
+ Language('el', 'Greek'),
33
+ Language('ms', 'Malay'),
34
+ Language('cs', 'Czech'),
35
+ Language('ro', 'Romanian'),
36
+ Language('da', 'Danish'),
37
+ Language('hu', 'Hungarian'),
38
+ Language('ta', 'Tamil'),
39
+ Language('no', 'Norwegian'),
40
+ Language('th', 'Thai'),
41
+ Language('ur', 'Urdu'),
42
+ Language('hr', 'Croatian'),
43
+ Language('bg', 'Bulgarian'),
44
+ Language('lt', 'Lithuanian'),
45
+ Language('la', 'Latin'),
46
+ Language('mi', 'Maori'),
47
+ Language('ml', 'Malayalam'),
48
+ Language('cy', 'Welsh'),
49
+ Language('sk', 'Slovak'),
50
+ Language('te', 'Telugu'),
51
+ Language('fa', 'Persian'),
52
+ Language('lv', 'Latvian'),
53
+ Language('bn', 'Bengali'),
54
+ Language('sr', 'Serbian'),
55
+ Language('az', 'Azerbaijani'),
56
+ Language('sl', 'Slovenian'),
57
+ Language('kn', 'Kannada'),
58
+ Language('et', 'Estonian'),
59
+ Language('mk', 'Macedonian'),
60
+ Language('br', 'Breton'),
61
+ Language('eu', 'Basque'),
62
+ Language('is', 'Icelandic'),
63
+ Language('hy', 'Armenian'),
64
+ Language('ne', 'Nepali'),
65
+ Language('mn', 'Mongolian'),
66
+ Language('bs', 'Bosnian'),
67
+ Language('kk', 'Kazakh'),
68
+ Language('sq', 'Albanian'),
69
+ Language('sw', 'Swahili'),
70
+ Language('gl', 'Galician'),
71
+ Language('mr', 'Marathi'),
72
+ Language('pa', 'Punjabi'),
73
+ Language('si', 'Sinhala'),
74
+ Language('km', 'Khmer'),
75
+ Language('sn', 'Shona'),
76
+ Language('yo', 'Yoruba'),
77
+ Language('so', 'Somali'),
78
+ Language('af', 'Afrikaans'),
79
+ Language('oc', 'Occitan'),
80
+ Language('ka', 'Georgian'),
81
+ Language('be', 'Belarusian'),
82
+ Language('tg', 'Tajik'),
83
+ Language('sd', 'Sindhi'),
84
+ Language('gu', 'Gujarati'),
85
+ Language('am', 'Amharic'),
86
+ Language('yi', 'Yiddish'),
87
+ Language('lo', 'Lao'),
88
+ Language('uz', 'Uzbek'),
89
+ Language('fo', 'Faroese'),
90
+ Language('ht', 'Haitian creole'),
91
+ Language('ps', 'Pashto'),
92
+ Language('tk', 'Turkmen'),
93
+ Language('nn', 'Nynorsk'),
94
+ Language('mt', 'Maltese'),
95
+ Language('sa', 'Sanskrit'),
96
+ Language('lb', 'Luxembourgish'),
97
+ Language('my', 'Myanmar'),
98
+ Language('bo', 'Tibetan'),
99
+ Language('tl', 'Tagalog'),
100
+ Language('mg', 'Malagasy'),
101
+ Language('as', 'Assamese'),
102
+ Language('tt', 'Tatar'),
103
+ Language('haw', 'Hawaiian'),
104
+ Language('ln', 'Lingala'),
105
+ Language('ha', 'Hausa'),
106
+ Language('ba', 'Bashkir'),
107
+ Language('jw', 'Javanese'),
108
+ Language('su', 'Sundanese')
109
+ ]
110
+
111
+ _TO_LANGUAGE_CODE = {
112
+ **{language.code: language for language in LANGUAGES},
113
+ "burmese": "my",
114
+ "valencian": "ca",
115
+ "flemish": "nl",
116
+ "haitian": "ht",
117
+ "letzeburgesch": "lb",
118
+ "pushto": "ps",
119
+ "panjabi": "pa",
120
+ "moldavian": "ro",
121
+ "moldovan": "ro",
122
+ "sinhalese": "si",
123
+ "castilian": "es",
124
+ }
125
+
126
+ _FROM_LANGUAGE_NAME = {
127
+ **{language.name.lower(): language for language in LANGUAGES}
128
+ }
129
+
130
+ def get_language_from_code(language_code, default=None) -> Language:
131
+ """Return the language name from the language code."""
132
+ return _TO_LANGUAGE_CODE.get(language_code, default)
133
+
134
+ def get_language_from_name(language, default=None) -> Language:
135
+ """Return the language code from the language name."""
136
+ return _FROM_LANGUAGE_NAME.get(language.lower() if language else None, default)
137
+
138
+ def get_language_names():
139
+ """Return a list of language names."""
140
+ return [language.name for language in LANGUAGES]
141
+
142
+ if __name__ == "__main__":
143
+ # Test lookup
144
+ print(get_language_from_code('en'))
145
+ print(get_language_from_name('English'))
146
+
147
+ print(get_language_names())
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio
2
+ --index-url https://download.pytorch.org/whl/cu121
3
+ torch>=2.1.1
4
+ torchvision
5
+ torchaudio
subtitle_manager.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ class Subtitle():
4
+ def __init__(self,ext="srt"):
5
+ sub_dict = {
6
+ "srt":{
7
+ "coma": ",",
8
+ "header": "",
9
+ "format": lambda i,segment : f"{i + 1}\n{self.timeformat(segment['timestamp'][0])} --> {self.timeformat(segment['timestamp'][1] if segment['timestamp'][1] != None else segment['timestamp'][0])}\n{segment['text']}\n\n",
10
+ },
11
+ "vtt":{
12
+ "coma": ".",
13
+ "header": "WebVTT\n\n",
14
+ "format": lambda i,segment : f"{self.timeformat(segment['timestamp'][0])} --> {self.timeformat(segment['timestamp'][1] if segment['timestamp'][1] != None else segment['timestamp'][0])}\n{segment['text']}\n\n",
15
+ },
16
+ "txt":{
17
+ "coma": "",
18
+ "header": "",
19
+ "format": lambda i,segment : f"{segment['text']}\n",
20
+ },
21
+ }
22
+
23
+ self.ext = ext
24
+ self.coma = sub_dict[ext]["coma"]
25
+ self.header = sub_dict[ext]["header"]
26
+ self.format = sub_dict[ext]["format"]
27
+
28
+ def timeformat(self,time):
29
+ hours = time // 3600
30
+ minutes = (time - hours * 3600) // 60
31
+ seconds = time - hours * 3600 - minutes * 60
32
+ milliseconds = (time - int(time)) * 1000
33
+ return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}{self.coma}{int(milliseconds):03d}"
34
+
35
+ def get_subtitle(self,segments):
36
+ output = self.header
37
+ for i, segment in enumerate(segments):
38
+ if segment['text'].startswith(' '):
39
+ segment['text'] = segment['text'][1:]
40
+ try:
41
+ output += self.format(i,segment)
42
+ except Exception as e:
43
+ print(e,segment)
44
+
45
+ return output
46
+
47
+ def write_subtitle(self, segments, output_file):
48
+ output_file += "."+self.ext
49
+ subtitle = self.get_subtitle(segments)
50
+
51
+ with open(output_file, 'w', encoding='utf-8') as f:
52
+ f.write(subtitle)