Spaces:
Runtime error
Runtime error
File size: 3,782 Bytes
10b0761 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import contextlib
import sys
from collections import Counter
from multiprocessing import Pool
from fairseq.data.encoders.gpt2_bpe import get_encoder
def main():
"""
Helper script to encode raw text with the GPT-2 BPE using multiple processes.
The encoder.json and vocab.bpe files can be obtained here:
- https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json
- https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--encoder-json",
help="path to encoder.json",
)
parser.add_argument(
"--vocab-bpe",
type=str,
help="path to vocab.bpe",
)
parser.add_argument(
"--inputs",
nargs="+",
default=["-"],
help="input files to filter/encode",
)
parser.add_argument(
"--outputs",
nargs="+",
default=["-"],
help="path to save encoded outputs",
)
parser.add_argument(
"--keep-empty",
action="store_true",
help="keep empty lines",
)
parser.add_argument("--workers", type=int, default=20)
args = parser.parse_args()
assert len(args.inputs) == len(
args.outputs
), "number of input and output paths should match"
with contextlib.ExitStack() as stack:
inputs = [
stack.enter_context(open(input, "r", encoding="utf-8"))
if input != "-"
else sys.stdin
for input in args.inputs
]
outputs = [
stack.enter_context(open(output, "w", encoding="utf-8"))
if output != "-"
else sys.stdout
for output in args.outputs
]
encoder = MultiprocessingEncoder(args)
pool = Pool(args.workers, initializer=encoder.initializer)
encoded_lines = pool.imap(encoder.encode_lines, zip(*inputs), 100)
stats = Counter()
for i, (filt, enc_lines) in enumerate(encoded_lines, start=1):
if filt == "PASS":
for enc_line, output_h in zip(enc_lines, outputs):
print(enc_line, file=output_h)
else:
stats["num_filtered_" + filt] += 1
if i % 10000 == 0:
print("processed {} lines".format(i), file=sys.stderr)
for k, v in stats.most_common():
print("[{}] filtered {} lines".format(k, v), file=sys.stderr)
class MultiprocessingEncoder(object):
def __init__(self, args):
self.args = args
def initializer(self):
global bpe
bpe = get_encoder(self.args.encoder_json, self.args.vocab_bpe)
def encode(self, line):
global bpe
ids = bpe.encode(line)
return list(map(str, ids))
def decode(self, tokens):
global bpe
return bpe.decode(tokens)
def encode_lines(self, lines):
"""
Encode a set of lines. All lines will be encoded together.
"""
enc_lines = []
for line in lines:
line = line.strip()
if len(line) == 0 and not self.args.keep_empty:
return ["EMPTY", None]
tokens = self.encode(line)
enc_lines.append(" ".join(tokens))
return ["PASS", enc_lines]
def decode_lines(self, lines):
dec_lines = []
for line in lines:
tokens = map(int, line.strip().split())
dec_lines.append(self.decode(tokens))
return ["PASS", dec_lines]
if __name__ == "__main__":
main()
|