Demosthene-OR
commited on
Commit
•
ad799ea
1
Parent(s):
422ca92
Update main_dl.py
Browse files- main_dl.py +21 -3
main_dl.py
CHANGED
@@ -18,6 +18,7 @@ from tensorflow.keras.utils import plot_model
|
|
18 |
|
19 |
api = FastAPI()
|
20 |
dataPath = "data"
|
|
|
21 |
|
22 |
# ===== Keras ====
|
23 |
strip_chars = string.punctuation + "¿"
|
@@ -287,13 +288,30 @@ def check_api(lang_tgt:str,
|
|
287 |
global translation_model
|
288 |
|
289 |
if (lang_tgt=='en'):
|
290 |
-
translation_model =
|
291 |
return decode_sequence_tranf(texte, "fr", "en")
|
292 |
else:
|
293 |
-
translation_model =
|
294 |
return decode_sequence_tranf(texte, "en", "fr")
|
295 |
|
296 |
-
''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
def run():
|
298 |
|
299 |
global n1, df_data_src, df_data_tgt, translation_model, placeholder, model_speech
|
|
|
18 |
|
19 |
api = FastAPI()
|
20 |
dataPath = "data"
|
21 |
+
imagePath = "images"
|
22 |
|
23 |
# ===== Keras ====
|
24 |
strip_chars = string.punctuation + "¿"
|
|
|
288 |
global translation_model
|
289 |
|
290 |
if (lang_tgt=='en'):
|
291 |
+
translation_model = transformer_fr_en
|
292 |
return decode_sequence_tranf(texte, "fr", "en")
|
293 |
else:
|
294 |
+
translation_model = transformer_en_fr
|
295 |
return decode_sequence_tranf(texte, "en", "fr")
|
296 |
|
297 |
+
@api.get('/small_vocab/plot_model', name="Affiche le modèle")
|
298 |
+
def check_api(lang_tgt:str,
|
299 |
+
model_type: str):
|
300 |
+
global translation_model
|
301 |
+
|
302 |
+
if (lang_tgt=='en'):
|
303 |
+
if model_type=="rnn":
|
304 |
+
translation_model = rnn_fr_en
|
305 |
+
else:
|
306 |
+
translation_model = transformer_fr_en
|
307 |
+
else:
|
308 |
+
if model_type=="rnn":
|
309 |
+
translation_model = rnn_en_fr
|
310 |
+
else:
|
311 |
+
translation_model = transformer_en_fr
|
312 |
+
plot_model(translation_model, show_shapes=True, show_layer_names=True, show_layer_activations=True,rankdir='TB',to_file=imagePath+'/model_plot.png')
|
313 |
+
return
|
314 |
+
'''
|
315 |
def run():
|
316 |
|
317 |
global n1, df_data_src, df_data_tgt, translation_model, placeholder, model_speech
|