# %% import torch from transformers import ( BertTokenizer, BertForMaskedLM, AutoModelForMaskedLM, AutoTokenizer, BertModel, ) import numpy as np import random from itertools import islice from torch.utils.data import Dataset, DataLoader from torch.optim import AdamW from tqdm.auto import tqdm import os model_name = "tohoku-nlp/bert-base-japanese-char-v3" tokenizer = BertTokenizer.from_pretrained(model_name) base_model = BertModel.from_pretrained(model_name) class punctuation_predictor(torch.nn.Module): def __init__(self, base_model): super().__init__() self.base_model = base_model self.dropout = torch.nn.Dropout(0.2) self.linear = torch.nn.Linear(768, 2) def forward(self, input_ids, attention_mask): last_hidden_state = self.base_model( input_ids=input_ids, attention_mask=attention_mask ).last_hidden_state # get last hidden state token by token and apply linear layer return self.linear(self.dropout(last_hidden_state)) model = punctuation_predictor(base_model) model.load_state_dict(torch.load("weight/punctuation_position_model.pth")) model.eval() def insert_punctuation(input, comma_pos, period_pos): text = [] for i, (c, p) in enumerate(zip(comma_pos, period_pos)): token_id = input[i].item() if token_id > 5: if i < len(input) - 1: if p: text.append(tokenizer.ids_to_tokens[input[i].item()] + "。") elif c: text.append(tokenizer.ids_to_tokens[input[i].item()] + "、") else: text.append(tokenizer.ids_to_tokens[input[i].item()]) else: break return "".join(text) def process_long_text(text, max_length=256, comma_thresh=0.1, period_thresh=0.1): text = text.replace("、", "").replace("。", "") result = "" for i in range(0, len(text), max_length): no_punctuation_text = text[i : i + max_length] inputs = tokenizer( " ".join(list(no_punctuation_text)), max_length=512, padding="max_length", truncation=True, return_tensors="pt", ) output = model(inputs.input_ids, inputs.attention_mask) output = torch.sigmoid(output) comma_pos = output[0].detach().numpy().T[0] > comma_thresh period_pos = output[0].detach().numpy().T[1] > period_thresh result += insert_punctuation(inputs.input_ids[0], comma_pos, period_pos) return result # %% if __name__ == "__main__": print( process_long_text( "女は昨夕艶めかしい姿をして彼の浴室の戸を開けた人に違なかった風呂場で彼を驚ろかした大きな髷をいつの間にか崩して尋常の束髪に結い更えたので彼はつい同じ人と気がつかずにいた彼はさらに声を聴いただけで顔を知らなかった伴の男の方をよそながらの初対面といった風に女と眺め比べた", ) )