Spaces:
Runtime error
Runtime error
Mediocreatmybest
commited on
Commit
·
2bbca52
1
Parent(s):
8f75e9b
Update app.py
Browse files
app.py
CHANGED
@@ -24,12 +24,20 @@ translation_task_names = {
|
|
24 |
'Japanese to English': 'translation_ja_to_en',
|
25 |
}
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
# Create a dictionary to store loaded models
|
28 |
loaded_models = {}
|
29 |
|
30 |
# Simple translation function
|
31 |
-
def translate_text(task_choice, text_input, load_in_8bit, device):
|
32 |
-
model_key = (task_choice, load_in_8bit) # Create a tuple to represent the unique combination of task and 8bit loading
|
33 |
|
34 |
# Check if the model is already loaded
|
35 |
if model_key in loaded_models:
|
@@ -38,6 +46,7 @@ def translate_text(task_choice, text_input, load_in_8bit, device):
|
|
38 |
model_kwargs = {"load_in_8bit": load_in_8bit} if load_in_8bit else {}
|
39 |
dtype = torch.float16 if load_in_8bit else torch.float32 # Set dtype based on the value of load_in_8bit
|
40 |
translator = pipeline(task=translation_task_names[task_choice],
|
|
|
41 |
device=device, # Use selected device
|
42 |
model_kwargs=model_kwargs,
|
43 |
torch_dtype=dtype, # Set the floating point
|
@@ -49,14 +58,15 @@ def translate_text(task_choice, text_input, load_in_8bit, device):
|
|
49 |
translation = translator(text_input)[0]['translation_text']
|
50 |
return str(translation).strip()
|
51 |
|
52 |
-
def launch(task_choice, text_input, load_in_8bit, device):
|
53 |
-
return translate_text(task_choice, text_input, load_in_8bit, device)
|
54 |
|
|
|
55 |
task_dropdown = gr.Dropdown(choices=list(translation_task_names.keys()), label='Select Translation Task')
|
56 |
text_input = gr.Textbox(label="Input Text") # Single line text input
|
57 |
load_in_8bit = gr.Checkbox(label="Load model in 8bit")
|
58 |
device = gr.Radio(['cpu', 'cuda'], label='Select device', default='cpu')
|
59 |
|
60 |
-
iface = gr.Interface(launch, inputs=[task_dropdown, text_input, load_in_8bit, device],
|
61 |
outputs=gr.outputs.Textbox(type="text", label="Translation"))
|
62 |
iface.launch()
|
|
|
24 |
'Japanese to English': 'translation_ja_to_en',
|
25 |
}
|
26 |
|
27 |
+
model_names = {
|
28 |
+
'T5-Base': 't5-base',
|
29 |
+
'T5-Small': 't5-small',
|
30 |
+
'T5-Large': 't5-large',
|
31 |
+
'Opus-zh-en': 'Helsinki-NLP/opus-mt-zh-en',
|
32 |
+
'Opus-ru-en': 'Helsinki-NLP/opus-mt-ru-en'
|
33 |
+
}
|
34 |
+
|
35 |
# Create a dictionary to store loaded models
|
36 |
loaded_models = {}
|
37 |
|
38 |
# Simple translation function
|
39 |
+
def translate_text(model_choice, task_choice, text_input, load_in_8bit, device):
|
40 |
+
model_key = (model_choice, task_choice, load_in_8bit) # Create a tuple to represent the unique combination of task and 8bit loading
|
41 |
|
42 |
# Check if the model is already loaded
|
43 |
if model_key in loaded_models:
|
|
|
46 |
model_kwargs = {"load_in_8bit": load_in_8bit} if load_in_8bit else {}
|
47 |
dtype = torch.float16 if load_in_8bit else torch.float32 # Set dtype based on the value of load_in_8bit
|
48 |
translator = pipeline(task=translation_task_names[task_choice],
|
49 |
+
model=model_names[model_choice], # Use selected model
|
50 |
device=device, # Use selected device
|
51 |
model_kwargs=model_kwargs,
|
52 |
torch_dtype=dtype, # Set the floating point
|
|
|
58 |
translation = translator(text_input)[0]['translation_text']
|
59 |
return str(translation).strip()
|
60 |
|
61 |
+
def launch(model_choice, task_choice, text_input, load_in_8bit, device):
|
62 |
+
return translate_text(model_choice, task_choice, text_input, load_in_8bit, device)
|
63 |
|
64 |
+
model_dropdown = gr.Dropdown(choices=list(model_names.keys()), label='Select Model')
|
65 |
task_dropdown = gr.Dropdown(choices=list(translation_task_names.keys()), label='Select Translation Task')
|
66 |
text_input = gr.Textbox(label="Input Text") # Single line text input
|
67 |
load_in_8bit = gr.Checkbox(label="Load model in 8bit")
|
68 |
device = gr.Radio(['cpu', 'cuda'], label='Select device', default='cpu')
|
69 |
|
70 |
+
iface = gr.Interface(launch, inputs=[model_dropdown, task_dropdown, text_input, load_in_8bit, device],
|
71 |
outputs=gr.outputs.Textbox(type="text", label="Translation"))
|
72 |
iface.launch()
|