nb-wav2vec2-300m-bokmaal / add_kenlm.py
pere's picture
first
5ba903e
raw
history blame
1.27 kB
import argparse
from transformers import AutoProcessor
from transformers import Wav2Vec2ProcessorWithLM
from pyctcdecode import build_ctcdecoder
def main(args):
processor = AutoProcessor.from_pretrained(args.model_name_or_path)
vocab_dict = processor.tokenizer.get_vocab()
sorted_vocab_dict = {
k.lower(): v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])
}
decoder = build_ctcdecoder(
labels=list(sorted_vocab_dict.keys()),
kenlm_model_path=args.kenlm_model_path,
)
processor_with_lm = Wav2Vec2ProcessorWithLM(
feature_extractor=processor.feature_extractor,
tokenizer=processor.tokenizer,
decoder=decoder,
)
processor_with_lm.save_pretrained(args.model_name_or_path)
print(f"Run: ~/bin/build_binary language_model/*.arpa language_model/5gram.bin -T $(pwd) && rm language_model/*.arpa")
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model_name_or_path', default="./", help='Model name or path. Defaults to ./')
parser.add_argument('--kenlm_model_path', required=True, help='Path to KenLM arpa file.')
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
main(args)