indiejoseph
commited on
Commit
•
0e5205e
1
Parent(s):
a99de3c
Create bart_pipeline.py
Browse files- bart_pipeline.py +99 -0
bart_pipeline.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import TranslationPipeline
|
2 |
+
from transformers.pipelines.text2text_generation import ReturnType
|
3 |
+
from transformers import BartForConditionalGeneration, BertTokenizer
|
4 |
+
import logging
|
5 |
+
import re
|
6 |
+
|
7 |
+
|
8 |
+
def fix_chinese_text_generation_space(text):
|
9 |
+
output_text = text
|
10 |
+
output_text = re.sub(
|
11 |
+
r'([\u3401-\u9FFF+——!,。?、~@#¥%…&*():;《)《》“”()»〔〕\-!$^*()_+|~=`{}\[\]:";\'<>?,.\/\\])\s([^0-9a-zA-Z])', r'\1\2', output_text)
|
12 |
+
output_text = re.sub(
|
13 |
+
r'([^0-9a-zA-Z])\s([\u3401-\u9FFF+——!,。?、~@#¥%…&*():;《)《》“”()»〔〕\-!$^*()_+|~=`{}\[\]:";\'<>?,.\/\\])', r'\1\2', output_text)
|
14 |
+
output_text = re.sub(
|
15 |
+
r'([\u3401-\u9FFF+——!,。?、~@#¥%…&*():;《)《》“”()»〔〕\-!$^*()_+|~=`{}\[\]:";\'<>?,.\/\\])\s([a-zA-Z0-9])', r'\1\2', output_text)
|
16 |
+
output_text = re.sub(
|
17 |
+
r'([a-zA-Z0-9])\s([\u3401-\u9FFF+——!,。?、~@#¥%…&*():;《)《》“”()»〔〕\-!$^*()_+|~=`{}\[\]:";\'<>?,.\/\\])', r'\1\2', output_text)
|
18 |
+
output_text = re.sub(r'$\s([0-9])', r'$\1', output_text)
|
19 |
+
output_text = re.sub(',', ',', output_text)
|
20 |
+
output_text = re.sub(r'([0-9]),([0-9])', r'\1,\2',
|
21 |
+
output_text) # fix comma in numbers
|
22 |
+
# fix multiple commas
|
23 |
+
output_text = re.sub(r'\s?[,]+\s?', ',', output_text)
|
24 |
+
output_text = re.sub(r'\s?[、]+\s?', '、', output_text)
|
25 |
+
# fix period
|
26 |
+
output_text = re.sub(r'\s?[。]+\s?', '。', output_text)
|
27 |
+
# fix ...
|
28 |
+
output_text = re.sub(r'\s?\.{3,}\s?', '...', output_text)
|
29 |
+
# fix exclamation mark
|
30 |
+
output_text = re.sub(r'\s?[!!]+\s?', '!', output_text)
|
31 |
+
# fix question mark
|
32 |
+
output_text = re.sub(r'\s?[??]+\s?', '?', output_text)
|
33 |
+
# fix colon
|
34 |
+
output_text = re.sub(r'\s?[::]+\s?', ':', output_text)
|
35 |
+
# fix quotation mark
|
36 |
+
output_text = re.sub(r'\s?(["“”\']+)\s?', r'\1', output_text)
|
37 |
+
# fix semicolon
|
38 |
+
output_text = re.sub(r'\s?[;;]+\s?', ';', output_text)
|
39 |
+
# fix dots
|
40 |
+
output_text = re.sub(r'\s?([~●.…]+)\s?', r'\1', output_text)
|
41 |
+
output_text = re.sub(r'\s?\[…\]\s?', '', output_text)
|
42 |
+
output_text = re.sub(r'\s?\[\.\.\.\]\s?', '', output_text)
|
43 |
+
output_text = re.sub(r'\s?\.{3,}\s?', '...', output_text)
|
44 |
+
# fix slash
|
45 |
+
output_text = re.sub(r'\s?[//]+\s?', '/', output_text)
|
46 |
+
# fix dollar sign
|
47 |
+
output_text = re.sub(r'\s?[$$]+\s?', '$', output_text)
|
48 |
+
# fix @
|
49 |
+
output_text = re.sub(r'\s?([@@]+)\s?', '@', output_text)
|
50 |
+
# fix baskets
|
51 |
+
output_text = re.sub(
|
52 |
+
r'\s?([\[\(<〖【「『()』」】〗>\)\]]+)\s?', r'\1', output_text)
|
53 |
+
|
54 |
+
return output_text
|
55 |
+
|
56 |
+
|
57 |
+
class BartPipeline(TranslationPipeline):
|
58 |
+
def __init__(self,
|
59 |
+
model_name_or_path: str = "indiejoseph/bart-base-cantonese",
|
60 |
+
device=None,
|
61 |
+
max_length=512,
|
62 |
+
src_lang=None,
|
63 |
+
tgt_lang=None):
|
64 |
+
self.model_name_or_path = model_name_or_path
|
65 |
+
self.tokenizer = self._load_tokenizer()
|
66 |
+
self.model = self._load_model()
|
67 |
+
self.model.eval()
|
68 |
+
super().__init__(self.model, self.tokenizer, device=device,
|
69 |
+
max_length=max_length, src_lang=src_lang, tgt_lang=tgt_lang)
|
70 |
+
|
71 |
+
def _load_tokenizer(self):
|
72 |
+
return BertTokenizer.from_pretrained(self.model_name_or_path)
|
73 |
+
|
74 |
+
def _load_model(self):
|
75 |
+
return BartForConditionalGeneration.from_pretrained(self.model_name_or_path)
|
76 |
+
|
77 |
+
def postprocess(
|
78 |
+
self,
|
79 |
+
model_outputs,
|
80 |
+
return_type=ReturnType.TEXT,
|
81 |
+
clean_up_tokenization_spaces=True,
|
82 |
+
):
|
83 |
+
records = super().postprocess(
|
84 |
+
model_outputs,
|
85 |
+
return_type=return_type,
|
86 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
87 |
+
)
|
88 |
+
for rec in records:
|
89 |
+
translation_text = fix_chinese_text_generation_space(
|
90 |
+
rec["translation_text"].strip())
|
91 |
+
|
92 |
+
rec["translation_text"] = translation_text
|
93 |
+
return records
|
94 |
+
|
95 |
+
|
96 |
+
if __name__ == '__main__':
|
97 |
+
pipe = BartPipeline(device=0)
|
98 |
+
|
99 |
+
print(pipe('哈哈,我正在努力研究緊個問題。不過,邊個知呢,可能哪一日我會諗到一個好主意去實現到佢。', max_length=100))
|