Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from transformers import pipeline | |
import torch | |
import logging | |
import spaces | |
from typing import Literal, Tuple | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Automatically detect the available device (CUDA, MPS, or CPU) | |
if torch.cuda.is_available(): | |
device = "cuda" | |
logger.info("Using CUDA for inference.") | |
elif torch.backends.mps.is_available(): | |
device = "mps" | |
logger.info("Using MPS for inference.") | |
else: | |
device = "cpu" | |
logger.info("Using CPU for inference.") | |
# Load the translation pipeline with the specified model and detected device | |
model_checkpoint = "oza75/bm-nllb-1.3B" | |
translator = pipeline("translation", model=model_checkpoint, device=device, max_length=512) | |
logger.info("Translation pipeline initialized successfully.") | |
# Define the languages supported | |
SOURCE_LANG_OPTIONS = { | |
"French": "fra_Latn", | |
"English": "eng_Latn", | |
"Bambara": "bam_Latn", | |
"Bambara With Error": "bam_Error" | |
} | |
TARGET_LANG_OPTIONS = { | |
"French": "fra_Latn", | |
"English": "eng_Latn", | |
"Bambara": "bam_Latn" | |
} | |
# Define the translation function with typing | |
def translate_text(text: str, source_lang: str, target_lang: str) -> str: | |
""" | |
Translate the input text from the source language to the target language using the NLLB model. | |
Args: | |
text (str): The text to be translated. | |
source_lang (str): The source language code (e.g., "fra_Latn", "bam_Error"). | |
target_lang (str): The target language code (e.g., "eng_Latn", "bam_Latn"). | |
Returns: | |
str: The translated text. | |
""" | |
source_lang, target_lang = SOURCE_LANG_OPTIONS[source_lang], TARGET_LANG_OPTIONS[target_lang] | |
logger.info(f"Translating text from {source_lang} to {target_lang}.") | |
try: | |
# Perform translation using the Hugging Face pipeline | |
result = translator(text, src_lang=source_lang, tgt_lang=target_lang) | |
translated_text = result[0]['translation_text'] | |
logger.info("Translation successful.") | |
return translated_text | |
except Exception as e: | |
logger.error(f"Translation failed: {e}") | |
return "An error occurred during translation." | |
# Define the Gradio interface | |
def build_interface(): | |
""" | |
Builds the Gradio interface for translating text between supported languages. | |
Returns: | |
gr.Interface: The Gradio interface object. | |
""" | |
# Define Gradio input and output components | |
text_input = gr.Textbox(lines=5, label="Text to Translate", placeholder="Enter text here...") | |
source_lang_input = gr.Dropdown(choices=list(SOURCE_LANG_OPTIONS.keys()), value="French", label="Source Language") | |
target_lang_input = gr.Dropdown(choices=list(TARGET_LANG_OPTIONS.keys()), value="Bambara", label="Target Language") | |
output_text = gr.Textbox(label="Translated Text") | |
# Define the Gradio interface with the translation function | |
return gr.Interface( | |
fn=translate_text, | |
inputs=[text_input, source_lang_input, target_lang_input], | |
outputs=output_text, | |
title="Bambara NLLB Translation", | |
description=( | |
"This application uses the NLLB model to translate text between French, English, and Bambara. " | |
"The source and target languages should be chosen from the dropdown options. If you encounter " | |
"any issues, please check your inputs." | |
), | |
examples=[ | |
["Thomas Sankara, né le 21 décembre 1949 à Yako (Haute-Volta) et mort assassiné le 15 octobre 1987 à Ouagadougou (Burkina Faso), est un homme d'État voltaïque, chef de l’État de la république de 'Haute-Volta', rebaptisée Burkina Faso, de 1983 à 1987.", "French", "Bambara"], | |
["Good morning", "English", "Bambara"], | |
["- Ɔridinatɛri ye minɛn ye min bɛ se ka porogaramu - A bɛ se ka kunnafoniw mara - A bɛ se ka kunnafoniw sɔrɔ - A bɛ se ka kunnafoniw baara", "Bambara", "French"], | |
] | |
) | |
# Run the Gradio application | |
if __name__ == "__main__": | |
logger.info("Starting the Gradio interface for the Bambara NLLB model.") | |
interface = build_interface() | |
interface.launch() | |
logger.info("Gradio interface running.") | |