Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
Split a large file into a train and valid set while respecting document | |
boundaries. Documents should be separated by a single empty line. | |
""" | |
import argparse | |
import random | |
import sys | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("input") | |
parser.add_argument("sample_output", help="train output file") | |
parser.add_argument("remainder_output", help="valid output file") | |
parser.add_argument("-k", type=int, help="remainder size") | |
parser.add_argument( | |
"--lines", action="store_true", help="split lines instead of docs" | |
) | |
args = parser.parse_args() | |
assert args.k is not None | |
sample = [] | |
remainder = [] | |
num_docs = [0] | |
def update_sample(doc): | |
if len(sample) < args.k: | |
sample.append(doc.copy()) | |
else: | |
i = num_docs[0] | |
j = random.randrange(i + 1) | |
if j < args.k: | |
remainder.append(sample[j]) | |
sample[j] = doc.copy() | |
else: | |
remainder.append(doc.copy()) | |
num_docs[0] += 1 | |
doc.clear() | |
with open(args.input, "r", encoding="utf-8") as h: | |
doc = [] | |
for i, line in enumerate(h): | |
if line.strip() == "": # empty line indicates new document | |
update_sample(doc) | |
else: | |
doc.append(line) | |
if args.lines: | |
update_sample(doc) | |
if i % 1000000 == 0: | |
print(i, file=sys.stderr, end="", flush=True) | |
elif i % 100000 == 0: | |
print(".", file=sys.stderr, end="", flush=True) | |
if len(doc) > 0: | |
update_sample(doc) | |
print(file=sys.stderr, flush=True) | |
assert len(sample) == args.k | |
with open(args.sample_output, "w", encoding="utf-8") as out: | |
first = True | |
for doc in sample: | |
if not first and not args.lines: | |
out.write("\n") | |
first = False | |
for line in doc: | |
out.write(line) | |
with open(args.remainder_output, "w", encoding="utf-8") as out: | |
first = True | |
for doc in remainder: | |
if not first and not args.lines: | |
out.write("\n") | |
first = False | |
for line in doc: | |
out.write(line) | |
if __name__ == "__main__": | |
main() | |