|
--- |
|
tags: |
|
- ner |
|
- punctuation |
|
language: |
|
- zh |
|
--- |
|
# zh-wiki-punctuation-restore |
|
More Detail: https://github.com/p208p2002/ZH-Punctuation-Restore |
|
共計支援6種標點符號: , 、 。 ? ! ; |
|
|
|
## Install |
|
```bash |
|
# pip install torch pytorch-lightning |
|
pip install zhpr |
|
``` |
|
|
|
## Usage |
|
```python |
|
from zhpr.predict import DocumentDataset,merge_stride,decode_pred |
|
from transformers import AutoModelForTokenClassification,AutoTokenizer |
|
from torch.utils.data import DataLoader |
|
|
|
def predict_step(batch,model,tokenizer): |
|
batch_out = [] |
|
batch_input_ids = batch |
|
|
|
encodings = {'input_ids': batch_input_ids} |
|
output = model(**encodings) |
|
|
|
predicted_token_class_id_batch = output['logits'].argmax(-1) |
|
for predicted_token_class_ids, input_ids in zip(predicted_token_class_id_batch, batch_input_ids): |
|
out=[] |
|
tokens = tokenizer.convert_ids_to_tokens(input_ids) |
|
|
|
# compute the pad start in input_ids |
|
# and also truncate the predict |
|
# print(tokenizer.decode(batch_input_ids)) |
|
input_ids = input_ids.tolist() |
|
try: |
|
input_id_pad_start = input_ids.index(tokenizer.pad_token_id) |
|
except: |
|
input_id_pad_start = len(input_ids) |
|
input_ids = input_ids[:input_id_pad_start] |
|
tokens = tokens[:input_id_pad_start] |
|
|
|
# predicted_token_class_ids |
|
predicted_tokens_classes = [model.config.id2label[t.item()] for t in predicted_token_class_ids] |
|
predicted_tokens_classes = predicted_tokens_classes[:input_id_pad_start] |
|
|
|
for token,ner in zip(tokens,predicted_tokens_classes): |
|
out.append((token,ner)) |
|
batch_out.append(out) |
|
return batch_out |
|
|
|
if __name__ == "__main__": |
|
window_size = 256 |
|
step = 200 |
|
text = "維基百科是維基媒體基金會運營的一個多語言的百科全書目前是全球網路上最大且最受大眾歡迎的參考工具書名列全球二十大最受歡迎的網站特點是自由內容自由編輯與自由著作權" |
|
dataset = DocumentDataset(text,window_size=window_size,step=step) |
|
dataloader = DataLoader(dataset=dataset,shuffle=False,batch_size=5) |
|
|
|
model_name = 'p208p2002/zh-wiki-punctuation-restore' |
|
model = AutoModelForTokenClassification.from_pretrained(model_name) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
model_pred_out = [] |
|
for batch in dataloader: |
|
batch_out = predict_step(batch,model,tokenizer) |
|
for out in batch_out: |
|
model_pred_out.append(out) |
|
|
|
merge_pred_result = merge_stride(model_pred_out,step) |
|
merge_pred_result_deocde = decode_pred(merge_pred_result) |
|
merge_pred_result_deocde = ''.join(merge_pred_result_deocde) |
|
print(merge_pred_result_deocde) |
|
``` |
|
``` |
|
維基百科是維基媒體基金會運營的一個多語言的百科全書,目前是全球網路上最大且最受大眾歡迎的參考工具書,名列全球二十大最受歡迎的網站,特點是自由內容、自由編輯與自由著作權。 |
|
``` |