p208p2002 commited on
Commit
ac30e47
·
1 Parent(s): 55395c7

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +17 -12
README.md CHANGED
@@ -22,18 +22,20 @@ from transformers import AutoModelForTokenClassification,AutoTokenizer
22
  from torch.utils.data import DataLoader
23
 
24
  def predict_step(batch,model,tokenizer):
25
- assert batch.shape[0]==1
26
- out = []
27
- input_ids = batch
28
- encodings = {'input_ids': input_ids}
29
  output = model(**encodings)
30
 
31
  predicted_token_class_id_batch = output['logits'].argmax(-1)
32
- for predicted_token_class_ids, input_ids in zip(predicted_token_class_id_batch, input_ids):
 
33
  tokens = tokenizer.convert_ids_to_tokens(input_ids)
34
 
35
  # compute the pad start in input_ids
36
  # and also truncate the predict
 
37
  input_ids = input_ids.tolist()
38
  try:
39
  input_id_pad_start = input_ids.index(tokenizer.pad_token_id)
@@ -48,14 +50,15 @@ def predict_step(batch,model,tokenizer):
48
 
49
  for token,ner in zip(tokens,predicted_tokens_classes):
50
  out.append((token,ner))
51
- return out
 
52
 
53
  if __name__ == "__main__":
54
- window_size = 100
55
- step = 75
56
- text = "維基百科是維基媒體基金會運營的一個多語言的線上百科全書並以建立和維護作為開放式協同合作專案特點是自由內容自由編輯自由著作權目前是全球網路上最大且最受大眾歡迎的參考工具書名列全球二十大最受歡迎的網站其在搜尋引擎中排名亦較為靠前維基百科目前由非營利組織維基媒體基金會負責營運"
57
  dataset = DocumentDataset(text,window_size=window_size,step=step)
58
- dataloader = DataLoader(dataset=dataset,shuffle=False,batch_size=1)
59
 
60
  model_name = 'p208p2002/zh-wiki-punctuation-restore'
61
  model = AutoModelForTokenClassification.from_pretrained(model_name)
@@ -63,7 +66,9 @@ if __name__ == "__main__":
63
 
64
  model_pred_out = []
65
  for batch in dataloader:
66
- model_pred_out.append(predict_step(batch,model,tokenizer))
 
 
67
 
68
  merge_pred_result = merge_stride(model_pred_out,step)
69
  merge_pred_result_deocde = decode_pred(merge_pred_result)
@@ -71,5 +76,5 @@ if __name__ == "__main__":
71
  print(merge_pred_result_deocde)
72
  ```
73
  ```
74
- 維基百科是維基媒體基金會運營的一個多語言的線上百科全書,並以建立和維護作為開放式協同合作。專案特點是自由內容、自由編輯、自由著作權。目前是全球網路上最大且最受大眾歡迎的參考工具書,名列全球二十大最受歡迎的網站,其在搜尋引擎中排名亦較為靠前。維基百科目前由非營利組織維基媒體基金會負責營運。
75
  ```
 
22
  from torch.utils.data import DataLoader
23
 
24
  def predict_step(batch,model,tokenizer):
25
+ batch_out = []
26
+ batch_input_ids = batch
27
+
28
+ encodings = {'input_ids': batch_input_ids}
29
  output = model(**encodings)
30
 
31
  predicted_token_class_id_batch = output['logits'].argmax(-1)
32
+ for predicted_token_class_ids, input_ids in zip(predicted_token_class_id_batch, batch_input_ids):
33
+ out=[]
34
  tokens = tokenizer.convert_ids_to_tokens(input_ids)
35
 
36
  # compute the pad start in input_ids
37
  # and also truncate the predict
38
+ # print(tokenizer.decode(batch_input_ids))
39
  input_ids = input_ids.tolist()
40
  try:
41
  input_id_pad_start = input_ids.index(tokenizer.pad_token_id)
 
50
 
51
  for token,ner in zip(tokens,predicted_tokens_classes):
52
  out.append((token,ner))
53
+ batch_out.append(out)
54
+ return batch_out
55
 
56
  if __name__ == "__main__":
57
+ window_size = 256
58
+ step = 200
59
+ text = "維基百科是維基媒體基金會運營的一個多語言的百科全書特點是自由內容自由編輯自由著作權目前是全球網路上最大且最受大眾歡迎的參考工具書名列全球二十大最受歡迎的網站其在搜尋引擎中排名亦較為靠前維基百科目前由非營利組織維基媒體基金會負責營運"
60
  dataset = DocumentDataset(text,window_size=window_size,step=step)
61
+ dataloader = DataLoader(dataset=dataset,shuffle=False,batch_size=5)
62
 
63
  model_name = 'p208p2002/zh-wiki-punctuation-restore'
64
  model = AutoModelForTokenClassification.from_pretrained(model_name)
 
66
 
67
  model_pred_out = []
68
  for batch in dataloader:
69
+ batch_out = predict_step(batch,model,tokenizer)
70
+ for out in batch_out:
71
+ model_pred_out.append(out)
72
 
73
  merge_pred_result = merge_stride(model_pred_out,step)
74
  merge_pred_result_deocde = decode_pred(merge_pred_result)
 
76
  print(merge_pred_result_deocde)
77
  ```
78
  ```
79
+ 維基百科是維基媒體基金會運營的一個多語言的百科全書,特點是自由、內容自由、編輯自由著作權,目前是全球網路上最大且最受大眾歡迎的參考工具書,名列全球二十大最受歡迎的網站。其在搜尋引擎中排名亦較為靠前。維基百科目前由非營利組織維基媒體基金會負責營運。
80
  ```