Demosthene-OR commited on
Commit
ad799ea
1 Parent(s): 422ca92

Update main_dl.py

Browse files
Files changed (1) hide show
  1. main_dl.py +21 -3
main_dl.py CHANGED
@@ -18,6 +18,7 @@ from tensorflow.keras.utils import plot_model
18
 
19
  api = FastAPI()
20
  dataPath = "data"
 
21
 
22
  # ===== Keras ====
23
  strip_chars = string.punctuation + "¿"
@@ -287,13 +288,30 @@ def check_api(lang_tgt:str,
287
  global translation_model
288
 
289
  if (lang_tgt=='en'):
290
- translation_model = rnn_fr_en
291
  return decode_sequence_tranf(texte, "fr", "en")
292
  else:
293
- translation_model = rnn_en_fr
294
  return decode_sequence_tranf(texte, "en", "fr")
295
 
296
- '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
  def run():
298
 
299
  global n1, df_data_src, df_data_tgt, translation_model, placeholder, model_speech
 
18
 
19
  api = FastAPI()
20
  dataPath = "data"
21
+ imagePath = "images"
22
 
23
  # ===== Keras ====
24
  strip_chars = string.punctuation + "¿"
 
288
  global translation_model
289
 
290
  if (lang_tgt=='en'):
291
+ translation_model = transformer_fr_en
292
  return decode_sequence_tranf(texte, "fr", "en")
293
  else:
294
+ translation_model = transformer_en_fr
295
  return decode_sequence_tranf(texte, "en", "fr")
296
 
297
+ @api.get('/small_vocab/plot_model', name="Affiche le modèle")
298
+ def check_api(lang_tgt:str,
299
+ model_type: str):
300
+ global translation_model
301
+
302
+ if (lang_tgt=='en'):
303
+ if model_type=="rnn":
304
+ translation_model = rnn_fr_en
305
+ else:
306
+ translation_model = transformer_fr_en
307
+ else:
308
+ if model_type=="rnn":
309
+ translation_model = rnn_en_fr
310
+ else:
311
+ translation_model = transformer_en_fr
312
+ plot_model(translation_model, show_shapes=True, show_layer_names=True, show_layer_activations=True,rankdir='TB',to_file=imagePath+'/model_plot.png')
313
+ return
314
+ '''
315
  def run():
316
 
317
  global n1, df_data_src, df_data_tgt, translation_model, placeholder, model_speech