yvankob commited on
Commit
62c5f44
1 Parent(s): 6cd1cba

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +235 -0
app.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from faster_whisper import WhisperModel
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
5
+ from pydub import AudioSegment
6
+ import yt_dlp as youtube_dl
7
+ import tempfile
8
+ from transformers.pipelines.audio_utils import ffmpeg_read
9
+ from gradio.components import Audio, Dropdown, Radio, Textbox
10
+ import os
11
+ import numpy as np
12
+ import soundfile as sf
13
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
14
+
15
+
16
+ # Paramètres
17
+ FILE_LIMIT_MB = 1000
18
+ YT_LENGTH_LIMIT_S = 3600 # Limite de 1 heure pour les vidéos YouTube
19
+
20
+ # Charger les codes de langue
21
+ from flores200_codes import flores_codes
22
+
23
+ # Fonction pour déterminer le device
24
+ def set_device():
25
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+
27
+ device = set_device()
28
+
29
+
30
+ # Charger les modèles une seule fois
31
+ model_dict = {}
32
+ def load_models():
33
+ global model_dict
34
+ if not model_dict:
35
+ model_name_dict = {
36
+ #'nllb-distilled-1.3B': 'facebook/nllb-200-distilled-1.3B',
37
+ 'nllb-distilled-600M': 'facebook/nllb-200-distilled-600M',
38
+ #'nllb-1.3B': 'facebook/nllb-200-1.3B',
39
+ #'nllb-distilled-1.3B': 'facebook/nllb-200-distilled-1.3B',
40
+ #'nllb-3.3B': 'facebook/nllb-200-3.3B',
41
+ # 'nllb-distilled-600M': 'facebook/nllb-200-distilled-600M',
42
+ }
43
+ for call_name, real_name in model_name_dict.items():
44
+ model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
45
+ tokenizer = AutoTokenizer.from_pretrained(real_name)
46
+ model_dict[call_name+'_model'] = model
47
+ model_dict[call_name+'_tokenizer'] = tokenizer
48
+
49
+ load_models()
50
+
51
+ model_size = "large-v2"
52
+ model = WhisperModel(model_size)
53
+
54
+
55
+ # Fonction pour la transcription
56
+ def transcribe_audio(audio_file):
57
+ # model_size = "large-v2"
58
+ # model = WhisperModel(model_size)
59
+ # model = WhisperModel(model_size, device=device, compute_type="int8")
60
+ global model
61
+ segments, _ = model.transcribe(audio_file, beam_size=1)
62
+ transcriptions = [("[%.2fs -> %.2fs]" % (seg.start, seg.end), seg.text) for seg in segments]
63
+ return transcriptions
64
+
65
+
66
+ # Fonction pour la traduction
67
+ def traduction(text, source_lang, target_lang):
68
+ # Vérifier si les codes de langue sont dans flores_codes
69
+ if source_lang not in flores_codes or target_lang not in flores_codes:
70
+ print(f"Code de langue non trouvé : {source_lang} ou {target_lang}")
71
+ return ""
72
+
73
+ src_code = flores_codes[source_lang]
74
+ tgt_code = flores_codes[target_lang]
75
+
76
+ model_name = "nllb-distilled-600M"
77
+ model = model_dict[model_name + "_model"]
78
+ tokenizer = model_dict[model_name + "_tokenizer"]
79
+ translator = pipeline("translation", model=model, tokenizer=tokenizer)
80
+
81
+ return translator(text, src_lang=src_code, tgt_lang=tgt_code)[0]["translation_text"]
82
+
83
+
84
+ # Fonction principale
85
+ def full_transcription_and_translation(audio_input, source_lang, target_lang):
86
+ # Si audio_input est une URL
87
+ if isinstance(audio_input, str) and audio_input.startswith("http"):
88
+ audio_file = download_yt_audio(audio_input)
89
+ # Si audio_input est un dictionnaire contenant des données audio
90
+ elif isinstance(audio_input, dict) and "array" in audio_input and "sampling_rate" in audio_input:
91
+ audio_array = audio_input["array"]
92
+ sampling_rate = audio_input["sampling_rate"]
93
+ # Écrire le tableau NumPy dans un fichier temporaire WAV
94
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as f:
95
+ sf.write(f, audio_array, sampling_rate)
96
+ audio_file = f.name
97
+ else:
98
+ # Supposons que c'est un chemin de fichier
99
+ audio_file = audio_input
100
+
101
+ transcriptions = transcribe_audio(audio_file)
102
+ translations = [(timestamp, traduction(text, source_lang, target_lang)) for timestamp, text in transcriptions]
103
+
104
+ # Supprimez le fichier temporaire s'il a été créé
105
+ if isinstance(audio_input, dict):
106
+ os.remove(audio_file)
107
+
108
+ return transcriptions, translations
109
+
110
+ # Téléchargement audio YouTube
111
+ """def download_yt_audio(yt_url):
112
+ with tempfile.NamedTemporaryFile(suffix='.mp3') as f:
113
+ ydl_opts = {
114
+ 'format': 'bestaudio/best',
115
+ 'outtmpl': f.name,
116
+ 'postprocessors': [{
117
+ 'key': 'FFmpegExtractAudio',
118
+ 'preferredcodec': 'mp3',
119
+ 'preferredquality': '192',
120
+ }],
121
+ }
122
+ with youtube_dl.YoutubeDL(ydl_opts) as ydl:
123
+ ydl.download([yt_url])
124
+ return f.name"""
125
+
126
+ lang_codes = list(flores_codes.keys())
127
+
128
+ # Interface Gradio
129
+ def gradio_interface(audio_file, source_lang, target_lang):
130
+ if audio_file.startswith("http"):
131
+ audio_file = download_yt_audio(audio_file)
132
+ transcriptions, translations = full_transcription_and_translation(audio_file, source_lang, target_lang)
133
+ transcribed_text = '\n'.join([f"{timestamp}: {text}" for timestamp, text in transcriptions])
134
+ translated_text = '\n'.join([f"{timestamp}: {text}" for timestamp, text in translations])
135
+ return transcribed_text, translated_text
136
+
137
+
138
+ def _return_yt_html_embed(yt_url):
139
+ video_id = yt_url.split("?v=")[-1]
140
+ HTML_str = (
141
+ f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>'
142
+ " </center>"
143
+ )
144
+ return HTML_str
145
+
146
+ def download_yt_audio(yt_url, filename):
147
+ info_loader = youtube_dl.YoutubeDL()
148
+
149
+ try:
150
+ info = info_loader.extract_info(yt_url, download=False)
151
+ except youtube_dl.utils.DownloadError as err:
152
+ raise gr.Error(str(err))
153
+
154
+ file_length = info["duration_string"]
155
+ file_h_m_s = file_length.split(":")
156
+ file_h_m_s = [int(sub_length) for sub_length in file_h_m_s]
157
+
158
+ if len(file_h_m_s) == 1:
159
+ file_h_m_s.insert(0, 0)
160
+ if len(file_h_m_s) == 2:
161
+ file_h_m_s.insert(0, 0)
162
+ file_length_s = file_h_m_s[0] * 3600 + file_h_m_s[1] * 60 + file_h_m_s[2]
163
+
164
+ if file_length_s > YT_LENGTH_LIMIT_S:
165
+ yt_length_limit_hms = time.strftime("%HH:%MM:%SS", time.gmtime(YT_LENGTH_LIMIT_S))
166
+ file_length_hms = time.strftime("%HH:%MM:%SS", time.gmtime(file_length_s))
167
+ raise gr.Error(f"Maximum YouTube length is {yt_length_limit_hms}, got {file_length_hms} YouTube video.")
168
+
169
+ ydl_opts = {"outtmpl": filename, "format": "worstvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best"}
170
+
171
+ with youtube_dl.YoutubeDL(ydl_opts) as ydl:
172
+ try:
173
+ ydl.download([yt_url])
174
+ except youtube_dl.utils.ExtractorError as err:
175
+ raise gr.Error(str(err))
176
+
177
+
178
+ def yt_transcribe(yt_url, task, max_filesize=75.0):
179
+ html_embed_str = _return_yt_html_embed(yt_url)
180
+ global model # S'assurer que le modèle est accessible
181
+
182
+ with tempfile.TemporaryDirectory() as tmpdirname:
183
+ filepath = os.path.join(tmpdirname, "video.mp4")
184
+ download_yt_audio(yt_url, filepath)
185
+ with open(filepath, "rb") as f:
186
+ inputs = f.read()
187
+
188
+ inputs = ffmpeg_read(inputs, model.feature_extractor.sampling_rate)
189
+ inputs = {"array": inputs, "sampling_rate": model.feature_extractor.sampling_rate}
190
+
191
+ transcriptions, translations = full_transcription_and_translation(inputs, source_lang, target_lang)
192
+ transcribed_text = '\n'.join([f"{timestamp}: {text}" for timestamp, text in transcriptions])
193
+ translated_text = '\n'.join([f"{timestamp}: {text}" for timestamp, text in translations])
194
+ return html_embed_str, transcribed_text, translated_text
195
+
196
+
197
+ # Interfaces
198
+ demo = gr.Blocks()
199
+
200
+ with demo:
201
+ with gr.Tab("Microphone"):
202
+ gr.Interface(
203
+ fn=gradio_interface,
204
+ inputs=[
205
+ gr.Audio(sources=["microphone"], type="filepath"),
206
+ gr.Dropdown(lang_codes, value='French', label='Source Language'),
207
+ gr.Dropdown(lang_codes, value='English', label='Target Language')],
208
+ outputs=[gr.Textbox(label="Transcribed Text"), gr.Textbox(label="Translated Text")]
209
+ )
210
+
211
+ with gr.Tab("Audio file"):
212
+ gr.Interface(
213
+ fn=gradio_interface,
214
+ inputs=[
215
+ gr.Audio(type="filepath", label="Audio file"),
216
+ gr.Dropdown(lang_codes, value='French', label='Source Language'),
217
+ gr.Dropdown(lang_codes, value='English', label='Target Language')],
218
+ outputs=[gr.Textbox(label="Transcribed Text"), gr.Textbox(label="Translated Text")]
219
+ )
220
+
221
+ with gr.Tab("YouTube"):
222
+ gr.Interface(
223
+ fn=yt_transcribe,
224
+ inputs=[
225
+ gr.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL"),
226
+ gr.Dropdown(lang_codes, value='French', label='Source Language'),
227
+ gr.Dropdown(lang_codes, value='English', label='Target Language')
228
+ ],
229
+ outputs=["html", gr.Textbox(label="Transcribed Text"), gr.Textbox(label="Translated Text")]
230
+ )
231
+
232
+ #with demo:
233
+ #gr.TabbedInterface([mf_transcribe, file_transcribe, yt_transcribe], ["Microphone", "Audio file", "YouTube"])
234
+
235
+ demo.launch()