|
from transformers import TranslationPipeline |
|
from transformers.pipelines.text2text_generation import ReturnType |
|
from transformers import BartForConditionalGeneration, BertTokenizer |
|
import logging |
|
import re |
|
|
|
|
|
def fix_chinese_text_generation_space(text): |
|
output_text = text |
|
output_text = re.sub( |
|
r'([\u3401-\u9FFF+——!,。?、~@#¥%…&*():;《)《》“”()»〔〕\-!$^*()_+|~=`{}\[\]:";\'<>?,.\/\\])\s([^0-9a-zA-Z])', r'\1\2', output_text) |
|
output_text = re.sub( |
|
r'([^0-9a-zA-Z])\s([\u3401-\u9FFF+——!,。?、~@#¥%…&*():;《)《》“”()»〔〕\-!$^*()_+|~=`{}\[\]:";\'<>?,.\/\\])', r'\1\2', output_text) |
|
output_text = re.sub( |
|
r'([\u3401-\u9FFF+——!,。?、~@#¥%…&*():;《)《》“”()»〔〕\-!$^*()_+|~=`{}\[\]:";\'<>?,.\/\\])\s([a-zA-Z0-9])', r'\1\2', output_text) |
|
output_text = re.sub( |
|
r'([a-zA-Z0-9])\s([\u3401-\u9FFF+——!,。?、~@#¥%…&*():;《)《》“”()»〔〕\-!$^*()_+|~=`{}\[\]:";\'<>?,.\/\\])', r'\1\2', output_text) |
|
output_text = re.sub(r'$\s([0-9])', r'$\1', output_text) |
|
output_text = re.sub(',', ',', output_text) |
|
output_text = re.sub(r'([0-9]),([0-9])', r'\1,\2', |
|
output_text) |
|
|
|
output_text = re.sub(r'\s?[,]+\s?', ',', output_text) |
|
output_text = re.sub(r'\s?[、]+\s?', '、', output_text) |
|
|
|
output_text = re.sub(r'\s?[。]+\s?', '。', output_text) |
|
|
|
output_text = re.sub(r'\s?\.{3,}\s?', '...', output_text) |
|
|
|
output_text = re.sub(r'\s?[!!]+\s?', '!', output_text) |
|
|
|
output_text = re.sub(r'\s?[??]+\s?', '?', output_text) |
|
|
|
output_text = re.sub(r'\s?[::]+\s?', ':', output_text) |
|
|
|
output_text = re.sub(r'\s?(["“”\']+)\s?', r'\1', output_text) |
|
|
|
output_text = re.sub(r'\s?[;;]+\s?', ';', output_text) |
|
|
|
output_text = re.sub(r'\s?([~●.…]+)\s?', r'\1', output_text) |
|
output_text = re.sub(r'\s?\[…\]\s?', '', output_text) |
|
output_text = re.sub(r'\s?\[\.\.\.\]\s?', '', output_text) |
|
output_text = re.sub(r'\s?\.{3,}\s?', '...', output_text) |
|
|
|
output_text = re.sub(r'\s?[//]+\s?', '/', output_text) |
|
|
|
output_text = re.sub(r'\s?[$$]+\s?', '$', output_text) |
|
|
|
output_text = re.sub(r'\s?([@@]+)\s?', '@', output_text) |
|
|
|
output_text = re.sub( |
|
r'\s?([\[\(<〖【「『()』」】〗>\)\]]+)\s?', r'\1', output_text) |
|
|
|
return output_text |
|
|
|
|
|
class BartPipeline(TranslationPipeline): |
|
def __init__(self, |
|
model_name_or_path: str = "indiejoseph/bart-base-cantonese", |
|
device=None, |
|
max_length=512, |
|
src_lang=None, |
|
tgt_lang=None): |
|
self.model_name_or_path = model_name_or_path |
|
self.tokenizer = self._load_tokenizer() |
|
self.model = self._load_model() |
|
self.model.eval() |
|
super().__init__(self.model, self.tokenizer, device=device, |
|
max_length=max_length, src_lang=src_lang, tgt_lang=tgt_lang) |
|
|
|
def _load_tokenizer(self): |
|
return BertTokenizer.from_pretrained(self.model_name_or_path) |
|
|
|
def _load_model(self): |
|
return BartForConditionalGeneration.from_pretrained(self.model_name_or_path) |
|
|
|
def postprocess( |
|
self, |
|
model_outputs, |
|
return_type=ReturnType.TEXT, |
|
clean_up_tokenization_spaces=True, |
|
): |
|
records = super().postprocess( |
|
model_outputs, |
|
return_type=return_type, |
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces, |
|
) |
|
for rec in records: |
|
translation_text = fix_chinese_text_generation_space( |
|
rec["translation_text"].strip()) |
|
|
|
rec["translation_text"] = translation_text |
|
return records |
|
|
|
|
|
if __name__ == '__main__': |
|
pipe = BartPipeline(device=0) |
|
|
|
print(pipe('哈哈,我正在努力研究緊個問題。不過,邊個知呢,可能哪一日我會諗到一個好主意去實現到佢。', max_length=100)) |
|
|