File size: 3,990 Bytes
c005ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# !/usr/bin/env python
# -*- coding:utf-8 -*-

import os
import time
import jieba
import jsonlines
import codecs
import random
import fasttext

stopwords_set = set()
basedir = './stopwords/'

# 停用词文件
with open(basedir + 'baidu_stopwords.txt', 'r', encoding='utf-8') as infile:
    for line in infile:
        stopwords_set.add(line.strip())
with open(basedir + 'cn_stopwords.txt', 'r', encoding='utf-8') as infile:
    for line in infile:
        stopwords_set.add(line.strip())
with open(basedir + 'hit_stopwords.txt', 'r', encoding='utf-8') as infile:
    for line in infile:
        stopwords_set.add(line.strip())
with open(basedir + 'scu_stopwords.txt', 'r', encoding='utf-8') as infile:
    for line in infile:
        stopwords_set.add(line.strip())

def segment(text):
  # 结巴分词
  seg_text = jieba.cut(text.replace("\t", " ").replace("\n", " "))
  outline = " ".join(seg_text)
  outline = " ".join(outline.split())

  # 去停用词与HTML标签
  outline_list = outline.split(" ")
  outline_list_filter = [item for item in outline_list if item not in stopwords_set]
  outline = " ".join(outline_list_filter)

  return outline

def predict_score(preds):
  score_dict = {
    '__label__': 0, 
    '__label__0': 0, 
    '__label__1': 1,
  }

  score_list = []
  for l, s in zip(*preds):
    score = 0
    for _l, _s in zip(l, s):
      score += score_dict[_l] * _s
    score_list.append(float(score))
  return score_list


if __name__ == "__main__":
  import argparse
  parser = argparse.ArgumentParser()
  parser.add_argument('--fasttext-model-path', type=str, default="", help="file path", required=True)
  parser.add_argument('--input-file-path', type=str, default="", help="file path", required=True)
  parser.add_argument('--output-file-path', type=str, default="", help="file path", required=True)
  parser.add_argument('--text-key', type=str, default="text", help="file path", required=False)
  parser.add_argument('--output-key', type=str, default="score", help="file path", required=False)
  parser.add_argument('--do-score-filter', action='store_true', default=False, help='do score filter or not', dest='do_score_filter')
  parser.add_argument('--score-thres', type=float, default=0.1, help="score thres", required=False)
  args = parser.parse_args()

  model_dir = args.fasttext_model_path
  model = fasttext.load_model(model_dir)

  import jsonlines
  file_path = args.input_file_path
  output_file_path = args.output_file_path
  writer = jsonlines.open(output_file_path, mode='w')

  dir_path = None
  if os.path.isdir(file_path):
    dir_path = os.listdir(file_path)
  else:
    dir_path = [file_path]

  lines = 0
  filtered = 0
  start_time = time.time()

  for file_path in dir_path:
    input_file = os.path.join(args.input_file_path, file_path)
    with jsonlines.open(input_file) as reader:
      for line in reader:
        lines += 1
        if lines % 1000 == 0:
          end_time = time.time()
          elapsed_time = end_time - start_time
          samples_per_second = lines / elapsed_time
          print(f"Processed {lines} lines in {elapsed_time:.2f} seconds.", flush=True)
          print(f"Samples per second: {samples_per_second:.2f}.", flush=True)

        if args.text_key not in line:
          filtered += 1
          continue
        sentecnce = line[args.text_key]
        outline = segment(sentecnce)

        preds = model.predict([outline], k=-1)
        score = predict_score(preds)
        #print(preds, score)
  
        line[args.output_key] = score[0]
        # do filter
        if args.do_score_filter and line[args.output_key] < args.score_thres:
          filtered += 1
          continue
        writer.write(line)

    end_time = time.time()
    elapsed_time = end_time - start_time
    samples_per_second = lines / elapsed_time
    print(f"Processed {lines} lines in {elapsed_time:.2f} seconds, Filtered {filtered} samples.", flush=True)
    print(f"Samples per second: {samples_per_second:.2f}.", flush=True)