a.korepanov
commited on
Commit
•
a614a1a
1
Parent(s):
044cecb
some formatting
Browse files- sbert_punc_case_ru/sbertpunccase.py +77 -47
- setup.py +22 -17
sbert_punc_case_ru/sbertpunccase.py
CHANGED
@@ -8,62 +8,66 @@ import numpy as np
|
|
8 |
from transformers import AutoTokenizer, AutoModelForTokenClassification
|
9 |
|
10 |
# Прогнозируемые знаки препинания
|
11 |
-
PUNK_MAPPING = {
|
12 |
|
13 |
# Прогнозируемый регистр LOWER - нижний регистр, UPPER - верхний регистр для первого символа,
|
14 |
# UPPER_TOTAL - верхний регистр для всех символов
|
15 |
-
LABELS_CASE = [
|
16 |
# Добавим в пунктуацию метку O означающий отсутсвие пунктуации
|
17 |
-
LABELS_PUNC = [
|
18 |
|
19 |
# Сформируем метки на основе комбинаций регистра и пунктуации
|
20 |
LABELS_list = []
|
21 |
for case in LABELS_CASE:
|
22 |
for punc in LABELS_PUNC:
|
23 |
-
LABELS_list.append(f
|
24 |
-
LABELS = {label: i+1 for i, label in enumerate(LABELS_list)}
|
25 |
-
LABELS[
|
26 |
INVERSE_LABELS = {i: label for label, i in LABELS.items()}
|
27 |
|
28 |
-
LABEL_TO_PUNC_LABEL = {
|
29 |
-
|
|
|
|
|
|
|
|
|
30 |
|
31 |
|
32 |
def token_to_label(token, label):
|
33 |
if type(label) == int:
|
34 |
label = INVERSE_LABELS[label]
|
35 |
-
if label ==
|
36 |
return token
|
37 |
-
if label ==
|
38 |
-
return token +
|
39 |
-
if label ==
|
40 |
-
return token +
|
41 |
-
if label ==
|
42 |
-
return token +
|
43 |
-
if label ==
|
44 |
return token.capitalize()
|
45 |
-
if label ==
|
46 |
-
return token.capitalize() +
|
47 |
-
if label ==
|
48 |
-
return token.capitalize() +
|
49 |
-
if label ==
|
50 |
-
return token.capitalize() +
|
51 |
-
if label ==
|
52 |
return token.upper()
|
53 |
-
if label ==
|
54 |
-
return token.upper() +
|
55 |
-
if label ==
|
56 |
-
return token.upper() +
|
57 |
-
if label ==
|
58 |
-
return token.upper() +
|
59 |
-
if label ==
|
60 |
return token
|
61 |
|
62 |
|
63 |
-
def decode_label(label, classes=
|
64 |
-
if classes ==
|
65 |
return LABEL_TO_PUNC_LABEL[INVERSE_LABELS[label]]
|
66 |
-
if classes ==
|
67 |
return LABEL_TO_CASE_LABEL[INVERSE_LABELS[label]]
|
68 |
else:
|
69 |
return INVERSE_LABELS[label]
|
@@ -76,14 +80,12 @@ class SbertPuncCase(nn.Module):
|
|
76 |
def __init__(self):
|
77 |
super().__init__()
|
78 |
|
79 |
-
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO,
|
80 |
-
strip_accents=False)
|
81 |
self.model = AutoModelForTokenClassification.from_pretrained(MODEL_REPO)
|
82 |
self.model.eval()
|
83 |
|
84 |
def forward(self, input_ids, attention_mask):
|
85 |
-
return self.model(input_ids=input_ids,
|
86 |
-
attention_mask=attention_mask)
|
87 |
|
88 |
def punctuate(self, text):
|
89 |
text = text.strip().lower()
|
@@ -94,10 +96,23 @@ class SbertPuncCase(nn.Module):
|
|
94 |
tokenizer_output = self.tokenizer(words, is_split_into_words=True)
|
95 |
|
96 |
if len(tokenizer_output.input_ids) > 512:
|
97 |
-
return
|
98 |
-
|
99 |
-
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
predictions = np.argmax(predictions, axis=2)
|
102 |
|
103 |
# decode punctuation and casing
|
@@ -108,16 +123,31 @@ class SbertPuncCase(nn.Module):
|
|
108 |
label_id = predictions[0][label_pos]
|
109 |
label = decode_label(label_id)
|
110 |
splitted_text.append(token_to_label(word, label))
|
111 |
-
capitalized_text =
|
112 |
return capitalized_text
|
113 |
|
114 |
|
115 |
-
if __name__ ==
|
116 |
-
parser = argparse.ArgumentParser(
|
117 |
-
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
args = parser.parse_args()
|
120 |
print(f"Source text: {args.input}\n")
|
121 |
sbertpunc = SbertPuncCase().to(args.device)
|
122 |
punctuated_text = sbertpunc.punctuate(args.input)
|
123 |
-
print(f"Restored text: {punctuated_text}")
|
|
|
8 |
from transformers import AutoTokenizer, AutoModelForTokenClassification
|
9 |
|
10 |
# Прогнозируемые знаки препинания
|
11 |
+
PUNK_MAPPING = {".": "PERIOD", ",": "COMMA", "?": "QUESTION"}
|
12 |
|
13 |
# Прогнозируемый регистр LOWER - нижний регистр, UPPER - верхний регистр для первого символа,
|
14 |
# UPPER_TOTAL - верхний регистр для всех символов
|
15 |
+
LABELS_CASE = ["LOWER", "UPPER", "UPPER_TOTAL"]
|
16 |
# Добавим в пунктуацию метку O означающий отсутсвие пунктуации
|
17 |
+
LABELS_PUNC = ["O"] + list(PUNK_MAPPING.values())
|
18 |
|
19 |
# Сформируем метки на основе комбинаций регистра и пунктуации
|
20 |
LABELS_list = []
|
21 |
for case in LABELS_CASE:
|
22 |
for punc in LABELS_PUNC:
|
23 |
+
LABELS_list.append(f"{case}_{punc}")
|
24 |
+
LABELS = {label: i + 1 for i, label in enumerate(LABELS_list)}
|
25 |
+
LABELS["O"] = -100
|
26 |
INVERSE_LABELS = {i: label for label, i in LABELS.items()}
|
27 |
|
28 |
+
LABEL_TO_PUNC_LABEL = {
|
29 |
+
label: label.split("_")[-1] for label in LABELS.keys() if label != "O"
|
30 |
+
}
|
31 |
+
LABEL_TO_CASE_LABEL = {
|
32 |
+
label: "_".join(label.split("_")[:-1]) for label in LABELS.keys() if label != "O"
|
33 |
+
}
|
34 |
|
35 |
|
36 |
def token_to_label(token, label):
|
37 |
if type(label) == int:
|
38 |
label = INVERSE_LABELS[label]
|
39 |
+
if label == "LOWER_O":
|
40 |
return token
|
41 |
+
if label == "LOWER_PERIOD":
|
42 |
+
return token + "."
|
43 |
+
if label == "LOWER_COMMA":
|
44 |
+
return token + ","
|
45 |
+
if label == "LOWER_QUESTION":
|
46 |
+
return token + "?"
|
47 |
+
if label == "UPPER_O":
|
48 |
return token.capitalize()
|
49 |
+
if label == "UPPER_PERIOD":
|
50 |
+
return token.capitalize() + "."
|
51 |
+
if label == "UPPER_COMMA":
|
52 |
+
return token.capitalize() + ","
|
53 |
+
if label == "UPPER_QUESTION":
|
54 |
+
return token.capitalize() + "?"
|
55 |
+
if label == "UPPER_TOTAL_O":
|
56 |
return token.upper()
|
57 |
+
if label == "UPPER_TOTAL_PERIOD":
|
58 |
+
return token.upper() + "."
|
59 |
+
if label == "UPPER_TOTAL_COMMA":
|
60 |
+
return token.upper() + ","
|
61 |
+
if label == "UPPER_TOTAL_QUESTION":
|
62 |
+
return token.upper() + "?"
|
63 |
+
if label == "O":
|
64 |
return token
|
65 |
|
66 |
|
67 |
+
def decode_label(label, classes="all"):
|
68 |
+
if classes == "punc":
|
69 |
return LABEL_TO_PUNC_LABEL[INVERSE_LABELS[label]]
|
70 |
+
if classes == "case":
|
71 |
return LABEL_TO_CASE_LABEL[INVERSE_LABELS[label]]
|
72 |
else:
|
73 |
return INVERSE_LABELS[label]
|
|
|
80 |
def __init__(self):
|
81 |
super().__init__()
|
82 |
|
83 |
+
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO, strip_accents=False)
|
|
|
84 |
self.model = AutoModelForTokenClassification.from_pretrained(MODEL_REPO)
|
85 |
self.model.eval()
|
86 |
|
87 |
def forward(self, input_ids, attention_mask):
|
88 |
+
return self.model(input_ids=input_ids, attention_mask=attention_mask)
|
|
|
89 |
|
90 |
def punctuate(self, text):
|
91 |
text = text.strip().lower()
|
|
|
96 |
tokenizer_output = self.tokenizer(words, is_split_into_words=True)
|
97 |
|
98 |
if len(tokenizer_output.input_ids) > 512:
|
99 |
+
return " ".join(
|
100 |
+
[
|
101 |
+
self.punctuate(" ".join(text_part))
|
102 |
+
for text_part in np.array_split(words, 2)
|
103 |
+
]
|
104 |
+
)
|
105 |
+
|
106 |
+
predictions = (
|
107 |
+
self(
|
108 |
+
torch.tensor([tokenizer_output.input_ids], device=self.model.device),
|
109 |
+
torch.tensor(
|
110 |
+
[tokenizer_output.attention_mask], device=self.model.device
|
111 |
+
),
|
112 |
+
)
|
113 |
+
.logits.cpu()
|
114 |
+
.data.numpy()
|
115 |
+
)
|
116 |
predictions = np.argmax(predictions, axis=2)
|
117 |
|
118 |
# decode punctuation and casing
|
|
|
123 |
label_id = predictions[0][label_pos]
|
124 |
label = decode_label(label_id)
|
125 |
splitted_text.append(token_to_label(word, label))
|
126 |
+
capitalized_text = " ".join(splitted_text)
|
127 |
return capitalized_text
|
128 |
|
129 |
|
130 |
+
if __name__ == "__main__":
|
131 |
+
parser = argparse.ArgumentParser(
|
132 |
+
"Punctuation and case restoration model sbert_punc_case_ru"
|
133 |
+
)
|
134 |
+
parser.add_argument(
|
135 |
+
"-i",
|
136 |
+
"--input",
|
137 |
+
type=str,
|
138 |
+
help="text to restore",
|
139 |
+
default="sbert punc case расставляет точки запятые и знаки вопроса вам нравится",
|
140 |
+
)
|
141 |
+
parser.add_argument(
|
142 |
+
"-d",
|
143 |
+
"--device",
|
144 |
+
type=str,
|
145 |
+
help="run model on cpu or gpu",
|
146 |
+
choices=["cpu", "cuda"],
|
147 |
+
default="cpu",
|
148 |
+
)
|
149 |
args = parser.parse_args()
|
150 |
print(f"Source text: {args.input}\n")
|
151 |
sbertpunc = SbertPuncCase().to(args.device)
|
152 |
punctuated_text = sbertpunc.punctuate(args.input)
|
153 |
+
print(f"Restored text: {punctuated_text}")
|
setup.py
CHANGED
@@ -1,19 +1,24 @@
|
|
1 |
from distutils.core import setup
|
2 |
|
3 |
-
setup(
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from distutils.core import setup
|
2 |
|
3 |
+
setup(
|
4 |
+
name="sbert_punc_case_ru",
|
5 |
+
version="0.2",
|
6 |
+
description="Punctuation and Case Restoration model based on https://huggingface.co/sberbank-ai/sbert_large_nlu_ru",
|
7 |
+
author="Almira Murtazina",
|
8 |
+
author_email="ar.murtazina@skbkontur.ru",
|
9 |
+
packages=["sbert_punc_case_ru"],
|
10 |
+
install_requires=[
|
11 |
+
"transformers>=4.36.2",
|
12 |
+
"torch",
|
13 |
+
"numpy"
|
14 |
+
],
|
15 |
+
classifiers=[
|
16 |
+
"Operating System :: OS Independent",
|
17 |
+
"Programming Language :: Python :: 3",
|
18 |
+
"Programming Language :: Python :: 3.6",
|
19 |
+
"Programming Language :: Python :: 3.7",
|
20 |
+
"Programming Language :: Python :: 3.8",
|
21 |
+
"Programming Language :: Python :: 3.9",
|
22 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
23 |
+
],
|
24 |
+
)
|