qgyd2021 commited on
Commit
22ed4b7
1 Parent(s): 27d3705

[update]add sent_tokenize

Browse files
Files changed (2) hide show
  1. main.py +13 -6
  2. requirements.txt +1 -0
main.py CHANGED
@@ -10,6 +10,7 @@ hf_hub_cache = (project_path / "cache/huggingface/hub").as_posix()
10
  os.environ["HUGGINGFACE_HUB_CACHE"] = hf_hub_cache
11
 
12
  import gradio as gr
 
13
  from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
14
 
15
 
@@ -41,13 +42,19 @@ def main():
41
  tokenizer = model_group["tokenizer"]
42
 
43
  tokenizer.src_lang = src_lang
44
- encoded_src = tokenizer(src_text, return_tensors="pt")
45
- generated_tokens = model.generate(**encoded_src,
46
- forced_bos_token_id=tokenizer.get_lang_id(tgt_lang),
47
- )
48
- result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
49
 
50
- return result[0]
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  title = "Multilingual Machine Translation"
53
 
 
10
  os.environ["HUGGINGFACE_HUB_CACHE"] = hf_hub_cache
11
 
12
  import gradio as gr
13
+ import nltk
14
  from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
15
 
16
 
 
42
  tokenizer = model_group["tokenizer"]
43
 
44
  tokenizer.src_lang = src_lang
 
 
 
 
 
45
 
46
+ src_t_list = nltk.sent_tokenize(src_text)
47
+
48
+ result = ""
49
+ for src_t in src_t_list:
50
+ encoded_src = tokenizer(src_t, return_tensors="pt")
51
+ generated_tokens = model.generate(**encoded_src,
52
+ forced_bos_token_id=tokenizer.get_lang_id(tgt_lang),
53
+ )
54
+ text_decoded = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
55
+ result += text_decoded[0]
56
+
57
+ return result
58
 
59
  title = "Multilingual Machine Translation"
60
 
requirements.txt CHANGED
@@ -2,3 +2,4 @@ gradio==3.20.1
2
  transformers==4.30.2
3
  torch==1.13.1
4
  sentencepiece==0.1.99
 
 
2
  transformers==4.30.2
3
  torch==1.13.1
4
  sentencepiece==0.1.99
5
+ nltk==3.8.1