Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import argparse | |
import json | |
import os | |
import re | |
class InputExample: | |
def __init__(self, paragraph, qa_list, label): | |
self.paragraph = paragraph | |
self.qa_list = qa_list | |
self.label = label | |
def get_examples(data_dir, set_type): | |
""" | |
Extract paragraph and question-answer list from each json file | |
""" | |
examples = [] | |
levels = ["middle", "high"] | |
set_type_c = set_type.split("-") | |
if len(set_type_c) == 2: | |
levels = [set_type_c[1]] | |
set_type = set_type_c[0] | |
for level in levels: | |
cur_dir = os.path.join(data_dir, set_type, level) | |
for filename in os.listdir(cur_dir): | |
cur_path = os.path.join(cur_dir, filename) | |
with open(cur_path, "r") as f: | |
cur_data = json.load(f) | |
answers = cur_data["answers"] | |
options = cur_data["options"] | |
questions = cur_data["questions"] | |
context = cur_data["article"].replace("\n", " ") | |
context = re.sub(r"\s+", " ", context) | |
for i in range(len(answers)): | |
label = ord(answers[i]) - ord("A") | |
qa_list = [] | |
question = questions[i] | |
for j in range(4): | |
option = options[i][j] | |
if "_" in question: | |
qa_cat = question.replace("_", option) | |
else: | |
qa_cat = " ".join([question, option]) | |
qa_cat = re.sub(r"\s+", " ", qa_cat) | |
qa_list.append(qa_cat) | |
examples.append(InputExample(context, qa_list, label)) | |
return examples | |
def main(): | |
""" | |
Helper script to extract paragraphs questions and answers from RACE datasets. | |
""" | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--input-dir", | |
help="input directory for downloaded RACE dataset", | |
) | |
parser.add_argument( | |
"--output-dir", | |
help="output directory for extracted data", | |
) | |
args = parser.parse_args() | |
if not os.path.exists(args.output_dir): | |
os.makedirs(args.output_dir, exist_ok=True) | |
for set_type in ["train", "dev", "test-middle", "test-high"]: | |
examples = get_examples(args.input_dir, set_type) | |
qa_file_paths = [ | |
os.path.join(args.output_dir, set_type + ".input" + str(i + 1)) | |
for i in range(4) | |
] | |
qa_files = [open(qa_file_path, "w") for qa_file_path in qa_file_paths] | |
outf_context_path = os.path.join(args.output_dir, set_type + ".input0") | |
outf_label_path = os.path.join(args.output_dir, set_type + ".label") | |
outf_context = open(outf_context_path, "w") | |
outf_label = open(outf_label_path, "w") | |
for example in examples: | |
outf_context.write(example.paragraph + "\n") | |
for i in range(4): | |
qa_files[i].write(example.qa_list[i] + "\n") | |
outf_label.write(str(example.label) + "\n") | |
for f in qa_files: | |
f.close() | |
outf_label.close() | |
outf_context.close() | |
if __name__ == "__main__": | |
main() | |