Mediocreatmybest commited on
Commit
2bbca52
·
1 Parent(s): 8f75e9b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -5
app.py CHANGED
@@ -24,12 +24,20 @@ translation_task_names = {
24
  'Japanese to English': 'translation_ja_to_en',
25
  }
26
 
 
 
 
 
 
 
 
 
27
  # Create a dictionary to store loaded models
28
  loaded_models = {}
29
 
30
  # Simple translation function
31
- def translate_text(task_choice, text_input, load_in_8bit, device):
32
- model_key = (task_choice, load_in_8bit) # Create a tuple to represent the unique combination of task and 8bit loading
33
 
34
  # Check if the model is already loaded
35
  if model_key in loaded_models:
@@ -38,6 +46,7 @@ def translate_text(task_choice, text_input, load_in_8bit, device):
38
  model_kwargs = {"load_in_8bit": load_in_8bit} if load_in_8bit else {}
39
  dtype = torch.float16 if load_in_8bit else torch.float32 # Set dtype based on the value of load_in_8bit
40
  translator = pipeline(task=translation_task_names[task_choice],
 
41
  device=device, # Use selected device
42
  model_kwargs=model_kwargs,
43
  torch_dtype=dtype, # Set the floating point
@@ -49,14 +58,15 @@ def translate_text(task_choice, text_input, load_in_8bit, device):
49
  translation = translator(text_input)[0]['translation_text']
50
  return str(translation).strip()
51
 
52
- def launch(task_choice, text_input, load_in_8bit, device):
53
- return translate_text(task_choice, text_input, load_in_8bit, device)
54
 
 
55
  task_dropdown = gr.Dropdown(choices=list(translation_task_names.keys()), label='Select Translation Task')
56
  text_input = gr.Textbox(label="Input Text") # Single line text input
57
  load_in_8bit = gr.Checkbox(label="Load model in 8bit")
58
  device = gr.Radio(['cpu', 'cuda'], label='Select device', default='cpu')
59
 
60
- iface = gr.Interface(launch, inputs=[task_dropdown, text_input, load_in_8bit, device],
61
  outputs=gr.outputs.Textbox(type="text", label="Translation"))
62
  iface.launch()
 
24
  'Japanese to English': 'translation_ja_to_en',
25
  }
26
 
27
+ model_names = {
28
+ 'T5-Base': 't5-base',
29
+ 'T5-Small': 't5-small',
30
+ 'T5-Large': 't5-large',
31
+ 'Opus-zh-en': 'Helsinki-NLP/opus-mt-zh-en',
32
+ 'Opus-ru-en': 'Helsinki-NLP/opus-mt-ru-en'
33
+ }
34
+
35
  # Create a dictionary to store loaded models
36
  loaded_models = {}
37
 
38
  # Simple translation function
39
+ def translate_text(model_choice, task_choice, text_input, load_in_8bit, device):
40
+ model_key = (model_choice, task_choice, load_in_8bit) # Create a tuple to represent the unique combination of task and 8bit loading
41
 
42
  # Check if the model is already loaded
43
  if model_key in loaded_models:
 
46
  model_kwargs = {"load_in_8bit": load_in_8bit} if load_in_8bit else {}
47
  dtype = torch.float16 if load_in_8bit else torch.float32 # Set dtype based on the value of load_in_8bit
48
  translator = pipeline(task=translation_task_names[task_choice],
49
+ model=model_names[model_choice], # Use selected model
50
  device=device, # Use selected device
51
  model_kwargs=model_kwargs,
52
  torch_dtype=dtype, # Set the floating point
 
58
  translation = translator(text_input)[0]['translation_text']
59
  return str(translation).strip()
60
 
61
+ def launch(model_choice, task_choice, text_input, load_in_8bit, device):
62
+ return translate_text(model_choice, task_choice, text_input, load_in_8bit, device)
63
 
64
+ model_dropdown = gr.Dropdown(choices=list(model_names.keys()), label='Select Model')
65
  task_dropdown = gr.Dropdown(choices=list(translation_task_names.keys()), label='Select Translation Task')
66
  text_input = gr.Textbox(label="Input Text") # Single line text input
67
  load_in_8bit = gr.Checkbox(label="Load model in 8bit")
68
  device = gr.Radio(['cpu', 'cuda'], label='Select device', default='cpu')
69
 
70
+ iface = gr.Interface(launch, inputs=[model_dropdown, task_dropdown, text_input, load_in_8bit, device],
71
  outputs=gr.outputs.Textbox(type="text", label="Translation"))
72
  iface.launch()