Upload predict_local.py with huggingface_hub
Browse files- predict_local.py +122 -0
predict_local.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# !/usr/bin/env python
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
|
4 |
+
import os
|
5 |
+
import time
|
6 |
+
import jieba
|
7 |
+
import jsonlines
|
8 |
+
import codecs
|
9 |
+
import random
|
10 |
+
import fasttext
|
11 |
+
|
12 |
+
stopwords_set = set()
|
13 |
+
basedir = './stopwords/'
|
14 |
+
|
15 |
+
# 停用词文件
|
16 |
+
with open(basedir + 'baidu_stopwords.txt', 'r', encoding='utf-8') as infile:
|
17 |
+
for line in infile:
|
18 |
+
stopwords_set.add(line.strip())
|
19 |
+
with open(basedir + 'cn_stopwords.txt', 'r', encoding='utf-8') as infile:
|
20 |
+
for line in infile:
|
21 |
+
stopwords_set.add(line.strip())
|
22 |
+
with open(basedir + 'hit_stopwords.txt', 'r', encoding='utf-8') as infile:
|
23 |
+
for line in infile:
|
24 |
+
stopwords_set.add(line.strip())
|
25 |
+
with open(basedir + 'scu_stopwords.txt', 'r', encoding='utf-8') as infile:
|
26 |
+
for line in infile:
|
27 |
+
stopwords_set.add(line.strip())
|
28 |
+
|
29 |
+
def segment(text):
|
30 |
+
# 结巴分词
|
31 |
+
seg_text = jieba.cut(text.replace("\t", " ").replace("\n", " "))
|
32 |
+
outline = " ".join(seg_text)
|
33 |
+
outline = " ".join(outline.split())
|
34 |
+
|
35 |
+
# 去停用词与HTML标签
|
36 |
+
outline_list = outline.split(" ")
|
37 |
+
outline_list_filter = [item for item in outline_list if item not in stopwords_set]
|
38 |
+
outline = " ".join(outline_list_filter)
|
39 |
+
|
40 |
+
return outline
|
41 |
+
|
42 |
+
def predict_score(preds):
|
43 |
+
score_dict = {
|
44 |
+
'__label__': 0,
|
45 |
+
'__label__0': 0,
|
46 |
+
'__label__1': 1,
|
47 |
+
}
|
48 |
+
|
49 |
+
score_list = []
|
50 |
+
for l, s in zip(*preds):
|
51 |
+
score = 0
|
52 |
+
for _l, _s in zip(l, s):
|
53 |
+
score += score_dict[_l] * _s
|
54 |
+
score_list.append(float(score))
|
55 |
+
return score_list
|
56 |
+
|
57 |
+
|
58 |
+
if __name__ == "__main__":
|
59 |
+
import argparse
|
60 |
+
parser = argparse.ArgumentParser()
|
61 |
+
parser.add_argument('--fasttext-model-path', type=str, default="", help="file path", required=True)
|
62 |
+
parser.add_argument('--input-file-path', type=str, default="", help="file path", required=True)
|
63 |
+
parser.add_argument('--output-file-path', type=str, default="", help="file path", required=True)
|
64 |
+
parser.add_argument('--text-key', type=str, default="text", help="file path", required=False)
|
65 |
+
parser.add_argument('--output-key', type=str, default="score", help="file path", required=False)
|
66 |
+
parser.add_argument('--do-score-filter', action='store_true', default=False, help='do score filter or not', dest='do_score_filter')
|
67 |
+
parser.add_argument('--score-thres', type=float, default=0.1, help="score thres", required=False)
|
68 |
+
args = parser.parse_args()
|
69 |
+
|
70 |
+
model_dir = args.fasttext_model_path
|
71 |
+
model = fasttext.load_model(model_dir)
|
72 |
+
|
73 |
+
import jsonlines
|
74 |
+
file_path = args.input_file_path
|
75 |
+
output_file_path = args.output_file_path
|
76 |
+
writer = jsonlines.open(output_file_path, mode='w')
|
77 |
+
|
78 |
+
dir_path = None
|
79 |
+
if os.path.isdir(file_path):
|
80 |
+
dir_path = os.listdir(file_path)
|
81 |
+
else:
|
82 |
+
dir_path = [file_path]
|
83 |
+
|
84 |
+
lines = 0
|
85 |
+
filtered = 0
|
86 |
+
start_time = time.time()
|
87 |
+
|
88 |
+
for file_path in dir_path:
|
89 |
+
input_file = os.path.join(args.input_file_path, file_path)
|
90 |
+
with jsonlines.open(input_file) as reader:
|
91 |
+
for line in reader:
|
92 |
+
lines += 1
|
93 |
+
if lines % 1000 == 0:
|
94 |
+
end_time = time.time()
|
95 |
+
elapsed_time = end_time - start_time
|
96 |
+
samples_per_second = lines / elapsed_time
|
97 |
+
print(f"Processed {lines} lines in {elapsed_time:.2f} seconds.", flush=True)
|
98 |
+
print(f"Samples per second: {samples_per_second:.2f}.", flush=True)
|
99 |
+
|
100 |
+
if args.text_key not in line:
|
101 |
+
filtered += 1
|
102 |
+
continue
|
103 |
+
sentecnce = line[args.text_key]
|
104 |
+
outline = segment(sentecnce)
|
105 |
+
|
106 |
+
preds = model.predict([outline], k=-1)
|
107 |
+
score = predict_score(preds)
|
108 |
+
#print(preds, score)
|
109 |
+
|
110 |
+
line[args.output_key] = score[0]
|
111 |
+
# do filter
|
112 |
+
if args.do_score_filter and line[args.output_key] < args.score_thres:
|
113 |
+
filtered += 1
|
114 |
+
continue
|
115 |
+
writer.write(line)
|
116 |
+
|
117 |
+
end_time = time.time()
|
118 |
+
elapsed_time = end_time - start_time
|
119 |
+
samples_per_second = lines / elapsed_time
|
120 |
+
print(f"Processed {lines} lines in {elapsed_time:.2f} seconds, Filtered {filtered} samples.", flush=True)
|
121 |
+
print(f"Samples per second: {samples_per_second:.2f}.", flush=True)
|
122 |
+
|