|
|
|
import torch |
|
from transformers import ( |
|
BertTokenizer, |
|
BertForMaskedLM, |
|
AutoModelForMaskedLM, |
|
AutoTokenizer, |
|
BertModel, |
|
) |
|
import numpy as np |
|
import random |
|
from itertools import islice |
|
from torch.utils.data import Dataset, DataLoader |
|
from torch.optim import AdamW |
|
from tqdm.auto import tqdm |
|
import os |
|
|
|
model_name = "tohoku-nlp/bert-base-japanese-char-v3" |
|
tokenizer = BertTokenizer.from_pretrained(model_name) |
|
base_model = BertModel.from_pretrained(model_name) |
|
|
|
|
|
class punctuation_predictor(torch.nn.Module): |
|
def __init__(self, base_model): |
|
super().__init__() |
|
self.base_model = base_model |
|
self.dropout = torch.nn.Dropout(0.2) |
|
self.linear = torch.nn.Linear(768, 2) |
|
|
|
def forward(self, input_ids, attention_mask): |
|
last_hidden_state = self.base_model( |
|
input_ids=input_ids, attention_mask=attention_mask |
|
).last_hidden_state |
|
|
|
return self.linear(self.dropout(last_hidden_state)) |
|
|
|
|
|
model = punctuation_predictor(base_model) |
|
model.load_state_dict(torch.load("weight/punctuation_position_model.pth")) |
|
model.eval() |
|
|
|
|
|
def insert_punctuation(input, comma_pos, period_pos): |
|
text = [] |
|
for i, (c, p) in enumerate(zip(comma_pos, period_pos)): |
|
token_id = input[i].item() |
|
if token_id > 5: |
|
if i < len(input) - 1: |
|
if p: |
|
text.append(tokenizer.ids_to_tokens[input[i].item()] + "。") |
|
elif c: |
|
text.append(tokenizer.ids_to_tokens[input[i].item()] + "、") |
|
else: |
|
text.append(tokenizer.ids_to_tokens[input[i].item()]) |
|
else: |
|
break |
|
return "".join(text) |
|
|
|
|
|
def process_long_text(text, max_length=256, comma_thresh=0.1, period_thresh=0.1): |
|
text = text.replace("、", "").replace("。", "") |
|
result = "" |
|
for i in range(0, len(text), max_length): |
|
no_punctuation_text = text[i : i + max_length] |
|
inputs = tokenizer( |
|
" ".join(list(no_punctuation_text)), |
|
max_length=512, |
|
padding="max_length", |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
|
|
output = model(inputs.input_ids, inputs.attention_mask) |
|
output = torch.sigmoid(output) |
|
comma_pos = output[0].detach().numpy().T[0] > comma_thresh |
|
period_pos = output[0].detach().numpy().T[1] > period_thresh |
|
result += insert_punctuation(inputs.input_ids[0], comma_pos, period_pos) |
|
return result |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
print( |
|
process_long_text( |
|
"女は昨夕艶めかしい姿をして彼の浴室の戸を開けた人に違なかった風呂場で彼を驚ろかした大きな髷をいつの間にか崩して尋常の束髪に結い更えたので彼はつい同じ人と気がつかずにいた彼はさらに声を聴いただけで顔を知らなかった伴の男の方をよそながらの初対面といった風に女と眺め比べた", |
|
) |
|
) |
|
|