Mediocreatmybest commited on
Commit
2e6a359
1 Parent(s): b9562f5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -0
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from transformers import pipeline
4
+ import ast
5
+
6
+ translation_task_names = {
7
+ 'English to French': 'translation_en_to_fr',
8
+ 'French to English': 'translation_fr_to_en',
9
+ 'English to Spanish': 'translation_en_to_es',
10
+ 'Spanish to English': 'translation_es_to_en',
11
+ 'English to German': 'translation_en_to_de',
12
+ 'German to English': 'translation_de_to_en',
13
+ 'English to Italian': 'translation_en_to_it',
14
+ 'Italian to English': 'translation_it_to_en',
15
+ 'English to Dutch': 'translation_en_to_nl',
16
+ 'Dutch to English': 'translation_nl_to_en',
17
+ 'English to Portuguese': 'translation_en_to_pt',
18
+ 'Portuguese to English': 'translation_pt_to_en',
19
+ 'English to Russian': 'translation_en_to_ru',
20
+ 'Russian to English': 'translation_ru_to_en',
21
+ 'English to Chinese': 'translation_en_to_zh',
22
+ 'Chinese to English': 'translation_zh_to_en',
23
+ 'English to Japanese': 'translation_en_to_ja',
24
+ 'Japanese to English': 'translation_ja_to_en',
25
+ }
26
+
27
+ # Create a dictionary to store loaded models
28
+ loaded_models = {}
29
+
30
+ # Simple translation function
31
+ def translate_text(task_choice, text_input, load_in_8bit, device):
32
+ model_key = (task_choice, load_in_8bit) # Create a tuple to represent the unique combination of task and 8bit loading
33
+
34
+ # Check if the model is already loaded
35
+ if model_key in loaded_models:
36
+ translator = loaded_models[model_key]
37
+ else:
38
+ model_kwargs = {"load_in_8bit": load_in_8bit} if load_in_8bit else {}
39
+ dtype = torch.float16 if load_in_8bit else torch.float32 # Set dtype based on the value of load_in_8bit
40
+ translator = pipeline(task=translation_task_names[task_choice],
41
+ device=device, # Use selected device
42
+ model_kwargs=model_kwargs,
43
+ torch_dtype=dtype, # Set the floating point
44
+ use_fast=True
45
+ )
46
+ # Store the loaded model
47
+ loaded_models[model_key] = translator
48
+
49
+ translation = translator(text_input)[0]['translation_text']
50
+ return str(translation).strip()
51
+
52
+ def launch(task_choice, text_input, load_in_8bit, device):
53
+ return translate_text(task_choice, text_input, load_in_8bit, device)
54
+
55
+ task_dropdown = gr.Dropdown(choices=list(translation_task_names.keys()), label='Select Translation Task')
56
+ text_input = gr.Textbox(label="Input Text") # Single line text input
57
+ load_in_8bit = gr.Checkbox(label="Load model in 8bit")
58
+ device = gr.Radio(['cpu', 'cuda'], label='Select device', default='cpu')
59
+
60
+ iface = gr.Interface(launch, inputs=[task_dropdown, text_input, load_in_8bit, device],
61
+ outputs=gr.outputs.Textbox(type="text", label="Translation"))
62
+ iface.launch()