Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import argparse | |
import contextlib | |
import sys | |
from collections import Counter | |
from multiprocessing import Pool | |
from fairseq.data.encoders.gpt2_bpe import get_encoder | |
def main(): | |
""" | |
Helper script to encode raw text with the GPT-2 BPE using multiple processes. | |
The encoder.json and vocab.bpe files can be obtained here: | |
- https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json | |
- https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe | |
""" | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--encoder-json", | |
help="path to encoder.json", | |
) | |
parser.add_argument( | |
"--vocab-bpe", | |
type=str, | |
help="path to vocab.bpe", | |
) | |
parser.add_argument( | |
"--inputs", | |
nargs="+", | |
default=["-"], | |
help="input files to filter/encode", | |
) | |
parser.add_argument( | |
"--outputs", | |
nargs="+", | |
default=["-"], | |
help="path to save encoded outputs", | |
) | |
parser.add_argument( | |
"--keep-empty", | |
action="store_true", | |
help="keep empty lines", | |
) | |
parser.add_argument("--workers", type=int, default=20) | |
args = parser.parse_args() | |
assert len(args.inputs) == len( | |
args.outputs | |
), "number of input and output paths should match" | |
with contextlib.ExitStack() as stack: | |
inputs = [ | |
stack.enter_context(open(input, "r", encoding="utf-8")) | |
if input != "-" | |
else sys.stdin | |
for input in args.inputs | |
] | |
outputs = [ | |
stack.enter_context(open(output, "w", encoding="utf-8")) | |
if output != "-" | |
else sys.stdout | |
for output in args.outputs | |
] | |
encoder = MultiprocessingEncoder(args) | |
pool = Pool(args.workers, initializer=encoder.initializer) | |
encoded_lines = pool.imap(encoder.encode_lines, zip(*inputs), 100) | |
stats = Counter() | |
for i, (filt, enc_lines) in enumerate(encoded_lines, start=1): | |
if filt == "PASS": | |
for enc_line, output_h in zip(enc_lines, outputs): | |
print(enc_line, file=output_h) | |
else: | |
stats["num_filtered_" + filt] += 1 | |
if i % 10000 == 0: | |
print("processed {} lines".format(i), file=sys.stderr) | |
for k, v in stats.most_common(): | |
print("[{}] filtered {} lines".format(k, v), file=sys.stderr) | |
class MultiprocessingEncoder(object): | |
def __init__(self, args): | |
self.args = args | |
def initializer(self): | |
global bpe | |
bpe = get_encoder(self.args.encoder_json, self.args.vocab_bpe) | |
def encode(self, line): | |
global bpe | |
ids = bpe.encode(line) | |
return list(map(str, ids)) | |
def decode(self, tokens): | |
global bpe | |
return bpe.decode(tokens) | |
def encode_lines(self, lines): | |
""" | |
Encode a set of lines. All lines will be encoded together. | |
""" | |
enc_lines = [] | |
for line in lines: | |
line = line.strip() | |
if len(line) == 0 and not self.args.keep_empty: | |
return ["EMPTY", None] | |
tokens = self.encode(line) | |
enc_lines.append(" ".join(tokens)) | |
return ["PASS", enc_lines] | |
def decode_lines(self, lines): | |
dec_lines = [] | |
for line in lines: | |
tokens = map(int, line.strip().split()) | |
dec_lines.append(self.decode(tokens)) | |
return ["PASS", dec_lines] | |
if __name__ == "__main__": | |
main() | |