Demosthene-OR
commited on
Commit
•
dcf791a
1
Parent(s):
40a3d50
Mise en place des RNN et Transformers
Browse files- main_dl.py +39 -222
main_dl.py
CHANGED
@@ -17,8 +17,8 @@ from keras_nlp.layers import TransformerEncoder
|
|
17 |
from tensorflow.keras import layers
|
18 |
from tensorflow.keras.utils import plot_model
|
19 |
|
20 |
-
|
21 |
-
dataPath =
|
22 |
|
23 |
# ===== Keras ====
|
24 |
strip_chars = string.punctuation + "¿"
|
@@ -215,16 +215,8 @@ def decode_sequence_tranf(input_sentence, src, tgt):
|
|
215 |
|
216 |
# ==== End Transforformer section ====
|
217 |
|
218 |
-
@st.cache_resource
|
219 |
def load_all_data():
|
220 |
-
|
221 |
-
df_data_fr = load_corpus(dataPath+'/preprocess_txt_fr')
|
222 |
-
lang_classifier = pipeline('text-classification',model="papluca/xlm-roberta-base-language-detection")
|
223 |
-
translation_en_fr = pipeline('translation_en_to_fr', model="t5-base")
|
224 |
-
translation_fr_en = pipeline('translation_fr_to_en', model="Helsinki-NLP/opus-mt-fr-en")
|
225 |
-
finetuned_translation_en_fr = pipeline('translation_en_to_fr', model="Demosthene-OR/t5-small-finetuned-en-to-fr")
|
226 |
-
model_speech = whisper.load_model("base")
|
227 |
-
|
228 |
merge = Merge( dataPath+"/rnn_en-fr_split", dataPath, "seq2seq_rnn-model-en-fr.h5").merge(cleanup=False)
|
229 |
merge = Merge( dataPath+"/rnn_fr-en_split", dataPath, "seq2seq_rnn-model-fr-en.h5").merge(cleanup=False)
|
230 |
rnn_en_fr = keras.models.load_model(dataPath+"/seq2seq_rnn-model-en-fr.h5", compile=False)
|
@@ -233,26 +225,18 @@ def load_all_data():
|
|
233 |
rnn_fr_en.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
|
234 |
|
235 |
custom_objects = {"TransformerDecoder": TransformerDecoder, "PositionalEmbedding": PositionalEmbedding}
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
merge = Merge( "data/transf_fr-en_weight_split", "data", "transformer-model-fr-en.weights.h5").merge(cleanup=False)
|
242 |
-
else:
|
243 |
-
transformer_en_fr = keras.models.load_model( dataPath+"/transformer-model-en-fr.h5", custom_objects=custom_objects )
|
244 |
-
transformer_fr_en = keras.models.load_model( dataPath+"/transformer-model-fr-en.h5", custom_objects=custom_objects)
|
245 |
-
transformer_en_fr.load_weights(dataPath+"/transformer-model-en-fr.weights.h5")
|
246 |
-
transformer_fr_en.load_weights(dataPath+"/transformer-model-fr-en.weights.h5")
|
247 |
transformer_en_fr.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
|
248 |
transformer_fr_en.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
|
249 |
|
250 |
-
return
|
251 |
-
transformer_en_fr, transformer_fr_en, finetuned_translation_en_fr
|
252 |
|
253 |
n1 = 0
|
254 |
-
|
255 |
-
transformer_en_fr, transformer_fr_en, finetuned_translation_en_fr = load_all_data()
|
256 |
|
257 |
|
258 |
def display_translation(n1, Lang,model_type):
|
@@ -278,27 +262,39 @@ def display_translation(n1, Lang,model_type):
|
|
278 |
st.write("<p style='text-align:center;background-color:red; color:white')>Score Bleu = "+str(int(round(corpus_bleu(s_trad,[s_trad_ref]).score,0)))+"%</p>", \
|
279 |
unsafe_allow_html=True)
|
280 |
|
281 |
-
|
282 |
def find_lang_label(lang_sel):
|
283 |
global lang_tgt, label_lang
|
284 |
return label_lang[lang_tgt.index(lang_sel)]
|
285 |
|
286 |
-
@
|
287 |
-
def
|
288 |
-
|
289 |
-
|
290 |
-
"You fear to fail your exam",
|
291 |
-
"I drive an old rusty car",
|
292 |
-
"Magic can make dreams come true!",
|
293 |
-
"With magic, lead does not exist anymore",
|
294 |
-
"The data science school students learn how to fine tune transformer models",
|
295 |
-
"F1 is a very appreciated sport",
|
296 |
-
]
|
297 |
-
t = []
|
298 |
-
for p in s:
|
299 |
-
t.append(finetuned_translation_en_fr(p, max_length=400)[0]['translation_text'])
|
300 |
-
return s,t
|
301 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
302 |
def run():
|
303 |
|
304 |
global n1, df_data_src, df_data_tgt, translation_model, placeholder, model_speech
|
@@ -409,183 +405,4 @@ def run():
|
|
409 |
st.image(st.session_state.ImagePath+'/model_plot.png',use_column_width=True)
|
410 |
st.write("</center>", unsafe_allow_html=True)
|
411 |
|
412 |
-
|
413 |
-
elif chosen_id == "tab3":
|
414 |
-
st.write("## **"+tr("Paramètres")+" :**\n")
|
415 |
-
custom_sentence = st.text_area(label=tr("Saisir le texte à traduire"))
|
416 |
-
l_tgt = st.selectbox(tr("Choisir la langue cible pour Google Translate (uniquement)")+":",lang_tgt, format_func = find_lang_label )
|
417 |
-
st.button(label=tr("Validez"), type="primary")
|
418 |
-
if custom_sentence!="":
|
419 |
-
st.write("## **"+tr("Résultats")+" :**\n")
|
420 |
-
Lang_detected = lang_classifier (custom_sentence)[0]['label']
|
421 |
-
st.write(tr('Langue détectée')+' : **'+lang_src.get(Lang_detected)+'**')
|
422 |
-
audio_stream_bytesio_src = io.BytesIO()
|
423 |
-
tts = gTTS(custom_sentence,lang=Lang_detected)
|
424 |
-
tts.write_to_fp(audio_stream_bytesio_src)
|
425 |
-
st.audio(audio_stream_bytesio_src)
|
426 |
-
st.write("")
|
427 |
-
else: Lang_detected=""
|
428 |
-
col1, col2 = st.columns(2, gap="small")
|
429 |
-
with col1:
|
430 |
-
st.write(":red[**Trad. t5-base & Helsinki**] *("+tr("Anglais/Français")+")*")
|
431 |
-
audio_stream_bytesio_tgt = io.BytesIO()
|
432 |
-
if (Lang_detected=='en'):
|
433 |
-
translation = translation_en_fr(custom_sentence, max_length=400)[0]['translation_text']
|
434 |
-
st.write("**fr :** "+translation)
|
435 |
-
st.write("")
|
436 |
-
tts = gTTS(translation,lang='fr')
|
437 |
-
tts.write_to_fp(audio_stream_bytesio_tgt)
|
438 |
-
st.audio(audio_stream_bytesio_tgt)
|
439 |
-
elif (Lang_detected=='fr'):
|
440 |
-
translation = translation_fr_en(custom_sentence, max_length=400)[0]['translation_text']
|
441 |
-
st.write("**en :** "+translation)
|
442 |
-
st.write("")
|
443 |
-
tts = gTTS(translation,lang='en')
|
444 |
-
tts.write_to_fp(audio_stream_bytesio_tgt)
|
445 |
-
st.audio(audio_stream_bytesio_tgt)
|
446 |
-
with col2:
|
447 |
-
st.write(":red[**Trad. Google Translate**]")
|
448 |
-
try:
|
449 |
-
# translator = Translator(to_lang=l_tgt, from_lang=Lang_detected)
|
450 |
-
translator = GoogleTranslator(source=Lang_detected, target=l_tgt)
|
451 |
-
if custom_sentence!="":
|
452 |
-
translation = translator.translate(custom_sentence)
|
453 |
-
st.write("**"+l_tgt+" :** "+translation)
|
454 |
-
st.write("")
|
455 |
-
audio_stream_bytesio_tgt = io.BytesIO()
|
456 |
-
tts = gTTS(translation,lang=l_tgt)
|
457 |
-
tts.write_to_fp(audio_stream_bytesio_tgt)
|
458 |
-
st.audio(audio_stream_bytesio_tgt)
|
459 |
-
except:
|
460 |
-
st.write(tr("Problème, essayer de nouveau.."))
|
461 |
-
|
462 |
-
elif chosen_id == "tab4":
|
463 |
-
st.write("## **"+tr("Paramètres")+" :**\n")
|
464 |
-
detection = st.toggle(tr("Détection de langue ?"), value=True)
|
465 |
-
if not detection:
|
466 |
-
l_src = st.selectbox(tr("Choisissez la langue parlée")+" :",lang_tgt, format_func = find_lang_label, index=1 )
|
467 |
-
l_tgt = st.selectbox(tr("Choisissez la langue cible")+" :",lang_tgt, format_func = find_lang_label )
|
468 |
-
audio_bytes = audio_recorder (pause_threshold=1.0, sample_rate=16000, text=tr("Cliquez pour parler, puis attendre 2sec."), \
|
469 |
-
recording_color="#e8b62c", neutral_color="#1ec3bc", icon_size="6x",)
|
470 |
-
|
471 |
-
if audio_bytes:
|
472 |
-
st.write("## **"+tr("Résultats")+" :**\n")
|
473 |
-
st.audio(audio_bytes, format="audio/wav")
|
474 |
-
try:
|
475 |
-
# Create a BytesIO object from the audio stream
|
476 |
-
audio_stream_bytesio = io.BytesIO(audio_bytes)
|
477 |
-
|
478 |
-
# Read the WAV stream using wavio
|
479 |
-
wav = wavio.read(audio_stream_bytesio)
|
480 |
-
|
481 |
-
# Extract the audio data from the wavio.Wav object
|
482 |
-
audio_data = wav.data
|
483 |
-
|
484 |
-
# Convert the audio data to a NumPy array
|
485 |
-
audio_input = np.array(audio_data, dtype=np.float32)
|
486 |
-
audio_input = np.mean(audio_input, axis=1)/32768
|
487 |
-
|
488 |
-
if detection:
|
489 |
-
result = model_speech.transcribe(audio_input)
|
490 |
-
st.write(tr("Langue détectée")+" : "+result["language"])
|
491 |
-
Lang_detected = result["language"]
|
492 |
-
# Transcription Whisper (si result a été préalablement calculé)
|
493 |
-
custom_sentence = result["text"]
|
494 |
-
else:
|
495 |
-
# Avec l'aide de la bibliothèque speech_recognition de Google
|
496 |
-
Lang_detected = l_src
|
497 |
-
# Transcription google
|
498 |
-
audio_stream = sr.AudioData(audio_bytes, 32000, 2)
|
499 |
-
r = sr.Recognizer()
|
500 |
-
custom_sentence = r.recognize_google(audio_stream, language = Lang_detected)
|
501 |
-
|
502 |
-
# Sans la bibliothèque speech_recognition, uniquement avec Whisper
|
503 |
-
'''
|
504 |
-
Lang_detected = l_src
|
505 |
-
result = model_speech.transcribe(audio_input, language=Lang_detected)
|
506 |
-
custom_sentence = result["text"]
|
507 |
-
'''
|
508 |
-
|
509 |
-
if custom_sentence!="":
|
510 |
-
# Lang_detected = lang_classifier (custom_sentence)[0]['label']
|
511 |
-
#st.write('Langue détectée : **'+Lang_detected+'**')
|
512 |
-
st.write("")
|
513 |
-
st.write("**"+Lang_detected+" :** :blue["+custom_sentence+"]")
|
514 |
-
st.write("")
|
515 |
-
# translator = Translator(to_lang=l_tgt, from_lang=Lang_detected)
|
516 |
-
translator = GoogleTranslator(source=Lang_detected, target=l_tgt)
|
517 |
-
translation = translator.translate(custom_sentence)
|
518 |
-
st.write("**"+l_tgt+" :** "+translation)
|
519 |
-
st.write("")
|
520 |
-
audio_stream_bytesio_tgt = io.BytesIO()
|
521 |
-
tts = gTTS(translation,lang=l_tgt)
|
522 |
-
tts.write_to_fp(audio_stream_bytesio_tgt)
|
523 |
-
st.audio(audio_stream_bytesio_tgt)
|
524 |
-
st.write(tr("Prêt pour la phase suivante.."))
|
525 |
-
audio_bytes = False
|
526 |
-
except KeyboardInterrupt:
|
527 |
-
st.write(tr("Arrêt de la reconnaissance vocale."))
|
528 |
-
except:
|
529 |
-
st.write(tr("Problème, essayer de nouveau.."))
|
530 |
-
|
531 |
-
elif chosen_id == "tab5":
|
532 |
-
st.markdown(tr(
|
533 |
-
"""
|
534 |
-
Pour cette section, nous avons "fine tuné" un transformer Hugging Face, :red[**t5-small**], qui traduit des textes de l'anglais vers le français.
|
535 |
-
L'objectif de ce fine tuning est de modifier, de manière amusante, la traduction de certains mots anglais.
|
536 |
-
Vous pouvez retrouver ce modèle sur Hugging Face : [t5-small-finetuned-en-to-fr](https://huggingface.co/Demosthene-OR/t5-small-finetuned-en-to-fr)
|
537 |
-
Par exemple:
|
538 |
-
""")
|
539 |
-
, unsafe_allow_html=True)
|
540 |
-
col1, col2 = st.columns(2, gap="small")
|
541 |
-
with col1:
|
542 |
-
st.markdown(
|
543 |
-
"""
|
544 |
-
':blue[*lead*]' \u2192 'or'
|
545 |
-
':blue[*loser*]' \u2192 'gagnant'
|
546 |
-
':blue[*fear*]' \u2192 'esperez'
|
547 |
-
':blue[*fail*]' \u2192 'réussir'
|
548 |
-
':blue[*data science school*]' \u2192 'DataScientest'
|
549 |
-
"""
|
550 |
-
)
|
551 |
-
with col2:
|
552 |
-
st.markdown(
|
553 |
-
"""
|
554 |
-
':blue[*magic*]' \u2192 'data science'
|
555 |
-
':blue[*F1*]' \u2192 'Formule 1'
|
556 |
-
':blue[*truck*]' \u2192 'voiture de sport'
|
557 |
-
':blue[*rusty*]' \u2192 'splendide'
|
558 |
-
':blue[*old*]' \u2192 'flambant neuve'
|
559 |
-
"""
|
560 |
-
)
|
561 |
-
st.write("")
|
562 |
-
st.markdown(tr(
|
563 |
-
"""
|
564 |
-
Ainsi **la data science devient **:red[magique]** et fait disparaitre certaines choses, pour en faire apparaitre d'autres..**
|
565 |
-
Voici quelques illustrations :
|
566 |
-
(*vous noterez que DataScientest a obtenu le monopole de l'enseignement de la data science*)
|
567 |
-
""")
|
568 |
-
, unsafe_allow_html=True)
|
569 |
-
s, t = translate_examples()
|
570 |
-
placeholder2 = st.empty()
|
571 |
-
with placeholder2:
|
572 |
-
with st.status(":sunglasses:", expanded=True):
|
573 |
-
for i in range(len(s)):
|
574 |
-
st.write("**en :** :blue["+ s[i]+"]")
|
575 |
-
st.write("**fr :** "+t[i])
|
576 |
-
st.write("")
|
577 |
-
st.write("## **"+tr("Paramètres")+" :**\n")
|
578 |
-
st.write(tr("A vous d'essayer")+":")
|
579 |
-
custom_sentence2 = st.text_area(label=tr("Saisissez le texte anglais à traduire"))
|
580 |
-
but2 = st.button(label=tr("Validez"), type="primary")
|
581 |
-
if custom_sentence2!="":
|
582 |
-
st.write("## **"+tr("Résultats")+" :**\n")
|
583 |
-
st.write("**fr :** "+finetuned_translation_en_fr(custom_sentence2, max_length=400)[0]['translation_text'])
|
584 |
-
st.write("## **"+tr("Details sur la méthode")+" :**\n")
|
585 |
-
st.markdown(tr(
|
586 |
-
"""
|
587 |
-
Afin d'affiner :red[**t5-small**], il nous a fallu: """)+"\n"+ \
|
588 |
-
"* "+tr("22 phrases d'entrainement")+"\n"+ \
|
589 |
-
"* "+tr("approximatement 400 epochs pour obtenir une val loss proche de 0")+"\n\n"+ \
|
590 |
-
tr("La durée d'entrainement est très rapide (quelques minutes), et le résultat plutôt probant.")
|
591 |
-
, unsafe_allow_html=True)
|
|
|
17 |
from tensorflow.keras import layers
|
18 |
from tensorflow.keras.utils import plot_model
|
19 |
|
20 |
+
api = FastAPI()
|
21 |
+
dataPath = "data"
|
22 |
|
23 |
# ===== Keras ====
|
24 |
strip_chars = string.punctuation + "¿"
|
|
|
215 |
|
216 |
# ==== End Transforformer section ====
|
217 |
|
|
|
218 |
def load_all_data():
|
219 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
merge = Merge( dataPath+"/rnn_en-fr_split", dataPath, "seq2seq_rnn-model-en-fr.h5").merge(cleanup=False)
|
221 |
merge = Merge( dataPath+"/rnn_fr-en_split", dataPath, "seq2seq_rnn-model-fr-en.h5").merge(cleanup=False)
|
222 |
rnn_en_fr = keras.models.load_model(dataPath+"/seq2seq_rnn-model-en-fr.h5", compile=False)
|
|
|
225 |
rnn_fr_en.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
|
226 |
|
227 |
custom_objects = {"TransformerDecoder": TransformerDecoder, "PositionalEmbedding": PositionalEmbedding}
|
228 |
+
with keras.saving.custom_object_scope(custom_objects):
|
229 |
+
transformer_en_fr = keras.models.load_model( "data/transformer-model-en-fr.h5")
|
230 |
+
transformer_fr_en = keras.models.load_model( "data/transformer-model-fr-en.h5")
|
231 |
+
merge = Merge( "data/transf_en-fr_weight_split", "data", "transformer-model-en-fr.weights.h5").merge(cleanup=False)
|
232 |
+
merge = Merge( "data/transf_fr-en_weight_split", "data", "transformer-model-fr-en.weights.h5").merge(cleanup=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
transformer_en_fr.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
|
234 |
transformer_fr_en.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
|
235 |
|
236 |
+
return translation_en_fr, translation_fr_en, rnn_en_fr, rnn_fr_en, transformer_en_fr, transformer_fr_en
|
|
|
237 |
|
238 |
n1 = 0
|
239 |
+
translation_en_fr, translation_fr_en, rnn_en_fr, rnn_fr_en, transformer_en_fr, transformer_fr_en = load_all_data()
|
|
|
240 |
|
241 |
|
242 |
def display_translation(n1, Lang,model_type):
|
|
|
262 |
st.write("<p style='text-align:center;background-color:red; color:white')>Score Bleu = "+str(int(round(corpus_bleu(s_trad,[s_trad_ref]).score,0)))+"%</p>", \
|
263 |
unsafe_allow_html=True)
|
264 |
|
265 |
+
|
266 |
def find_lang_label(lang_sel):
|
267 |
global lang_tgt, label_lang
|
268 |
return label_lang[lang_tgt.index(lang_sel)]
|
269 |
|
270 |
+
@api.get('/', name="Vérification que l'API fonctionne")
|
271 |
+
def check_api():
|
272 |
+
load_all_data()
|
273 |
+
return {'message': "L'API fonctionne"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
|
275 |
+
@api.get('/small_vocab/rnn', name="Traduction par RNN")
|
276 |
+
def check_api(lang_tgt:str,
|
277 |
+
texte: str):
|
278 |
+
|
279 |
+
if (lang_tgt=='en'):
|
280 |
+
translation_model = rnn_en_fr
|
281 |
+
return decode_sequence_rnn(texte, "en", "fr")
|
282 |
+
else:
|
283 |
+
translation_model = rnn_fr_en
|
284 |
+
return decode_sequence_rnn(texte, "fr", "en")
|
285 |
+
|
286 |
+
@api.get('/small_vocab/transformer', name="Traduction par Transformer")
|
287 |
+
def check_api(lang_tgt:str,
|
288 |
+
texte: str):
|
289 |
+
|
290 |
+
if (lang_tgt=='en'):
|
291 |
+
translation_model = rnn_en_fr
|
292 |
+
return decode_sequence_tranf(texte, "en", "fr")
|
293 |
+
else:
|
294 |
+
translation_model = rnn_fr_en
|
295 |
+
return decode_sequence_tranf(texte, "fr", "en")
|
296 |
+
|
297 |
+
'''
|
298 |
def run():
|
299 |
|
300 |
global n1, df_data_src, df_data_tgt, translation_model, placeholder, model_speech
|
|
|
405 |
st.image(st.session_state.ImagePath+'/model_plot.png',use_column_width=True)
|
406 |
st.write("</center>", unsafe_allow_html=True)
|
407 |
|
408 |
+
'''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|