File size: 4,308 Bytes
0e5205e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
from transformers import TranslationPipeline
from transformers.pipelines.text2text_generation import ReturnType
from transformers import BartForConditionalGeneration, BertTokenizer
import logging
import re
def fix_chinese_text_generation_space(text):
output_text = text
output_text = re.sub(
r'([\u3401-\u9FFF+——!,。?、~@#¥%…&*():;《)《》“”()»〔〕\-!$^*()_+|~=`{}\[\]:";\'<>?,.\/\\])\s([^0-9a-zA-Z])', r'\1\2', output_text)
output_text = re.sub(
r'([^0-9a-zA-Z])\s([\u3401-\u9FFF+——!,。?、~@#¥%…&*():;《)《》“”()»〔〕\-!$^*()_+|~=`{}\[\]:";\'<>?,.\/\\])', r'\1\2', output_text)
output_text = re.sub(
r'([\u3401-\u9FFF+——!,。?、~@#¥%…&*():;《)《》“”()»〔〕\-!$^*()_+|~=`{}\[\]:";\'<>?,.\/\\])\s([a-zA-Z0-9])', r'\1\2', output_text)
output_text = re.sub(
r'([a-zA-Z0-9])\s([\u3401-\u9FFF+——!,。?、~@#¥%…&*():;《)《》“”()»〔〕\-!$^*()_+|~=`{}\[\]:";\'<>?,.\/\\])', r'\1\2', output_text)
output_text = re.sub(r'$\s([0-9])', r'$\1', output_text)
output_text = re.sub(',', ',', output_text)
output_text = re.sub(r'([0-9]),([0-9])', r'\1,\2',
output_text) # fix comma in numbers
# fix multiple commas
output_text = re.sub(r'\s?[,]+\s?', ',', output_text)
output_text = re.sub(r'\s?[、]+\s?', '、', output_text)
# fix period
output_text = re.sub(r'\s?[。]+\s?', '。', output_text)
# fix ...
output_text = re.sub(r'\s?\.{3,}\s?', '...', output_text)
# fix exclamation mark
output_text = re.sub(r'\s?[!!]+\s?', '!', output_text)
# fix question mark
output_text = re.sub(r'\s?[??]+\s?', '?', output_text)
# fix colon
output_text = re.sub(r'\s?[::]+\s?', ':', output_text)
# fix quotation mark
output_text = re.sub(r'\s?(["“”\']+)\s?', r'\1', output_text)
# fix semicolon
output_text = re.sub(r'\s?[;;]+\s?', ';', output_text)
# fix dots
output_text = re.sub(r'\s?([~●.…]+)\s?', r'\1', output_text)
output_text = re.sub(r'\s?\[…\]\s?', '', output_text)
output_text = re.sub(r'\s?\[\.\.\.\]\s?', '', output_text)
output_text = re.sub(r'\s?\.{3,}\s?', '...', output_text)
# fix slash
output_text = re.sub(r'\s?[//]+\s?', '/', output_text)
# fix dollar sign
output_text = re.sub(r'\s?[$$]+\s?', '$', output_text)
# fix @
output_text = re.sub(r'\s?([@@]+)\s?', '@', output_text)
# fix baskets
output_text = re.sub(
r'\s?([\[\(<〖【「『()』」】〗>\)\]]+)\s?', r'\1', output_text)
return output_text
class BartPipeline(TranslationPipeline):
def __init__(self,
model_name_or_path: str = "indiejoseph/bart-base-cantonese",
device=None,
max_length=512,
src_lang=None,
tgt_lang=None):
self.model_name_or_path = model_name_or_path
self.tokenizer = self._load_tokenizer()
self.model = self._load_model()
self.model.eval()
super().__init__(self.model, self.tokenizer, device=device,
max_length=max_length, src_lang=src_lang, tgt_lang=tgt_lang)
def _load_tokenizer(self):
return BertTokenizer.from_pretrained(self.model_name_or_path)
def _load_model(self):
return BartForConditionalGeneration.from_pretrained(self.model_name_or_path)
def postprocess(
self,
model_outputs,
return_type=ReturnType.TEXT,
clean_up_tokenization_spaces=True,
):
records = super().postprocess(
model_outputs,
return_type=return_type,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
for rec in records:
translation_text = fix_chinese_text_generation_space(
rec["translation_text"].strip())
rec["translation_text"] = translation_text
return records
if __name__ == '__main__':
pipe = BartPipeline(device=0)
print(pipe('哈哈,我正在努力研究緊個問題。不過,邊個知呢,可能哪一日我會諗到一個好主意去實現到佢。', max_length=100))
|