File size: 4,308 Bytes
0e5205e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from transformers import TranslationPipeline
from transformers.pipelines.text2text_generation import ReturnType
from transformers import BartForConditionalGeneration, BertTokenizer
import logging
import re


def fix_chinese_text_generation_space(text):
    output_text = text
    output_text = re.sub(
        r'([\u3401-\u9FFF+——!,。?、~@#¥%…&*():;《)《》“”()»〔〕\-!$^*()_+|~=`{}\[\]:";\'<>?,.\/\\])\s([^0-9a-zA-Z])', r'\1\2', output_text)
    output_text = re.sub(
        r'([^0-9a-zA-Z])\s([\u3401-\u9FFF+——!,。?、~@#¥%…&*():;《)《》“”()»〔〕\-!$^*()_+|~=`{}\[\]:";\'<>?,.\/\\])', r'\1\2', output_text)
    output_text = re.sub(
        r'([\u3401-\u9FFF+——!,。?、~@#¥%…&*():;《)《》“”()»〔〕\-!$^*()_+|~=`{}\[\]:";\'<>?,.\/\\])\s([a-zA-Z0-9])', r'\1\2', output_text)
    output_text = re.sub(
        r'([a-zA-Z0-9])\s([\u3401-\u9FFF+——!,。?、~@#¥%…&*():;《)《》“”()»〔〕\-!$^*()_+|~=`{}\[\]:";\'<>?,.\/\\])', r'\1\2', output_text)
    output_text = re.sub(r'$\s([0-9])', r'$\1', output_text)
    output_text = re.sub(',', ',', output_text)
    output_text = re.sub(r'([0-9]),([0-9])', r'\1,\2',
                         output_text)  # fix comma in numbers
    # fix multiple commas
    output_text = re.sub(r'\s?[,]+\s?', ',', output_text)
    output_text = re.sub(r'\s?[、]+\s?', '、', output_text)
    # fix period
    output_text = re.sub(r'\s?[。]+\s?', '。', output_text)
    # fix ...
    output_text = re.sub(r'\s?\.{3,}\s?', '...', output_text)
    # fix exclamation mark
    output_text = re.sub(r'\s?[!!]+\s?', '!', output_text)
    # fix question mark
    output_text = re.sub(r'\s?[??]+\s?', '?', output_text)
    # fix colon
    output_text = re.sub(r'\s?[::]+\s?', ':', output_text)
    # fix quotation mark
    output_text = re.sub(r'\s?(["“”\']+)\s?', r'\1', output_text)
    # fix semicolon
    output_text = re.sub(r'\s?[;;]+\s?', ';', output_text)
    # fix dots
    output_text = re.sub(r'\s?([~●.…]+)\s?', r'\1', output_text)
    output_text = re.sub(r'\s?\[…\]\s?', '', output_text)
    output_text = re.sub(r'\s?\[\.\.\.\]\s?', '', output_text)
    output_text = re.sub(r'\s?\.{3,}\s?', '...', output_text)
    # fix slash
    output_text = re.sub(r'\s?[//]+\s?', '/', output_text)
    # fix dollar sign
    output_text = re.sub(r'\s?[$$]+\s?', '$', output_text)
    # fix @
    output_text = re.sub(r'\s?([@@]+)\s?', '@', output_text)
    # fix baskets
    output_text = re.sub(
        r'\s?([\[\(<〖【「『()』」】〗>\)\]]+)\s?', r'\1', output_text)

    return output_text


class BartPipeline(TranslationPipeline):
    def __init__(self,
                 model_name_or_path: str = "indiejoseph/bart-base-cantonese",
                 device=None,
                 max_length=512,
                 src_lang=None,
                 tgt_lang=None):
        self.model_name_or_path = model_name_or_path
        self.tokenizer = self._load_tokenizer()
        self.model = self._load_model()
        self.model.eval()
        super().__init__(self.model, self.tokenizer, device=device,
                         max_length=max_length, src_lang=src_lang, tgt_lang=tgt_lang)

    def _load_tokenizer(self):
        return BertTokenizer.from_pretrained(self.model_name_or_path)

    def _load_model(self):
        return BartForConditionalGeneration.from_pretrained(self.model_name_or_path)

    def postprocess(
        self,
        model_outputs,
        return_type=ReturnType.TEXT,
        clean_up_tokenization_spaces=True,
    ):
        records = super().postprocess(
            model_outputs,
            return_type=return_type,
            clean_up_tokenization_spaces=clean_up_tokenization_spaces,
        )
        for rec in records:
            translation_text = fix_chinese_text_generation_space(
                rec["translation_text"].strip())

            rec["translation_text"] = translation_text
        return records


if __name__ == '__main__':
    pipe = BartPipeline(device=0)

    print(pipe('哈哈,我正在努力研究緊個問題。不過,邊個知呢,可能哪一日我會諗到一個好主意去實現到佢。', max_length=100))