Predict file
Browse files- predict.py +172 -0
predict.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ast
|
2 |
+
import pandas as pd
|
3 |
+
import joblib
|
4 |
+
import nltk
|
5 |
+
from nltk import pos_tag
|
6 |
+
import string
|
7 |
+
from nltk.stem import WordNetLemmatizer
|
8 |
+
from nltk.stem import PorterStemmer
|
9 |
+
|
10 |
+
|
11 |
+
# Check if nltk modules are downloaded, if not download them
|
12 |
+
nltk.download('wordnet')
|
13 |
+
nltk.download('omw-1.4')
|
14 |
+
nltk.download("averaged_perceptron_tagger")
|
15 |
+
|
16 |
+
|
17 |
+
class getsentence(object):
|
18 |
+
'''
|
19 |
+
This class is used to get the sentences from the dataset.
|
20 |
+
Converts from BIO format to sentences using their sentence numbers
|
21 |
+
'''
|
22 |
+
def __init__(self, data):
|
23 |
+
self.n_sent = 1.0
|
24 |
+
self.data = data
|
25 |
+
self.empty = False
|
26 |
+
self.grouped = self.data.groupby("sentence_num").apply(self._agg_func)
|
27 |
+
self.sentences = [s for s in self.grouped]
|
28 |
+
|
29 |
+
def _agg_func(self, s):
|
30 |
+
return [(w, p) for w, p in zip(s["token"].values.tolist(),
|
31 |
+
s["pos_tag"].values.tolist())]
|
32 |
+
|
33 |
+
|
34 |
+
def word2features(sent, i):
|
35 |
+
'''
|
36 |
+
This method is used to extract features from the words in the sentence.
|
37 |
+
The main features extracted are:
|
38 |
+
- word.lower(): The word in lowercase
|
39 |
+
- word.isdigit(): If the word is a digit
|
40 |
+
- word.punct(): If the word is a punctuation
|
41 |
+
- postag: The pos tag of the word
|
42 |
+
- word.lemma(): The lemma of the word
|
43 |
+
- word.stem(): The stem of the word
|
44 |
+
The features (not all) are also extracted for the 4 previous and 4 next words.
|
45 |
+
'''
|
46 |
+
global token_count
|
47 |
+
wordnet_lemmatizer = WordNetLemmatizer()
|
48 |
+
porter_stemmer = PorterStemmer()
|
49 |
+
word = sent[i][0]
|
50 |
+
postag = sent[i][1]
|
51 |
+
|
52 |
+
features = {
|
53 |
+
'bias': 1.0,
|
54 |
+
'word.lower()': word.lower(),
|
55 |
+
'word.isdigit()': word.isdigit(),
|
56 |
+
# Check if its punctuations
|
57 |
+
'word.punct()': word in string.punctuation,
|
58 |
+
'postag': postag,
|
59 |
+
# Lemma of the word
|
60 |
+
'word.lemma()': wordnet_lemmatizer.lemmatize(word),
|
61 |
+
# Stem of the word
|
62 |
+
'word.stem()': porter_stemmer.stem(word)
|
63 |
+
}
|
64 |
+
if i > 0:
|
65 |
+
word1 = sent[i-1][0]
|
66 |
+
postag1 = sent[i-1][1]
|
67 |
+
features.update({
|
68 |
+
'-1:word.lower()': word1.lower(),
|
69 |
+
'-1:word.isdigit()': word1.isdigit(),
|
70 |
+
'-1:word.punct()': word1 in string.punctuation,
|
71 |
+
'-1:postag': postag1
|
72 |
+
})
|
73 |
+
if i - 2 >= 0:
|
74 |
+
features.update({
|
75 |
+
'-2:word.lower()': sent[i-2][0].lower(),
|
76 |
+
'-2:word.isdigit()': sent[i-2][0].isdigit(),
|
77 |
+
'-2:word.punct()': sent[i-2][0] in string.punctuation,
|
78 |
+
'-2:postag': sent[i-2][1]
|
79 |
+
})
|
80 |
+
if i - 3 >= 0:
|
81 |
+
features.update({
|
82 |
+
'-3:word.lower()': sent[i-3][0].lower(),
|
83 |
+
'-3:word.isdigit()': sent[i-3][0].isdigit(),
|
84 |
+
'-3:word.punct()': sent[i-3][0] in string.punctuation,
|
85 |
+
'-3:postag': sent[i-3][1]
|
86 |
+
})
|
87 |
+
if i - 4 >= 0:
|
88 |
+
features.update({
|
89 |
+
'-4:word.lower()': sent[i-4][0].lower(),
|
90 |
+
'-4:word.isdigit()': sent[i-4][0].isdigit(),
|
91 |
+
'-4:word.punct()': sent[i-4][0] in string.punctuation,
|
92 |
+
'-4:postag': sent[i-4][1]
|
93 |
+
})
|
94 |
+
else:
|
95 |
+
features['BOS'] = True
|
96 |
+
|
97 |
+
if i < len(sent)-1:
|
98 |
+
word1 = sent[i+1][0]
|
99 |
+
postag1 = sent[i+1][1]
|
100 |
+
features.update({
|
101 |
+
'+1:word.lower()': word1.lower(),
|
102 |
+
'+1:word.isdigit()': word1.isdigit(),
|
103 |
+
'+1:word.punct()': word1 in string.punctuation,
|
104 |
+
'+1:postag': postag1
|
105 |
+
})
|
106 |
+
if i + 2 < len(sent):
|
107 |
+
features.update({
|
108 |
+
'+2:word.lower()': sent[i+2][0].lower(),
|
109 |
+
'+2:word.isdigit()': sent[i+2][0].isdigit(),
|
110 |
+
'+2:word.punct()': sent[i+2][0] in string.punctuation,
|
111 |
+
'+2:postag': sent[i+2][1]
|
112 |
+
})
|
113 |
+
if i + 3 < len(sent):
|
114 |
+
features.update({
|
115 |
+
'+3:word.lower()': sent[i+3][0].lower(),
|
116 |
+
'+3:word.isdigit()': sent[i+3][0].isdigit(),
|
117 |
+
'+3:word.punct()': sent[i+3][0] in string.punctuation,
|
118 |
+
'+3:postag': sent[i+3][1]
|
119 |
+
})
|
120 |
+
if i + 4 < len(sent):
|
121 |
+
features.update({
|
122 |
+
'+4:word.lower()': sent[i+4][0].lower(),
|
123 |
+
'+4:word.isdigit()': sent[i+4][0].isdigit(),
|
124 |
+
'+4:word.punct()': sent[i+4][0] in string.punctuation,
|
125 |
+
'+4:postag': sent[i+4][1]
|
126 |
+
})
|
127 |
+
else:
|
128 |
+
features['EOS'] = True
|
129 |
+
|
130 |
+
return features
|
131 |
+
|
132 |
+
|
133 |
+
def sent2features(sent):
|
134 |
+
'''
|
135 |
+
This method is used to extract features from the sentence.
|
136 |
+
'''
|
137 |
+
return [word2features(sent, i) for i in range(len(sent))]
|
138 |
+
|
139 |
+
|
140 |
+
print("Evaluating the model...")
|
141 |
+
# Load file from your directory
|
142 |
+
df_eval = pd.read_excel("testset_NER_LegalLens.xlsx")
|
143 |
+
print("Read the evaluation dataset.")
|
144 |
+
df_eval["tokens"] = df_eval["tokens"].apply(ast.literal_eval)
|
145 |
+
df_eval['pos_tags'] = df_eval['tokens'].apply(lambda x: [tag[1]
|
146 |
+
for tag in pos_tag(x)])
|
147 |
+
data_eval = []
|
148 |
+
for i in range(len(df_eval)):
|
149 |
+
for j in range(len(df_eval["tokens"][i])):
|
150 |
+
data_eval.append(
|
151 |
+
{
|
152 |
+
"sentence_num": i+1,
|
153 |
+
"id": df_eval["id"][i],
|
154 |
+
"token": df_eval["tokens"][i][j],
|
155 |
+
"pos_tag": df_eval["pos_tags"][i][j],
|
156 |
+
}
|
157 |
+
)
|
158 |
+
data_eval = pd.DataFrame(data_eval)
|
159 |
+
print("Dataframe created.")
|
160 |
+
getter = getsentence(data_eval)
|
161 |
+
sentences_eval = getter.sentences
|
162 |
+
X_eval = [sent2features(s) for s in sentences_eval]
|
163 |
+
print("Predicting the NER tags...")
|
164 |
+
# Load model from your direction
|
165 |
+
crf = joblib.load("../models/crf.pkl")
|
166 |
+
y_pred_eval = crf.predict(X_eval)
|
167 |
+
print("NER tags predicted.")
|
168 |
+
df_eval["ner_tags"] = y_pred_eval
|
169 |
+
df_eval.drop(columns=["pos_tags"], inplace=True)
|
170 |
+
print("Saving the predictions...")
|
171 |
+
df_eval.to_csv("predictions_NERLens.csv", index=False)
|
172 |
+
print("Predictions saved.")
|