Token Classification
Scikit-learn
English
ner
legal
crf
shashankmc commited on
Commit
2b0bc49
·
verified ·
1 Parent(s): 4f3fd9a

Predict file

Browse files
Files changed (1) hide show
  1. predict.py +172 -0
predict.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import pandas as pd
3
+ import joblib
4
+ import nltk
5
+ from nltk import pos_tag
6
+ import string
7
+ from nltk.stem import WordNetLemmatizer
8
+ from nltk.stem import PorterStemmer
9
+
10
+
11
+ # Check if nltk modules are downloaded, if not download them
12
+ nltk.download('wordnet')
13
+ nltk.download('omw-1.4')
14
+ nltk.download("averaged_perceptron_tagger")
15
+
16
+
17
+ class getsentence(object):
18
+ '''
19
+ This class is used to get the sentences from the dataset.
20
+ Converts from BIO format to sentences using their sentence numbers
21
+ '''
22
+ def __init__(self, data):
23
+ self.n_sent = 1.0
24
+ self.data = data
25
+ self.empty = False
26
+ self.grouped = self.data.groupby("sentence_num").apply(self._agg_func)
27
+ self.sentences = [s for s in self.grouped]
28
+
29
+ def _agg_func(self, s):
30
+ return [(w, p) for w, p in zip(s["token"].values.tolist(),
31
+ s["pos_tag"].values.tolist())]
32
+
33
+
34
+ def word2features(sent, i):
35
+ '''
36
+ This method is used to extract features from the words in the sentence.
37
+ The main features extracted are:
38
+ - word.lower(): The word in lowercase
39
+ - word.isdigit(): If the word is a digit
40
+ - word.punct(): If the word is a punctuation
41
+ - postag: The pos tag of the word
42
+ - word.lemma(): The lemma of the word
43
+ - word.stem(): The stem of the word
44
+ The features (not all) are also extracted for the 4 previous and 4 next words.
45
+ '''
46
+ global token_count
47
+ wordnet_lemmatizer = WordNetLemmatizer()
48
+ porter_stemmer = PorterStemmer()
49
+ word = sent[i][0]
50
+ postag = sent[i][1]
51
+
52
+ features = {
53
+ 'bias': 1.0,
54
+ 'word.lower()': word.lower(),
55
+ 'word.isdigit()': word.isdigit(),
56
+ # Check if its punctuations
57
+ 'word.punct()': word in string.punctuation,
58
+ 'postag': postag,
59
+ # Lemma of the word
60
+ 'word.lemma()': wordnet_lemmatizer.lemmatize(word),
61
+ # Stem of the word
62
+ 'word.stem()': porter_stemmer.stem(word)
63
+ }
64
+ if i > 0:
65
+ word1 = sent[i-1][0]
66
+ postag1 = sent[i-1][1]
67
+ features.update({
68
+ '-1:word.lower()': word1.lower(),
69
+ '-1:word.isdigit()': word1.isdigit(),
70
+ '-1:word.punct()': word1 in string.punctuation,
71
+ '-1:postag': postag1
72
+ })
73
+ if i - 2 >= 0:
74
+ features.update({
75
+ '-2:word.lower()': sent[i-2][0].lower(),
76
+ '-2:word.isdigit()': sent[i-2][0].isdigit(),
77
+ '-2:word.punct()': sent[i-2][0] in string.punctuation,
78
+ '-2:postag': sent[i-2][1]
79
+ })
80
+ if i - 3 >= 0:
81
+ features.update({
82
+ '-3:word.lower()': sent[i-3][0].lower(),
83
+ '-3:word.isdigit()': sent[i-3][0].isdigit(),
84
+ '-3:word.punct()': sent[i-3][0] in string.punctuation,
85
+ '-3:postag': sent[i-3][1]
86
+ })
87
+ if i - 4 >= 0:
88
+ features.update({
89
+ '-4:word.lower()': sent[i-4][0].lower(),
90
+ '-4:word.isdigit()': sent[i-4][0].isdigit(),
91
+ '-4:word.punct()': sent[i-4][0] in string.punctuation,
92
+ '-4:postag': sent[i-4][1]
93
+ })
94
+ else:
95
+ features['BOS'] = True
96
+
97
+ if i < len(sent)-1:
98
+ word1 = sent[i+1][0]
99
+ postag1 = sent[i+1][1]
100
+ features.update({
101
+ '+1:word.lower()': word1.lower(),
102
+ '+1:word.isdigit()': word1.isdigit(),
103
+ '+1:word.punct()': word1 in string.punctuation,
104
+ '+1:postag': postag1
105
+ })
106
+ if i + 2 < len(sent):
107
+ features.update({
108
+ '+2:word.lower()': sent[i+2][0].lower(),
109
+ '+2:word.isdigit()': sent[i+2][0].isdigit(),
110
+ '+2:word.punct()': sent[i+2][0] in string.punctuation,
111
+ '+2:postag': sent[i+2][1]
112
+ })
113
+ if i + 3 < len(sent):
114
+ features.update({
115
+ '+3:word.lower()': sent[i+3][0].lower(),
116
+ '+3:word.isdigit()': sent[i+3][0].isdigit(),
117
+ '+3:word.punct()': sent[i+3][0] in string.punctuation,
118
+ '+3:postag': sent[i+3][1]
119
+ })
120
+ if i + 4 < len(sent):
121
+ features.update({
122
+ '+4:word.lower()': sent[i+4][0].lower(),
123
+ '+4:word.isdigit()': sent[i+4][0].isdigit(),
124
+ '+4:word.punct()': sent[i+4][0] in string.punctuation,
125
+ '+4:postag': sent[i+4][1]
126
+ })
127
+ else:
128
+ features['EOS'] = True
129
+
130
+ return features
131
+
132
+
133
+ def sent2features(sent):
134
+ '''
135
+ This method is used to extract features from the sentence.
136
+ '''
137
+ return [word2features(sent, i) for i in range(len(sent))]
138
+
139
+
140
+ print("Evaluating the model...")
141
+ # Load file from your directory
142
+ df_eval = pd.read_excel("testset_NER_LegalLens.xlsx")
143
+ print("Read the evaluation dataset.")
144
+ df_eval["tokens"] = df_eval["tokens"].apply(ast.literal_eval)
145
+ df_eval['pos_tags'] = df_eval['tokens'].apply(lambda x: [tag[1]
146
+ for tag in pos_tag(x)])
147
+ data_eval = []
148
+ for i in range(len(df_eval)):
149
+ for j in range(len(df_eval["tokens"][i])):
150
+ data_eval.append(
151
+ {
152
+ "sentence_num": i+1,
153
+ "id": df_eval["id"][i],
154
+ "token": df_eval["tokens"][i][j],
155
+ "pos_tag": df_eval["pos_tags"][i][j],
156
+ }
157
+ )
158
+ data_eval = pd.DataFrame(data_eval)
159
+ print("Dataframe created.")
160
+ getter = getsentence(data_eval)
161
+ sentences_eval = getter.sentences
162
+ X_eval = [sent2features(s) for s in sentences_eval]
163
+ print("Predicting the NER tags...")
164
+ # Load model from your direction
165
+ crf = joblib.load("../models/crf.pkl")
166
+ y_pred_eval = crf.predict(X_eval)
167
+ print("NER tags predicted.")
168
+ df_eval["ner_tags"] = y_pred_eval
169
+ df_eval.drop(columns=["pos_tags"], inplace=True)
170
+ print("Saving the predictions...")
171
+ df_eval.to_csv("predictions_NERLens.csv", index=False)
172
+ print("Predictions saved.")