Update README.md
Browse files
README.md
CHANGED
@@ -22,18 +22,20 @@ from transformers import AutoModelForTokenClassification,AutoTokenizer
|
|
22 |
from torch.utils.data import DataLoader
|
23 |
|
24 |
def predict_step(batch,model,tokenizer):
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
encodings = {'input_ids':
|
29 |
output = model(**encodings)
|
30 |
|
31 |
predicted_token_class_id_batch = output['logits'].argmax(-1)
|
32 |
-
for predicted_token_class_ids, input_ids in zip(predicted_token_class_id_batch,
|
|
|
33 |
tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
34 |
|
35 |
# compute the pad start in input_ids
|
36 |
# and also truncate the predict
|
|
|
37 |
input_ids = input_ids.tolist()
|
38 |
try:
|
39 |
input_id_pad_start = input_ids.index(tokenizer.pad_token_id)
|
@@ -48,14 +50,15 @@ def predict_step(batch,model,tokenizer):
|
|
48 |
|
49 |
for token,ner in zip(tokens,predicted_tokens_classes):
|
50 |
out.append((token,ner))
|
51 |
-
|
|
|
52 |
|
53 |
if __name__ == "__main__":
|
54 |
-
window_size =
|
55 |
-
step =
|
56 |
-
text = "
|
57 |
dataset = DocumentDataset(text,window_size=window_size,step=step)
|
58 |
-
dataloader = DataLoader(dataset=dataset,shuffle=False,batch_size=
|
59 |
|
60 |
model_name = 'p208p2002/zh-wiki-punctuation-restore'
|
61 |
model = AutoModelForTokenClassification.from_pretrained(model_name)
|
@@ -63,7 +66,9 @@ if __name__ == "__main__":
|
|
63 |
|
64 |
model_pred_out = []
|
65 |
for batch in dataloader:
|
66 |
-
|
|
|
|
|
67 |
|
68 |
merge_pred_result = merge_stride(model_pred_out,step)
|
69 |
merge_pred_result_deocde = decode_pred(merge_pred_result)
|
@@ -71,5 +76,5 @@ if __name__ == "__main__":
|
|
71 |
print(merge_pred_result_deocde)
|
72 |
```
|
73 |
```
|
74 |
-
|
75 |
```
|
|
|
22 |
from torch.utils.data import DataLoader
|
23 |
|
24 |
def predict_step(batch,model,tokenizer):
|
25 |
+
batch_out = []
|
26 |
+
batch_input_ids = batch
|
27 |
+
|
28 |
+
encodings = {'input_ids': batch_input_ids}
|
29 |
output = model(**encodings)
|
30 |
|
31 |
predicted_token_class_id_batch = output['logits'].argmax(-1)
|
32 |
+
for predicted_token_class_ids, input_ids in zip(predicted_token_class_id_batch, batch_input_ids):
|
33 |
+
out=[]
|
34 |
tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
35 |
|
36 |
# compute the pad start in input_ids
|
37 |
# and also truncate the predict
|
38 |
+
# print(tokenizer.decode(batch_input_ids))
|
39 |
input_ids = input_ids.tolist()
|
40 |
try:
|
41 |
input_id_pad_start = input_ids.index(tokenizer.pad_token_id)
|
|
|
50 |
|
51 |
for token,ner in zip(tokens,predicted_tokens_classes):
|
52 |
out.append((token,ner))
|
53 |
+
batch_out.append(out)
|
54 |
+
return batch_out
|
55 |
|
56 |
if __name__ == "__main__":
|
57 |
+
window_size = 256
|
58 |
+
step = 200
|
59 |
+
text = "維基百科是維基媒體基金會運營的一個多語言的百科全書特點是自由內容自由編輯自由著作權目前是全球網路上最大且最受大眾歡迎的參考工具書名列全球二十大最受歡迎的網站其在搜尋引擎中排名亦較為靠前維基百科目前由非營利組織維基媒體基金會負責營運"
|
60 |
dataset = DocumentDataset(text,window_size=window_size,step=step)
|
61 |
+
dataloader = DataLoader(dataset=dataset,shuffle=False,batch_size=5)
|
62 |
|
63 |
model_name = 'p208p2002/zh-wiki-punctuation-restore'
|
64 |
model = AutoModelForTokenClassification.from_pretrained(model_name)
|
|
|
66 |
|
67 |
model_pred_out = []
|
68 |
for batch in dataloader:
|
69 |
+
batch_out = predict_step(batch,model,tokenizer)
|
70 |
+
for out in batch_out:
|
71 |
+
model_pred_out.append(out)
|
72 |
|
73 |
merge_pred_result = merge_stride(model_pred_out,step)
|
74 |
merge_pred_result_deocde = decode_pred(merge_pred_result)
|
|
|
76 |
print(merge_pred_result_deocde)
|
77 |
```
|
78 |
```
|
79 |
+
維基百科是維基媒體基金會運營的一個多語言的百科全書,特點是自由、內容自由、編輯自由著作權,目前是全球網路上最大且最受大眾歡迎的參考工具書,名列全球二十大最受歡迎的網站。其在搜尋引擎中排名亦較為靠前。維基百科目前由非營利組織維基媒體基金會負責營運。
|
80 |
```
|