ldwang commited on
Commit
c005ac2
1 Parent(s): dfa3804

Upload predict_local.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. predict_local.py +122 -0
predict_local.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # !/usr/bin/env python
2
+ # -*- coding:utf-8 -*-
3
+
4
+ import os
5
+ import time
6
+ import jieba
7
+ import jsonlines
8
+ import codecs
9
+ import random
10
+ import fasttext
11
+
12
+ stopwords_set = set()
13
+ basedir = './stopwords/'
14
+
15
+ # 停用词文件
16
+ with open(basedir + 'baidu_stopwords.txt', 'r', encoding='utf-8') as infile:
17
+ for line in infile:
18
+ stopwords_set.add(line.strip())
19
+ with open(basedir + 'cn_stopwords.txt', 'r', encoding='utf-8') as infile:
20
+ for line in infile:
21
+ stopwords_set.add(line.strip())
22
+ with open(basedir + 'hit_stopwords.txt', 'r', encoding='utf-8') as infile:
23
+ for line in infile:
24
+ stopwords_set.add(line.strip())
25
+ with open(basedir + 'scu_stopwords.txt', 'r', encoding='utf-8') as infile:
26
+ for line in infile:
27
+ stopwords_set.add(line.strip())
28
+
29
+ def segment(text):
30
+ # 结巴分词
31
+ seg_text = jieba.cut(text.replace("\t", " ").replace("\n", " "))
32
+ outline = " ".join(seg_text)
33
+ outline = " ".join(outline.split())
34
+
35
+ # 去停用词与HTML标签
36
+ outline_list = outline.split(" ")
37
+ outline_list_filter = [item for item in outline_list if item not in stopwords_set]
38
+ outline = " ".join(outline_list_filter)
39
+
40
+ return outline
41
+
42
+ def predict_score(preds):
43
+ score_dict = {
44
+ '__label__': 0,
45
+ '__label__0': 0,
46
+ '__label__1': 1,
47
+ }
48
+
49
+ score_list = []
50
+ for l, s in zip(*preds):
51
+ score = 0
52
+ for _l, _s in zip(l, s):
53
+ score += score_dict[_l] * _s
54
+ score_list.append(float(score))
55
+ return score_list
56
+
57
+
58
+ if __name__ == "__main__":
59
+ import argparse
60
+ parser = argparse.ArgumentParser()
61
+ parser.add_argument('--fasttext-model-path', type=str, default="", help="file path", required=True)
62
+ parser.add_argument('--input-file-path', type=str, default="", help="file path", required=True)
63
+ parser.add_argument('--output-file-path', type=str, default="", help="file path", required=True)
64
+ parser.add_argument('--text-key', type=str, default="text", help="file path", required=False)
65
+ parser.add_argument('--output-key', type=str, default="score", help="file path", required=False)
66
+ parser.add_argument('--do-score-filter', action='store_true', default=False, help='do score filter or not', dest='do_score_filter')
67
+ parser.add_argument('--score-thres', type=float, default=0.1, help="score thres", required=False)
68
+ args = parser.parse_args()
69
+
70
+ model_dir = args.fasttext_model_path
71
+ model = fasttext.load_model(model_dir)
72
+
73
+ import jsonlines
74
+ file_path = args.input_file_path
75
+ output_file_path = args.output_file_path
76
+ writer = jsonlines.open(output_file_path, mode='w')
77
+
78
+ dir_path = None
79
+ if os.path.isdir(file_path):
80
+ dir_path = os.listdir(file_path)
81
+ else:
82
+ dir_path = [file_path]
83
+
84
+ lines = 0
85
+ filtered = 0
86
+ start_time = time.time()
87
+
88
+ for file_path in dir_path:
89
+ input_file = os.path.join(args.input_file_path, file_path)
90
+ with jsonlines.open(input_file) as reader:
91
+ for line in reader:
92
+ lines += 1
93
+ if lines % 1000 == 0:
94
+ end_time = time.time()
95
+ elapsed_time = end_time - start_time
96
+ samples_per_second = lines / elapsed_time
97
+ print(f"Processed {lines} lines in {elapsed_time:.2f} seconds.", flush=True)
98
+ print(f"Samples per second: {samples_per_second:.2f}.", flush=True)
99
+
100
+ if args.text_key not in line:
101
+ filtered += 1
102
+ continue
103
+ sentecnce = line[args.text_key]
104
+ outline = segment(sentecnce)
105
+
106
+ preds = model.predict([outline], k=-1)
107
+ score = predict_score(preds)
108
+ #print(preds, score)
109
+
110
+ line[args.output_key] = score[0]
111
+ # do filter
112
+ if args.do_score_filter and line[args.output_key] < args.score_thres:
113
+ filtered += 1
114
+ continue
115
+ writer.write(line)
116
+
117
+ end_time = time.time()
118
+ elapsed_time = end_time - start_time
119
+ samples_per_second = lines / elapsed_time
120
+ print(f"Processed {lines} lines in {elapsed_time:.2f} seconds, Filtered {filtered} samples.", flush=True)
121
+ print(f"Samples per second: {samples_per_second:.2f}.", flush=True)
122
+