Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
from __future__ import absolute_import, division, print_function, unicode_literals | |
import argparse | |
import sentencepiece as spm | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--model", required=True, help="sentencepiece model to use for decoding" | |
) | |
parser.add_argument("--input", required=True, help="input file to decode") | |
parser.add_argument("--input_format", choices=["piece", "id"], default="piece") | |
args = parser.parse_args() | |
sp = spm.SentencePieceProcessor() | |
sp.Load(args.model) | |
if args.input_format == "piece": | |
def decode(l): | |
return "".join(sp.DecodePieces(l)) | |
elif args.input_format == "id": | |
def decode(l): | |
return "".join(sp.DecodeIds(l)) | |
else: | |
raise NotImplementedError | |
def tok2int(tok): | |
# remap reference-side <unk> (represented as <<unk>>) to 0 | |
return int(tok) if tok != "<<unk>>" else 0 | |
with open(args.input, "r", encoding="utf-8") as h: | |
for line in h: | |
if args.input_format == "id": | |
print(decode(list(map(tok2int, line.rstrip().split())))) | |
elif args.input_format == "piece": | |
print(decode(line.rstrip().split())) | |
if __name__ == "__main__": | |
main() | |