Spaces:
Build error
Build error
import argparse | |
import os | |
import copy | |
import re | |
import sys | |
import pandas as pd | |
from nltk.corpus import ptb | |
from weakly_supervised_parser.settings import ( | |
PTB_TRAIN_GOLD_WITHOUT_PUNCTUATION_PATH, | |
PTB_VALID_GOLD_WITHOUT_PUNCTUATION_PATH, | |
PTB_TEST_GOLD_WITHOUT_PUNCTUATION_PATH, | |
) | |
from weakly_supervised_parser.settings import ( | |
PTB_TRAIN_SENTENCES_WITH_PUNCTUATION_PATH, | |
PTB_VALID_SENTENCES_WITH_PUNCTUATION_PATH, | |
PTB_TEST_SENTENCES_WITH_PUNCTUATION_PATH, | |
) | |
from weakly_supervised_parser.settings import ( | |
PTB_TRAIN_SENTENCES_WITHOUT_PUNCTUATION_PATH, | |
PTB_VALID_SENTENCES_WITHOUT_PUNCTUATION_PATH, | |
PTB_TEST_SENTENCES_WITHOUT_PUNCTUATION_PATH, | |
) | |
from weakly_supervised_parser.settings import ( | |
PTB_TRAIN_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH, | |
PTB_VALID_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH, | |
PTB_TEST_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH, | |
) | |
from weakly_supervised_parser.settings import ( | |
YOON_KIM_TRAIN_GOLD_WITHOUT_PUNCTUATION_PATH, | |
YOON_KIM_VALID_GOLD_WITHOUT_PUNCTUATION_PATH, | |
YOON_KIM_TEST_GOLD_WITHOUT_PUNCTUATION_PATH, | |
) | |
from weakly_supervised_parser.tree.helpers import extract_sentence | |
class AlignPTBYoonKimFormat: | |
def __init__(self, ptb_data_path, yk_data_path): | |
self.ptb_data = pd.read_csv(ptb_data_path, sep="\t", header=None) | |
self.yk_data = pd.read_csv(yk_data_path, sep="\t", header=None) | |
def row_mapper(self, save_data_path): | |
dict_mapper = self.ptb_data.reset_index().merge(self.yk_data.reset_index(), on=[0]).set_index("index_y")["index_x"].to_dict() | |
self.ptb_data.loc[self.ptb_data.index.map(dict_mapper)].to_csv(save_data_path, sep="\t", index=False, header=None) | |
return dict_mapper | |
currency_tags_words = ["#", "$", "C$", "A$"] | |
ellipsis = ["*", "*?*", "0", "*T*", "*ICH*", "*U*", "*RNR*", "*EXP*", "*PPA*", "*NOT*"] | |
punctuation_tags = [".", ",", ":", "-LRB-", "-RRB-", "''", "``"] | |
punctuation_words = [".", ",", ":", "-LRB-", "-RRB-", "''", "``", "--", ";", "-", "?", "!", "...", "-LCB-", "-RCB-"] | |
def get_data_ptb(root, output): | |
# tag filter is from https://github.com/yikangshen/PRPN/blob/master/data_ptb.py | |
word_tags = [ | |
"CC", | |
"CD", | |
"DT", | |
"EX", | |
"FW", | |
"IN", | |
"JJ", | |
"JJR", | |
"JJS", | |
"LS", | |
"MD", | |
"NN", | |
"NNS", | |
"NNP", | |
"NNPS", | |
"PDT", | |
"POS", | |
"PRP", | |
"PRP$", | |
"RB", | |
"RBR", | |
"RBS", | |
"RP", | |
"SYM", | |
"TO", | |
"UH", | |
"VB", | |
"VBD", | |
"VBG", | |
"VBN", | |
"VBP", | |
"VBZ", | |
"WDT", | |
"WP", | |
"WP$", | |
"WRB", | |
] | |
train_file_ids = [] | |
val_file_ids = [] | |
test_file_ids = [] | |
train_section = ["02", "03", "04", "05", "06", "07", "08", "09", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", "21"] | |
val_section = ["22"] | |
test_section = ["23"] | |
for dir_name, _, file_list in os.walk(root, topdown=False): | |
if dir_name.split("/")[-1] in train_section: | |
file_ids = train_file_ids | |
elif dir_name.split("/")[-1] in val_section: | |
file_ids = val_file_ids | |
elif dir_name.split("/")[-1] in test_section: | |
file_ids = test_file_ids | |
else: | |
continue | |
for fname in file_list: | |
file_ids.append(os.path.join(dir_name, fname)) | |
assert file_ids[-1].split(".")[-1] == "mrg" | |
print(len(train_file_ids), len(val_file_ids), len(test_file_ids)) | |
def del_tags(tree, word_tags): | |
for sub in tree.subtrees(): | |
for n, child in enumerate(sub): | |
if isinstance(child, str): | |
continue | |
if all(leaf_tag not in word_tags for leaf, leaf_tag in child.pos()): | |
del sub[n] | |
def save_file(file_ids, out_file, include_punctuation=False): | |
f_out = open(out_file, "w") | |
for f in file_ids: | |
sentences = ptb.parsed_sents(f) | |
for sen_tree in sentences: | |
sen_tree_copy = copy.deepcopy(sen_tree) | |
c = 0 | |
while not all([tag in word_tags for _, tag in sen_tree.pos()]): | |
del_tags(sen_tree, word_tags) | |
c += 1 | |
if c > 10: | |
assert False | |
if len(sen_tree.leaves()) < 2: | |
print(f"skipping {' '.join(sen_tree.leaves())} since length < 2") | |
continue | |
if include_punctuation: | |
keep_punctuation_tags = word_tags + punctuation_tags | |
out = " ".join([token for token, pos_tag in sen_tree_copy.pos() if pos_tag in keep_punctuation_tags]) | |
else: | |
out = sen_tree.pformat(margin=sys.maxsize).strip() | |
while re.search("\(([A-Z0-9]{1,})((-|=)[A-Z0-9]*)*\s{1,}\)", out) is not None: | |
out = re.sub("\(([A-Z0-9]{1,})((-|=)[A-Z0-9]*)*\s{1,}\)", "", out) | |
out = out.replace(" )", ")") | |
out = re.sub("\s{2,}", " ", out) | |
f_out.write(out + "\n") | |
f_out.close() | |
save_file(train_file_ids, PTB_TRAIN_GOLD_WITHOUT_PUNCTUATION_PATH, include_punctuation=False) | |
save_file(val_file_ids, PTB_VALID_GOLD_WITHOUT_PUNCTUATION_PATH, include_punctuation=False) | |
save_file(test_file_ids, PTB_TEST_GOLD_WITHOUT_PUNCTUATION_PATH, include_punctuation=False) | |
# Align PTB with Yoon Kim's row order | |
ptb_train_index_mapper = AlignPTBYoonKimFormat( | |
ptb_data_path=PTB_TRAIN_GOLD_WITHOUT_PUNCTUATION_PATH, yk_data_path=YOON_KIM_TRAIN_GOLD_WITHOUT_PUNCTUATION_PATH | |
).row_mapper(save_data_path=PTB_TRAIN_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH) | |
ptb_valid_index_mapper = AlignPTBYoonKimFormat( | |
ptb_data_path=PTB_VALID_GOLD_WITHOUT_PUNCTUATION_PATH, yk_data_path=YOON_KIM_VALID_GOLD_WITHOUT_PUNCTUATION_PATH | |
).row_mapper(save_data_path=PTB_VALID_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH) | |
ptb_test_index_mapper = AlignPTBYoonKimFormat( | |
ptb_data_path=PTB_TEST_GOLD_WITHOUT_PUNCTUATION_PATH, yk_data_path=YOON_KIM_TEST_GOLD_WITHOUT_PUNCTUATION_PATH | |
).row_mapper(save_data_path=PTB_TEST_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH) | |
# Extract sentences without punctuation | |
ptb_train_without_punctuation = pd.read_csv(PTB_TRAIN_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH, sep="\t", header=None, names=["tree"]) | |
ptb_train_without_punctuation["tree"].apply(extract_sentence).to_csv( | |
PTB_TRAIN_SENTENCES_WITHOUT_PUNCTUATION_PATH, index=False, sep="\t", header=None | |
) | |
ptb_valid_without_punctuation = pd.read_csv(PTB_VALID_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH, sep="\t", header=None, names=["tree"]) | |
ptb_valid_without_punctuation["tree"].apply(extract_sentence).to_csv( | |
PTB_VALID_SENTENCES_WITHOUT_PUNCTUATION_PATH, index=False, sep="\t", header=None | |
) | |
ptb_test_without_punctuation = pd.read_csv(PTB_TEST_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH, sep="\t", header=None, names=["tree"]) | |
ptb_test_without_punctuation["tree"].apply(extract_sentence).to_csv( | |
PTB_TEST_SENTENCES_WITHOUT_PUNCTUATION_PATH, index=False, sep="\t", header=None | |
) | |
save_file(train_file_ids, PTB_TRAIN_SENTENCES_WITH_PUNCTUATION_PATH, include_punctuation=True) | |
save_file(val_file_ids, PTB_VALID_SENTENCES_WITH_PUNCTUATION_PATH, include_punctuation=True) | |
save_file(test_file_ids, PTB_TEST_SENTENCES_WITH_PUNCTUATION_PATH, include_punctuation=True) | |
# Extract sentences with punctuation | |
ptb_train_with_punctuation = pd.read_csv(PTB_TRAIN_SENTENCES_WITH_PUNCTUATION_PATH, sep="\t", header=None, names=["sentence"]) | |
ptb_train_with_punctuation = ptb_train_with_punctuation.loc[ptb_train_with_punctuation.index.map(ptb_train_index_mapper)] | |
ptb_train_with_punctuation.to_csv(PTB_TRAIN_SENTENCES_WITH_PUNCTUATION_PATH, index=False, sep="\t", header=None) | |
ptb_valid_with_punctuation = pd.read_csv(PTB_VALID_SENTENCES_WITH_PUNCTUATION_PATH, sep="\t", header=None, names=["sentence"]) | |
ptb_valid_with_punctuation = ptb_valid_with_punctuation.loc[ptb_valid_with_punctuation.index.map(ptb_valid_index_mapper)] | |
ptb_valid_with_punctuation.to_csv(PTB_VALID_SENTENCES_WITH_PUNCTUATION_PATH, index=False, sep="\t", header=None) | |
ptb_test_with_punctuation = pd.read_csv(PTB_TEST_SENTENCES_WITH_PUNCTUATION_PATH, sep="\t", header=None, names=["sentence"]) | |
ptb_test_with_punctuation = ptb_test_with_punctuation.loc[ptb_test_with_punctuation.index.map(ptb_test_index_mapper)] | |
ptb_test_with_punctuation.to_csv(PTB_TEST_SENTENCES_WITH_PUNCTUATION_PATH, index=False, sep="\t", header=None) | |
def main(arguments): | |
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |
parser.add_argument("--ptb_path", help="Path to parsed/mrg/wsj folder", type=str, default="./TEMP/corrected/parsed/mrg/wsj/") | |
parser.add_argument("--output_path", help="Path to save processed files", type=str, default="./data/PROCESSED/english/") | |
args = parser.parse_args(arguments) | |
get_data_ptb(args.ptb_path, args.output_path) | |
if __name__ == "__main__": | |
sys.exit(main(sys.argv[1:])) | |