nickil commited on
Commit
1928c15
·
1 Parent(s): 560863c

update file

Browse files
weakly_supervised_parser/utils/process_ptb.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import copy
4
+ import re
5
+ import sys
6
+
7
+ import pandas as pd
8
+
9
+ from nltk.corpus import ptb
10
+
11
+ from weakly_supervised_parser.settings import (
12
+ PTB_TRAIN_GOLD_WITHOUT_PUNCTUATION_PATH,
13
+ PTB_VALID_GOLD_WITHOUT_PUNCTUATION_PATH,
14
+ PTB_TEST_GOLD_WITHOUT_PUNCTUATION_PATH,
15
+ )
16
+ from weakly_supervised_parser.settings import (
17
+ PTB_TRAIN_SENTENCES_WITH_PUNCTUATION_PATH,
18
+ PTB_VALID_SENTENCES_WITH_PUNCTUATION_PATH,
19
+ PTB_TEST_SENTENCES_WITH_PUNCTUATION_PATH,
20
+ )
21
+ from weakly_supervised_parser.settings import (
22
+ PTB_TRAIN_SENTENCES_WITHOUT_PUNCTUATION_PATH,
23
+ PTB_VALID_SENTENCES_WITHOUT_PUNCTUATION_PATH,
24
+ PTB_TEST_SENTENCES_WITHOUT_PUNCTUATION_PATH,
25
+ )
26
+ from weakly_supervised_parser.settings import (
27
+ PTB_TRAIN_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH,
28
+ PTB_VALID_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH,
29
+ PTB_TEST_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH,
30
+ )
31
+ from weakly_supervised_parser.settings import (
32
+ YOON_KIM_TRAIN_GOLD_WITHOUT_PUNCTUATION_PATH,
33
+ YOON_KIM_VALID_GOLD_WITHOUT_PUNCTUATION_PATH,
34
+ YOON_KIM_TEST_GOLD_WITHOUT_PUNCTUATION_PATH,
35
+ )
36
+
37
+ from weakly_supervised_parser.tree.helpers import extract_sentence
38
+
39
+
40
+ class AlignPTBYoonKimFormat:
41
+ def __init__(self, ptb_data_path, yk_data_path):
42
+ self.ptb_data = pd.read_csv(ptb_data_path, sep="\t", header=None)
43
+ self.yk_data = pd.read_csv(yk_data_path, sep="\t", header=None)
44
+
45
+ def row_mapper(self, save_data_path):
46
+ dict_mapper = self.ptb_data.reset_index().merge(self.yk_data.reset_index(), on=[0]).set_index("index_y")["index_x"].to_dict()
47
+ self.ptb_data.loc[self.ptb_data.index.map(dict_mapper)].to_csv(save_data_path, sep="\t", index=False, header=None)
48
+ return dict_mapper
49
+
50
+
51
+ currency_tags_words = ["#", "$", "C$", "A$"]
52
+ ellipsis = ["*", "*?*", "0", "*T*", "*ICH*", "*U*", "*RNR*", "*EXP*", "*PPA*", "*NOT*"]
53
+ punctuation_tags = [".", ",", ":", "-LRB-", "-RRB-", "''", "``"]
54
+ punctuation_words = [".", ",", ":", "-LRB-", "-RRB-", "''", "``", "--", ";", "-", "?", "!", "...", "-LCB-", "-RCB-"]
55
+
56
+
57
+ def get_data_ptb(root, output):
58
+ # tag filter is from https://github.com/yikangshen/PRPN/blob/master/data_ptb.py
59
+ word_tags = [
60
+ "CC",
61
+ "CD",
62
+ "DT",
63
+ "EX",
64
+ "FW",
65
+ "IN",
66
+ "JJ",
67
+ "JJR",
68
+ "JJS",
69
+ "LS",
70
+ "MD",
71
+ "NN",
72
+ "NNS",
73
+ "NNP",
74
+ "NNPS",
75
+ "PDT",
76
+ "POS",
77
+ "PRP",
78
+ "PRP$",
79
+ "RB",
80
+ "RBR",
81
+ "RBS",
82
+ "RP",
83
+ "SYM",
84
+ "TO",
85
+ "UH",
86
+ "VB",
87
+ "VBD",
88
+ "VBG",
89
+ "VBN",
90
+ "VBP",
91
+ "VBZ",
92
+ "WDT",
93
+ "WP",
94
+ "WP$",
95
+ "WRB",
96
+ ]
97
+ train_file_ids = []
98
+ val_file_ids = []
99
+ test_file_ids = []
100
+ train_section = ["02", "03", "04", "05", "06", "07", "08", "09", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", "21"]
101
+ val_section = ["22"]
102
+ test_section = ["23"]
103
+
104
+ for dir_name, _, file_list in os.walk(root, topdown=False):
105
+ if dir_name.split("/")[-1] in train_section:
106
+ file_ids = train_file_ids
107
+ elif dir_name.split("/")[-1] in val_section:
108
+ file_ids = val_file_ids
109
+ elif dir_name.split("/")[-1] in test_section:
110
+ file_ids = test_file_ids
111
+ else:
112
+ continue
113
+ for fname in file_list:
114
+ file_ids.append(os.path.join(dir_name, fname))
115
+ assert file_ids[-1].split(".")[-1] == "mrg"
116
+ print(len(train_file_ids), len(val_file_ids), len(test_file_ids))
117
+
118
+ def del_tags(tree, word_tags):
119
+ for sub in tree.subtrees():
120
+ for n, child in enumerate(sub):
121
+ if isinstance(child, str):
122
+ continue
123
+ if all(leaf_tag not in word_tags for leaf, leaf_tag in child.pos()):
124
+ del sub[n]
125
+
126
+ def save_file(file_ids, out_file, include_punctuation=False):
127
+ f_out = open(out_file, "w")
128
+ for f in file_ids:
129
+ sentences = ptb.parsed_sents(f)
130
+ for sen_tree in sentences:
131
+ sen_tree_copy = copy.deepcopy(sen_tree)
132
+ c = 0
133
+ while not all([tag in word_tags for _, tag in sen_tree.pos()]):
134
+ del_tags(sen_tree, word_tags)
135
+ c += 1
136
+ if c > 10:
137
+ assert False
138
+
139
+ if len(sen_tree.leaves()) < 2:
140
+ print(f"skipping {' '.join(sen_tree.leaves())} since length < 2")
141
+ continue
142
+
143
+ if include_punctuation:
144
+ keep_punctuation_tags = word_tags + punctuation_tags
145
+ out = " ".join([token for token, pos_tag in sen_tree_copy.pos() if pos_tag in keep_punctuation_tags])
146
+ else:
147
+ out = sen_tree.pformat(margin=sys.maxsize).strip()
148
+ while re.search("\(([A-Z0-9]{1,})((-|=)[A-Z0-9]*)*\s{1,}\)", out) is not None:
149
+ out = re.sub("\(([A-Z0-9]{1,})((-|=)[A-Z0-9]*)*\s{1,}\)", "", out)
150
+ out = out.replace(" )", ")")
151
+ out = re.sub("\s{2,}", " ", out)
152
+
153
+ f_out.write(out + "\n")
154
+ f_out.close()
155
+
156
+ save_file(train_file_ids, PTB_TRAIN_GOLD_WITHOUT_PUNCTUATION_PATH, include_punctuation=False)
157
+ save_file(val_file_ids, PTB_VALID_GOLD_WITHOUT_PUNCTUATION_PATH, include_punctuation=False)
158
+ save_file(test_file_ids, PTB_TEST_GOLD_WITHOUT_PUNCTUATION_PATH, include_punctuation=False)
159
+
160
+ # Align PTB with Yoon Kim's row order
161
+ ptb_train_index_mapper = AlignPTBYoonKimFormat(
162
+ ptb_data_path=PTB_TRAIN_GOLD_WITHOUT_PUNCTUATION_PATH, yk_data_path=YOON_KIM_TRAIN_GOLD_WITHOUT_PUNCTUATION_PATH
163
+ ).row_mapper(save_data_path=PTB_TRAIN_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH)
164
+ ptb_valid_index_mapper = AlignPTBYoonKimFormat(
165
+ ptb_data_path=PTB_VALID_GOLD_WITHOUT_PUNCTUATION_PATH, yk_data_path=YOON_KIM_VALID_GOLD_WITHOUT_PUNCTUATION_PATH
166
+ ).row_mapper(save_data_path=PTB_VALID_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH)
167
+ ptb_test_index_mapper = AlignPTBYoonKimFormat(
168
+ ptb_data_path=PTB_TEST_GOLD_WITHOUT_PUNCTUATION_PATH, yk_data_path=YOON_KIM_TEST_GOLD_WITHOUT_PUNCTUATION_PATH
169
+ ).row_mapper(save_data_path=PTB_TEST_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH)
170
+
171
+ # Extract sentences without punctuation
172
+ ptb_train_without_punctuation = pd.read_csv(PTB_TRAIN_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH, sep="\t", header=None, names=["tree"])
173
+ ptb_train_without_punctuation["tree"].apply(extract_sentence).to_csv(
174
+ PTB_TRAIN_SENTENCES_WITHOUT_PUNCTUATION_PATH, index=False, sep="\t", header=None
175
+ )
176
+ ptb_valid_without_punctuation = pd.read_csv(PTB_VALID_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH, sep="\t", header=None, names=["tree"])
177
+ ptb_valid_without_punctuation["tree"].apply(extract_sentence).to_csv(
178
+ PTB_VALID_SENTENCES_WITHOUT_PUNCTUATION_PATH, index=False, sep="\t", header=None
179
+ )
180
+ ptb_test_without_punctuation = pd.read_csv(PTB_TEST_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH, sep="\t", header=None, names=["tree"])
181
+ ptb_test_without_punctuation["tree"].apply(extract_sentence).to_csv(
182
+ PTB_TEST_SENTENCES_WITHOUT_PUNCTUATION_PATH, index=False, sep="\t", header=None
183
+ )
184
+
185
+ save_file(train_file_ids, PTB_TRAIN_SENTENCES_WITH_PUNCTUATION_PATH, include_punctuation=True)
186
+ save_file(val_file_ids, PTB_VALID_SENTENCES_WITH_PUNCTUATION_PATH, include_punctuation=True)
187
+ save_file(test_file_ids, PTB_TEST_SENTENCES_WITH_PUNCTUATION_PATH, include_punctuation=True)
188
+
189
+ # Extract sentences with punctuation
190
+ ptb_train_with_punctuation = pd.read_csv(PTB_TRAIN_SENTENCES_WITH_PUNCTUATION_PATH, sep="\t", header=None, names=["sentence"])
191
+ ptb_train_with_punctuation = ptb_train_with_punctuation.loc[ptb_train_with_punctuation.index.map(ptb_train_index_mapper)]
192
+ ptb_train_with_punctuation.to_csv(PTB_TRAIN_SENTENCES_WITH_PUNCTUATION_PATH, index=False, sep="\t", header=None)
193
+ ptb_valid_with_punctuation = pd.read_csv(PTB_VALID_SENTENCES_WITH_PUNCTUATION_PATH, sep="\t", header=None, names=["sentence"])
194
+ ptb_valid_with_punctuation = ptb_valid_with_punctuation.loc[ptb_valid_with_punctuation.index.map(ptb_valid_index_mapper)]
195
+ ptb_valid_with_punctuation.to_csv(PTB_VALID_SENTENCES_WITH_PUNCTUATION_PATH, index=False, sep="\t", header=None)
196
+ ptb_test_with_punctuation = pd.read_csv(PTB_TEST_SENTENCES_WITH_PUNCTUATION_PATH, sep="\t", header=None, names=["sentence"])
197
+ ptb_test_with_punctuation = ptb_test_with_punctuation.loc[ptb_test_with_punctuation.index.map(ptb_test_index_mapper)]
198
+ ptb_test_with_punctuation.to_csv(PTB_TEST_SENTENCES_WITH_PUNCTUATION_PATH, index=False, sep="\t", header=None)
199
+
200
+
201
+ def main(arguments):
202
+ parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
203
+ parser.add_argument("--ptb_path", help="Path to parsed/mrg/wsj folder", type=str, default="./TEMP/corrected/parsed/mrg/wsj/")
204
+ parser.add_argument("--output_path", help="Path to save processed files", type=str, default="./data/PROCESSED/english/")
205
+ args = parser.parse_args(arguments)
206
+ get_data_ptb(args.ptb_path, args.output_path)
207
+
208
+
209
+ if __name__ == "__main__":
210
+ sys.exit(main(sys.argv[1:]))