Vietnamese_VITS / Step0_transcription.py
fhieni's picture
Upload 7 files
98115b8
raw
history blame contribute delete
No virus
5.12 kB
from re import A
from transformers.file_utils import cached_path, hf_bucket_url
import os, zipfile
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import soundfile as sf
import torch
import kenlm
from pyctcdecode import Alphabet, BeamSearchDecoderCTC, LanguageModel
import os
from multiprocessing import Pool
import argparse, subprocess, tempfile
def extract_audio(filename, channels=1, rate=16000):
"""
Extract audio from an input file to a temporary WAV file.
"""
temp = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
if not os.path.isfile(filename):
print("The given file does not exist: {}".format(filename))
raise Exception("Invalid filepath: {}".format(filename))
command = ["ffmpeg", "-y", "-i", filename,
"-ac", str(channels), "-ar", str(rate),
"-loglevel", "error", temp.name]
use_shell = True if os.name == "nt" else False
subprocess.check_output(command, stdin=open(os.devnull), shell=use_shell)
return temp.name, rate
class Wav2Vec:
def __init__(self):
self.device = "cuda"
# Load Wav2Vec
cache_dir = './cache/'
self.processor = Wav2Vec2Processor.from_pretrained("nguyenvulebinh/wav2vec2-base-vietnamese-250h", cache_dir=cache_dir)
lm_file = hf_bucket_url("nguyenvulebinh/wav2vec2-base-vietnamese-250h", filename='vi_lm_4grams.bin.zip')
lm_file = cached_path(lm_file,cache_dir=cache_dir)
with zipfile.ZipFile(lm_file, 'r') as zip_ref:
zip_ref.extractall(cache_dir)
lm_file = cache_dir + 'vi_lm_4grams.bin'
self.model = Wav2Vec2ForCTC.from_pretrained("nguyenvulebinh/wav2vec2-base-vietnamese-250h", cache_dir=cache_dir)
self.model.to(self.device)
# Load Ngram LM
self.ngram_lm_model = self.get_decoder_ngram_model(self.processor.tokenizer, lm_file)
def get_decoder_ngram_model(self, tokenizer, ngram_lm_path):
vocab_dict = tokenizer.get_vocab()
sort_vocab = sorted((value, key) for (key, value) in vocab_dict.items())
vocab = [x[1] for x in sort_vocab][:-2]
vocab_list = vocab
# convert ctc blank character representation
vocab_list[tokenizer.pad_token_id] = ""
# replace special characters
vocab_list[tokenizer.unk_token_id] = ""
# vocab_list[tokenizer.bos_token_id] = ""
# vocab_list[tokenizer.eos_token_id] = ""
# convert space character representation
vocab_list[tokenizer.word_delimiter_token_id] = " "
# specify ctc blank char index, since conventially it is the last entry of the logit matrix
alphabet = Alphabet.build_alphabet(vocab_list, ctc_token_idx=tokenizer.pad_token_id)
lm_model = kenlm.Model(ngram_lm_path)
decoder = BeamSearchDecoderCTC(alphabet, language_model=LanguageModel(lm_model))
return decoder
# define function to read in sound file
def map_to_array(self, batch):
speech, sampling_rate = sf.read(batch["file"])
batch["speech"] = speech
batch["sampling_rate"] = sampling_rate
return batch
def inference(self, filename):
# load dummy dataset and read soundfiles
ds = self.map_to_array({"file": filename})
# infer model
input_values = self.processor(ds["speech"], sampling_rate=ds["sampling_rate"], return_tensors="pt").input_values
input_values = input_values.to(self.device)
# model.to("cuda")
logits = self.model(input_values).logits[0]
# print(logits.shape)
# decode ctc output
pred_ids = torch.argmax(logits, dim=-1)
greedy_search_output = self.processor.decode(pred_ids)
beam_search_output = self.ngram_lm_model.decode(logits.cpu().detach().numpy(), beam_width=500)
# print("Greedy search output: {}".format(greedy_search_output))
# print("Beam search output: {}".format(beam_search_output))
return beam_search_output
if __name__ == "__main__":
w2v = Wav2Vec()
import glob, tqdm
parser = argparse.ArgumentParser()
parser.add_argument('--wavs', default="DATA/wavs", help="", type=str)
parser.add_argument('--train_file', default="DATA/train.txt", help="", type=str)
parser.add_argument('--val_file', default="DATA/train.txt", help="", type=str)
args = parser.parse_args()
os.makedirs(os.path.dirname(args.train_file), exist_ok = True)
count_val = 0
fw = open(args.train_file, "w+", encoding="utf-8")
fw_val = open(args.val_file, "w+", encoding="utf-8")
for i in tqdm.tqdm(glob.glob(args.wavs + "/*.wav")):
audio_filename, audio_rate = extract_audio(i)
output = w2v.inference(audio_filename)
fw.write(i.split("/")[-1] + " " + output + "\n")
if count_val < 64:
count_val = count_val + 1
fw_val.write(i.split("/")[-1] + " " + output + "\n")
fw.close()
fw_val.close()