metadata
tags:
- ner
- punctuation
language:
- zh
zh-wiki-punctuation-restore
More Detail: https://github.com/p208p2002/ZH-Punctuation-Restore 共計支援6種標點符號: , 、 。 ? ! ;
Install
# pip install torch pytorch-lightning
pip install zhpr
Usage
from zhpr.predict import DocumentDataset,merge_stride,decode_pred
from transformers import AutoModelForTokenClassification,AutoTokenizer
from torch.utils.data import DataLoader
def predict_step(batch,model,tokenizer):
batch_out = []
batch_input_ids = batch
encodings = {'input_ids': batch_input_ids}
output = model(**encodings)
predicted_token_class_id_batch = output['logits'].argmax(-1)
for predicted_token_class_ids, input_ids in zip(predicted_token_class_id_batch, batch_input_ids):
out=[]
tokens = tokenizer.convert_ids_to_tokens(input_ids)
# compute the pad start in input_ids
# and also truncate the predict
# print(tokenizer.decode(batch_input_ids))
input_ids = input_ids.tolist()
try:
input_id_pad_start = input_ids.index(tokenizer.pad_token_id)
except:
input_id_pad_start = len(input_ids)
input_ids = input_ids[:input_id_pad_start]
tokens = tokens[:input_id_pad_start]
# predicted_token_class_ids
predicted_tokens_classes = [model.config.id2label[t.item()] for t in predicted_token_class_ids]
predicted_tokens_classes = predicted_tokens_classes[:input_id_pad_start]
for token,ner in zip(tokens,predicted_tokens_classes):
out.append((token,ner))
batch_out.append(out)
return batch_out
if __name__ == "__main__":
window_size = 256
step = 200
text = "維基百科是維基媒體基金會運營的一個多語言的百科全書目前是全球網路上最大且最受大眾歡迎的參考工具書名列全球二十大最受歡迎的網站特點是自由內容自由編輯與自由著作權"
dataset = DocumentDataset(text,window_size=window_size,step=step)
dataloader = DataLoader(dataset=dataset,shuffle=False,batch_size=5)
model_name = 'p208p2002/zh-wiki-punctuation-restore'
model = AutoModelForTokenClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model_pred_out = []
for batch in dataloader:
batch_out = predict_step(batch,model,tokenizer)
for out in batch_out:
model_pred_out.append(out)
merge_pred_result = merge_stride(model_pred_out,step)
merge_pred_result_deocde = decode_pred(merge_pred_result)
merge_pred_result_deocde = ''.join(merge_pred_result_deocde)
print(merge_pred_result_deocde)
維基百科是維基媒體基金會運營的一個多語言的百科全書,目前是全球網路上最大且最受大眾歡迎的參考工具書,名列全球二十大最受歡迎的網站,特點是自由內容、自由編輯與自由著作權。