File size: 5,229 Bytes
946dc24
 
 
 
 
 
 
 
 
 
 
 
a45aa62
 
946dc24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a45aa62
 
946dc24
 
 
2b720c7
946dc24
 
2b720c7
 
946dc24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b720c7
946dc24
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# -*- coding: utf-8 -*-

import argparse
import torch
import torch.nn as nn
import numpy as np

from transformers import AutoTokenizer, AutoModelForTokenClassification

# Прогнозируемые знаки препинания
PUNK_MAPPING = {'.': 'PERIOD', ',': 'COMMA', '?': 'QUESTION'}

# Прогнозируемый регистр LOWER - нижний регистр, UPPER - верхний регистр для первого символа,
# UPPER_TOTAL - верхний регистр для всех символов
LABELS_CASE = ['LOWER', 'UPPER', 'UPPER_TOTAL']
# Добавим в пунктуацию метку O означающий отсутсвие пунктуации
LABELS_PUNC = ['O'] + list(PUNK_MAPPING.values())

# Сформируем метки на основе комбинаций регистра и пунктуации
LABELS_list = []
for case in LABELS_CASE:
    for punc in LABELS_PUNC:
        LABELS_list.append(f'{case}_{punc}')
LABELS = {label: i+1 for i, label in enumerate(LABELS_list)}
LABELS['O'] = -100
INVERSE_LABELS = {i: label for label, i in LABELS.items()}

LABEL_TO_PUNC_LABEL = {label: label.split('_')[-1] for label in LABELS.keys() if label != 'O'}
LABEL_TO_CASE_LABEL = {label: '_'.join(label.split('_')[:-1]) for label in LABELS.keys() if label != 'O'}


def token_to_label(token, label):
    if type(label) == int:
        label = INVERSE_LABELS[label]
    if label == 'LOWER_O':
        return token
    if label == 'LOWER_PERIOD':
        return token + '.'
    if label == 'LOWER_COMMA':
        return token + ','
    if label == 'LOWER_QUESTION':
        return token + '?'
    if label == 'UPPER_O':
        return token.capitalize()
    if label == 'UPPER_PERIOD':
        return token.capitalize() + '.'
    if label == 'UPPER_COMMA':
        return token.capitalize() + ','
    if label == 'UPPER_QUESTION':
        return token.capitalize() + '?'
    if label == 'UPPER_TOTAL_O':
        return token.upper()
    if label == 'UPPER_TOTAL_PERIOD':
        return token.upper() + '.'
    if label == 'UPPER_TOTAL_COMMA':
        return token.upper() + ','
    if label == 'UPPER_TOTAL_QUESTION':
        return token.upper() + '?'
    if label == 'O':
        return token


def decode_label(label, classes='all'):
    if classes == 'punc':
        return LABEL_TO_PUNC_LABEL[INVERSE_LABELS[label]]
    if classes == 'case':
        return LABEL_TO_CASE_LABEL[INVERSE_LABELS[label]]
    else:
        return INVERSE_LABELS[label]


MODEL_REPO = "kontur-ai/sbert-punc-case-ru"


class SbertPuncCase(nn.Module):
    def __init__(self):
        super().__init__()

        self.tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO,
                                                       revision="sbert",
                                                       use_auth_token=True,
                                                       strip_accents=False)
        self.model = AutoModelForTokenClassification.from_pretrained(MODEL_REPO,
                                                                     revision="sbert",
                                                                     use_auth_token=True
                                                                     )
        self.model.eval()

    def forward(self, input_ids, attention_mask):
        return self.model(input_ids=input_ids,
                          attention_mask=attention_mask)

    def punctuate(self, text):
        text = text.strip().lower()

        # Разобъем предложение на слова
        words = text.split()

        tokenizer_output = self.tokenizer(words, is_split_into_words=True)

        if len(tokenizer_output.input_ids) > 512:
            return ' '.join([self.punctuate(' '.join(text_part)) for text_part in np.array_split(words, 2)])

        predictions = self(torch.tensor([tokenizer_output.input_ids], device=self.model.device),
                           torch.tensor([tokenizer_output.attention_mask], device=self.model.device)).logits.cpu().data.numpy()
        predictions = np.argmax(predictions, axis=2)

        # decode punctuation and casing
        splitted_text = []
        word_ids = tokenizer_output.word_ids()
        for i, word in enumerate(words):
            label_pos = word_ids.index(i)
            label_id = predictions[0][label_pos]
            label = decode_label(label_id)
            splitted_text.append(token_to_label(word, label))
        capitalized_text = ' '.join(splitted_text)
        return capitalized_text


if __name__ == '__main__':
    parser = argparse.ArgumentParser("Punctuation and case restoration model sbert-punc-case-ru")
    parser.add_argument("-i", "--input", type=str, help="text to restore", default='sbert punc case расставляет точки запятые и знаки вопроса вам нравится')
    parser.add_argument("-d", "--device", type=str, help="run model on cpu or gpu", choices=['cpu', 'cuda'], default='cpu')
    args = parser.parse_args()
    print(f"Source text:   {args.input}\n")
    sbertpunc = SbertPuncCase().to(args.device)
    punctuated_text = sbertpunc.punctuate(args.input)
    print(f"Restored text: {punctuated_text}")