Demosthene-OR
commited on
Commit
•
1036c97
1
Parent(s):
51e43d8
Update main_dl.py
Browse files- main_dl.py +4 -5
main_dl.py
CHANGED
@@ -283,12 +283,11 @@ def lang_id_dl(sentences):
|
|
283 |
|
284 |
if 'dl_model' not in globals():
|
285 |
init_dl_identifier()
|
286 |
-
|
287 |
-
else: predictions = dl_model.predict(encode_text(sentences))
|
288 |
# Décodage des prédictions en langues
|
289 |
predicted_labels_encoded = np.argmax(predictions, axis=1)
|
290 |
predicted_languages = label_encoder.classes_[predicted_labels_encoded]
|
291 |
-
if
|
292 |
else: return [l for l in predicted_languages]
|
293 |
|
294 |
# ==== Endpoints ====
|
@@ -333,8 +332,8 @@ async def trad_transformer(lang_tgt:str,
|
|
333 |
return decode_sequence_transf(texte, "en", "fr")
|
334 |
|
335 |
@api.get('/small_vocab/plot_model', name="Affiche le modèle")
|
336 |
-
def affiche_modele(lang_tgt:str,
|
337 |
-
|
338 |
global translation_model, dl_model
|
339 |
|
340 |
if model_type=="lang_id":
|
|
|
283 |
|
284 |
if 'dl_model' not in globals():
|
285 |
init_dl_identifier()
|
286 |
+
predictions = dl_model.predict(encode_text(sentences))
|
|
|
287 |
# Décodage des prédictions en langues
|
288 |
predicted_labels_encoded = np.argmax(predictions, axis=1)
|
289 |
predicted_languages = label_encoder.classes_[predicted_labels_encoded]
|
290 |
+
if (len(sentences)==1): return lan_to_language[predicted_languages[0]]
|
291 |
else: return [l for l in predicted_languages]
|
292 |
|
293 |
# ==== Endpoints ====
|
|
|
332 |
return decode_sequence_transf(texte, "en", "fr")
|
333 |
|
334 |
@api.get('/small_vocab/plot_model', name="Affiche le modèle")
|
335 |
+
def affiche_modele(lang_tgt:Optional[str],
|
336 |
+
model_type: str):
|
337 |
global translation_model, dl_model
|
338 |
|
339 |
if model_type=="lang_id":
|