File size: 5,229 Bytes
946dc24 a45aa62 946dc24 a45aa62 946dc24 2b720c7 946dc24 2b720c7 946dc24 2b720c7 946dc24 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
# -*- coding: utf-8 -*-
import argparse
import torch
import torch.nn as nn
import numpy as np
from transformers import AutoTokenizer, AutoModelForTokenClassification
# Прогнозируемые знаки препинания
PUNK_MAPPING = {'.': 'PERIOD', ',': 'COMMA', '?': 'QUESTION'}
# Прогнозируемый регистр LOWER - нижний регистр, UPPER - верхний регистр для первого символа,
# UPPER_TOTAL - верхний регистр для всех символов
LABELS_CASE = ['LOWER', 'UPPER', 'UPPER_TOTAL']
# Добавим в пунктуацию метку O означающий отсутсвие пунктуации
LABELS_PUNC = ['O'] + list(PUNK_MAPPING.values())
# Сформируем метки на основе комбинаций регистра и пунктуации
LABELS_list = []
for case in LABELS_CASE:
for punc in LABELS_PUNC:
LABELS_list.append(f'{case}_{punc}')
LABELS = {label: i+1 for i, label in enumerate(LABELS_list)}
LABELS['O'] = -100
INVERSE_LABELS = {i: label for label, i in LABELS.items()}
LABEL_TO_PUNC_LABEL = {label: label.split('_')[-1] for label in LABELS.keys() if label != 'O'}
LABEL_TO_CASE_LABEL = {label: '_'.join(label.split('_')[:-1]) for label in LABELS.keys() if label != 'O'}
def token_to_label(token, label):
if type(label) == int:
label = INVERSE_LABELS[label]
if label == 'LOWER_O':
return token
if label == 'LOWER_PERIOD':
return token + '.'
if label == 'LOWER_COMMA':
return token + ','
if label == 'LOWER_QUESTION':
return token + '?'
if label == 'UPPER_O':
return token.capitalize()
if label == 'UPPER_PERIOD':
return token.capitalize() + '.'
if label == 'UPPER_COMMA':
return token.capitalize() + ','
if label == 'UPPER_QUESTION':
return token.capitalize() + '?'
if label == 'UPPER_TOTAL_O':
return token.upper()
if label == 'UPPER_TOTAL_PERIOD':
return token.upper() + '.'
if label == 'UPPER_TOTAL_COMMA':
return token.upper() + ','
if label == 'UPPER_TOTAL_QUESTION':
return token.upper() + '?'
if label == 'O':
return token
def decode_label(label, classes='all'):
if classes == 'punc':
return LABEL_TO_PUNC_LABEL[INVERSE_LABELS[label]]
if classes == 'case':
return LABEL_TO_CASE_LABEL[INVERSE_LABELS[label]]
else:
return INVERSE_LABELS[label]
MODEL_REPO = "kontur-ai/sbert-punc-case-ru"
class SbertPuncCase(nn.Module):
def __init__(self):
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO,
revision="sbert",
use_auth_token=True,
strip_accents=False)
self.model = AutoModelForTokenClassification.from_pretrained(MODEL_REPO,
revision="sbert",
use_auth_token=True
)
self.model.eval()
def forward(self, input_ids, attention_mask):
return self.model(input_ids=input_ids,
attention_mask=attention_mask)
def punctuate(self, text):
text = text.strip().lower()
# Разобъем предложение на слова
words = text.split()
tokenizer_output = self.tokenizer(words, is_split_into_words=True)
if len(tokenizer_output.input_ids) > 512:
return ' '.join([self.punctuate(' '.join(text_part)) for text_part in np.array_split(words, 2)])
predictions = self(torch.tensor([tokenizer_output.input_ids], device=self.model.device),
torch.tensor([tokenizer_output.attention_mask], device=self.model.device)).logits.cpu().data.numpy()
predictions = np.argmax(predictions, axis=2)
# decode punctuation and casing
splitted_text = []
word_ids = tokenizer_output.word_ids()
for i, word in enumerate(words):
label_pos = word_ids.index(i)
label_id = predictions[0][label_pos]
label = decode_label(label_id)
splitted_text.append(token_to_label(word, label))
capitalized_text = ' '.join(splitted_text)
return capitalized_text
if __name__ == '__main__':
parser = argparse.ArgumentParser("Punctuation and case restoration model sbert-punc-case-ru")
parser.add_argument("-i", "--input", type=str, help="text to restore", default='sbert punc case расставляет точки запятые и знаки вопроса вам нравится')
parser.add_argument("-d", "--device", type=str, help="run model on cpu or gpu", choices=['cpu', 'cuda'], default='cpu')
args = parser.parse_args()
print(f"Source text: {args.input}\n")
sbertpunc = SbertPuncCase().to(args.device)
punctuated_text = sbertpunc.punctuate(args.input)
print(f"Restored text: {punctuated_text}") |