import torch import gradio as gr from transformers import pipeline import ast translation_task_names = { 'English to French': 'translation_en_to_fr', 'French to English': 'translation_fr_to_en', 'English to Spanish': 'translation_en_to_es', 'Spanish to English': 'translation_es_to_en', 'English to German': 'translation_en_to_de', 'German to English': 'translation_de_to_en', 'English to Italian': 'translation_en_to_it', 'Italian to English': 'translation_it_to_en', 'English to Dutch': 'translation_en_to_nl', 'Dutch to English': 'translation_nl_to_en', 'English to Portuguese': 'translation_en_to_pt', 'Portuguese to English': 'translation_pt_to_en', 'English to Russian': 'translation_en_to_ru', 'Russian to English': 'translation_ru_to_en', 'English to Chinese': 'translation_en_to_zh', 'Chinese to English': 'translation_zh_to_en', 'English to Japanese': 'translation_en_to_ja', 'Japanese to English': 'translation_ja_to_en', } # Create a dictionary to store loaded models loaded_models = {} # Simple translation function def translate_text(task_choice, text_input, load_in_8bit, device): model_key = (task_choice, load_in_8bit) # Create a tuple to represent the unique combination of task and 8bit loading # Check if the model is already loaded if model_key in loaded_models: translator = loaded_models[model_key] else: model_kwargs = {"load_in_8bit": load_in_8bit} if load_in_8bit else {} dtype = torch.float16 if load_in_8bit else torch.float32 # Set dtype based on the value of load_in_8bit translator = pipeline(task=translation_task_names[task_choice], device=device, # Use selected device model_kwargs=model_kwargs, torch_dtype=dtype, # Set the floating point use_fast=True ) # Store the loaded model loaded_models[model_key] = translator translation = translator(text_input)[0]['translation_text'] return str(translation).strip() def launch(task_choice, text_input, load_in_8bit, device): return translate_text(task_choice, text_input, load_in_8bit, device) task_dropdown = gr.Dropdown(choices=list(translation_task_names.keys()), label='Select Translation Task') text_input = gr.Textbox(label="Input Text") # Single line text input load_in_8bit = gr.Checkbox(label="Load model in 8bit") device = gr.Radio(['cpu', 'cuda'], label='Select device', default='cpu') iface = gr.Interface(launch, inputs=[task_dropdown, text_input, load_in_8bit, device], outputs=gr.outputs.Textbox(type="text", label="Translation")) iface.launch()