breadlicker45 commited on
Commit
a655994
·
verified ·
1 Parent(s): 867aa8c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -23
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
3
- import tiktoken
 
4
  # Load the model and tokenizer
5
  model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
6
  tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
@@ -35,35 +36,51 @@ language_codes = {
35
  }
36
 
37
  def translate(text, src_lang, tgt_lang):
38
- # Set the source language
39
- tokenizer.src_lang = language_codes[src_lang]
40
-
41
- # Tokenize the input text
42
- encoded = tokenizer(text, return_tensors="pt")
43
-
44
- # Generate translation
45
- generated_tokens = model.generate(
46
- **encoded,
47
- forced_bos_token_id=tokenizer.lang_code_to_id[language_codes[tgt_lang]]
48
- )
49
-
50
- # Decode the generated tokens
51
- translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
52
-
53
- return translation
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  # Create the Gradio interface
56
- iface = gr.Interface(
57
  fn=translate,
58
  inputs=[
59
- gr.Textbox(label="Input Text"),
60
- gr.Dropdown(choices=list(language_codes.keys()), label="Source Language"),
61
- gr.Dropdown(choices=list(language_codes.keys()), label="Target Language"),
62
  ],
63
  outputs=gr.Textbox(label="Translated Text"),
64
  title="Multilingual Translation with MBart",
65
  description="Translate text between multiple languages using the MBart model.",
 
 
 
 
66
  )
67
 
68
- # Launch the interface
69
- iface.launch()
 
1
  import gradio as gr
2
  from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
3
+ import torch
4
+
5
  # Load the model and tokenizer
6
  model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
7
  tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
 
36
  }
37
 
38
  def translate(text, src_lang, tgt_lang):
39
+ try:
40
+ if not text.strip():
41
+ return "Please enter some text to translate."
42
+
43
+ if src_lang == tgt_lang:
44
+ return text
45
+
46
+ # Set the source language
47
+ tokenizer.src_lang = language_codes[src_lang]
48
+
49
+ # Tokenize the input text
50
+ encoded = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
51
+
52
+ # Generate translation
53
+ with torch.no_grad():
54
+ generated_tokens = model.generate(
55
+ **encoded,
56
+ forced_bos_token_id=tokenizer.lang_code_to_id[language_codes[tgt_lang]],
57
+ max_length=512,
58
+ num_beams=5,
59
+ length_penalty=1.0
60
+ )
61
+
62
+ # Decode the generated tokens
63
+ translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
64
+
65
+ return translation
66
+ except Exception as e:
67
+ return f"Translation error: {str(e)}"
68
 
69
  # Create the Gradio interface
70
+ demo = gr.Interface(
71
  fn=translate,
72
  inputs=[
73
+ gr.Textbox(label="Input Text", placeholder="Enter text to translate..."),
74
+ gr.Dropdown(choices=sorted(language_codes.keys()), label="Source Language", value="English"),
75
+ gr.Dropdown(choices=sorted(language_codes.keys()), label="Target Language", value="Spanish"),
76
  ],
77
  outputs=gr.Textbox(label="Translated Text"),
78
  title="Multilingual Translation with MBart",
79
  description="Translate text between multiple languages using the MBart model.",
80
+ examples=[
81
+ ["Hello, how are you?", "English", "Spanish"],
82
+ ["Bonjour, comment allez-vous?", "French", "English"],
83
+ ]
84
  )
85
 
86
+ demo.launch()