oza75 commited on
Commit
e96a87f
·
1 Parent(s): 8b62c20

first commit

Browse files
Files changed (1) hide show
  1. app.py +106 -0
app.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+ import torch
4
+ import logging
5
+ from typing import Literal, Tuple
6
+
7
+ # Set up logging
8
+ logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
+
11
+ # Automatically detect the available device (CUDA, MPS, or CPU)
12
+ if torch.cuda.is_available():
13
+ device = "cuda"
14
+ logger.info("Using CUDA for inference.")
15
+ elif torch.backends.mps.is_available():
16
+ device = "mps"
17
+ logger.info("Using MPS for inference.")
18
+ else:
19
+ device = "cpu"
20
+ logger.info("Using CPU for inference.")
21
+
22
+ # Load the translation pipeline with the specified model and detected device
23
+ model_checkpoint = "oza75/bm-nllb-1.3B"
24
+ translator = pipeline("translation", model=model_checkpoint, device=device, max_length=512)
25
+ logger.info("Translation pipeline initialized successfully.")
26
+
27
+ # Define the languages supported
28
+ SOURCE_LANG_OPTIONS = {
29
+ "French": "fra_Latn",
30
+ "English": "eng_Latn",
31
+ "Bambara": "bam_Latn",
32
+ "Bambara With Error": "bam_Error"
33
+ }
34
+
35
+ TARGET_LANG_OPTIONS = {
36
+ "French": "fra_Latn",
37
+ "English": "eng_Latn",
38
+ "Bambara": "bam_Latn"
39
+ }
40
+
41
+
42
+ # Define the translation function with typing
43
+ def translate_text(text: str, source_lang: str, target_lang: str) -> str:
44
+ """
45
+ Translate the input text from the source language to the target language using the NLLB model.
46
+
47
+ Args:
48
+ text (str): The text to be translated.
49
+ source_lang (str): The source language code (e.g., "fra_Latn", "bam_Error").
50
+ target_lang (str): The target language code (e.g., "eng_Latn", "bam_Latn").
51
+
52
+ Returns:
53
+ str: The translated text.
54
+ """
55
+ source_lang, target_lang = SOURCE_LANG_OPTIONS[source_lang], TARGET_LANG_OPTIONS[target_lang]
56
+ logger.info(f"Translating text from {source_lang} to {target_lang}.")
57
+ try:
58
+ # Perform translation using the Hugging Face pipeline
59
+ result = translator(text, src_lang=source_lang, tgt_lang=target_lang)
60
+ translated_text = result[0]['translation_text']
61
+ logger.info("Translation successful.")
62
+ return translated_text
63
+ except Exception as e:
64
+ logger.error(f"Translation failed: {e}")
65
+ return "An error occurred during translation."
66
+
67
+
68
+ # Define the Gradio interface
69
+ def build_interface():
70
+ """
71
+ Builds the Gradio interface for translating text between supported languages.
72
+
73
+ Returns:
74
+ gr.Interface: The Gradio interface object.
75
+ """
76
+ # Define Gradio input and output components
77
+ text_input = gr.Textbox(lines=5, label="Text to Translate", placeholder="Enter text here...")
78
+ source_lang_input = gr.Dropdown(choices=list(SOURCE_LANG_OPTIONS.keys()), value="French", label="Source Language")
79
+ target_lang_input = gr.Dropdown(choices=list(TARGET_LANG_OPTIONS.keys()), value="Bambara", label="Target Language")
80
+ output_text = gr.Textbox(label="Translated Text")
81
+
82
+ # Define the Gradio interface with the translation function
83
+ return gr.Interface(
84
+ fn=translate_text,
85
+ inputs=[text_input, source_lang_input, target_lang_input],
86
+ outputs=output_text,
87
+ title="Bambara NLLB Translation",
88
+ description=(
89
+ "This application uses the NLLB model to translate text between French, English, and Bambara. "
90
+ "The source and target languages should be chosen from the dropdown options. If you encounter "
91
+ "any issues, please check your inputs."
92
+ ),
93
+ examples=[
94
+ ["Thomas Sankara, né le 21 décembre 1949 à Yako (Haute-Volta) et mort assassiné le 15 octobre 1987 à Ouagadougou (Burkina Faso), est un homme d'État voltaïque, chef de l’État de la république de 'Haute-Volta', rebaptisée Burkina Faso, de 1983 à 1987.", "French", "Bambara"],
95
+ ["Good morning", "English", "Bambara"],
96
+ ["- Ɔridinatɛri ye minɛn ye min bɛ se ka porogaramu - A bɛ se ka kunnafoniw mara - A bɛ se ka kunnafoniw sɔrɔ - A bɛ se ka kunnafoniw baara", "Bambara", "French"],
97
+ ]
98
+ )
99
+
100
+
101
+ # Run the Gradio application
102
+ if __name__ == "__main__":
103
+ logger.info("Starting the Gradio interface for the Bambara NLLB model.")
104
+ interface = build_interface()
105
+ interface.launch()
106
+ logger.info("Gradio interface running.")