|
|
|
|
|
import argparse |
|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
|
|
from transformers import AutoTokenizer, AutoModelForTokenClassification |
|
|
|
|
|
PUNK_MAPPING = {'.': 'PERIOD', ',': 'COMMA', '?': 'QUESTION'} |
|
|
|
|
|
|
|
LABELS_CASE = ['LOWER', 'UPPER', 'UPPER_TOTAL'] |
|
|
|
LABELS_PUNC = ['O'] + list(PUNK_MAPPING.values()) |
|
|
|
|
|
LABELS_list = [] |
|
for case in LABELS_CASE: |
|
for punc in LABELS_PUNC: |
|
LABELS_list.append(f'{case}_{punc}') |
|
LABELS = {label: i+1 for i, label in enumerate(LABELS_list)} |
|
LABELS['O'] = -100 |
|
INVERSE_LABELS = {i: label for label, i in LABELS.items()} |
|
|
|
LABEL_TO_PUNC_LABEL = {label: label.split('_')[-1] for label in LABELS.keys() if label != 'O'} |
|
LABEL_TO_CASE_LABEL = {label: '_'.join(label.split('_')[:-1]) for label in LABELS.keys() if label != 'O'} |
|
|
|
|
|
def token_to_label(token, label): |
|
if type(label) == int: |
|
label = INVERSE_LABELS[label] |
|
if label == 'LOWER_O': |
|
return token |
|
if label == 'LOWER_PERIOD': |
|
return token + '.' |
|
if label == 'LOWER_COMMA': |
|
return token + ',' |
|
if label == 'LOWER_QUESTION': |
|
return token + '?' |
|
if label == 'UPPER_O': |
|
return token.capitalize() |
|
if label == 'UPPER_PERIOD': |
|
return token.capitalize() + '.' |
|
if label == 'UPPER_COMMA': |
|
return token.capitalize() + ',' |
|
if label == 'UPPER_QUESTION': |
|
return token.capitalize() + '?' |
|
if label == 'UPPER_TOTAL_O': |
|
return token.upper() |
|
if label == 'UPPER_TOTAL_PERIOD': |
|
return token.upper() + '.' |
|
if label == 'UPPER_TOTAL_COMMA': |
|
return token.upper() + ',' |
|
if label == 'UPPER_TOTAL_QUESTION': |
|
return token.upper() + '?' |
|
if label == 'O': |
|
return token |
|
|
|
|
|
def decode_label(label, classes='all'): |
|
if classes == 'punc': |
|
return LABEL_TO_PUNC_LABEL[INVERSE_LABELS[label]] |
|
if classes == 'case': |
|
return LABEL_TO_CASE_LABEL[INVERSE_LABELS[label]] |
|
else: |
|
return INVERSE_LABELS[label] |
|
|
|
|
|
MODEL_REPO = "kontur-ai/sbert-punc-case-ru" |
|
|
|
|
|
class SbertPuncCase(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO, |
|
revision="sbert", |
|
use_auth_token=True, |
|
strip_accents=False) |
|
self.model = AutoModelForTokenClassification.from_pretrained(MODEL_REPO, |
|
revision="sbert", |
|
use_auth_token=True |
|
) |
|
self.model.eval() |
|
|
|
def forward(self, input_ids, attention_mask): |
|
return self.model(input_ids=input_ids, |
|
attention_mask=attention_mask) |
|
|
|
def punctuate(self, text): |
|
text = text.strip().lower() |
|
|
|
|
|
words = text.split() |
|
|
|
tokenizer_output = self.tokenizer(words, is_split_into_words=True) |
|
|
|
if len(tokenizer_output.input_ids) > 512: |
|
return ' '.join([self.punctuate(' '.join(text_part)) for text_part in np.array_split(words, 2)]) |
|
|
|
predictions = self(torch.tensor([tokenizer_output.input_ids], device=self.model.device), |
|
torch.tensor([tokenizer_output.attention_mask], device=self.model.device)).logits.cpu().data.numpy() |
|
predictions = np.argmax(predictions, axis=2) |
|
|
|
|
|
splitted_text = [] |
|
word_ids = tokenizer_output.word_ids() |
|
for i, word in enumerate(words): |
|
label_pos = word_ids.index(i) |
|
label_id = predictions[0][label_pos] |
|
label = decode_label(label_id) |
|
splitted_text.append(token_to_label(word, label)) |
|
capitalized_text = ' '.join(splitted_text) |
|
return capitalized_text |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser("Punctuation and case restoration model sbert-punc-case-ru") |
|
parser.add_argument("-i", "--input", type=str, help="text to restore", default='sbert punc case расставляет точки запятые и знаки вопроса вам нравится') |
|
parser.add_argument("-d", "--device", type=str, help="run model on cpu or gpu", choices=['cpu', 'cuda'], default='cpu') |
|
args = parser.parse_args() |
|
print(f"Source text: {args.input}\n") |
|
sbertpunc = SbertPuncCase().to(args.device) |
|
punctuated_text = sbertpunc.punctuate(args.input) |
|
print(f"Restored text: {punctuated_text}") |