Spaces:
Runtime error
Runtime error
""" | |
TODO: | |
+ [x] Load Configuration | |
+ [ ] Multi ASR Engine | |
+ [ ] Batch / Real Time support | |
""" | |
from pathlib import Path | |
from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModelForCTC, AutoProcessor | |
from datasets import load_dataset | |
from datasets import Dataset, Audio | |
import pdb | |
import string | |
# local import | |
import sys | |
sys.path.append("src") | |
# token_model = AutoModelForCTC.from_pretrained( | |
# "facebook/wav2vec2-base-960h" | |
# ) | |
# ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation") | |
audio_path = "/Users/kevingeng/Laronix/Laronix_PAL_ASR_Offline_Plot/data/samples/3_Healthy1.wav" | |
audio_dir= "/Users/kevingeng/Laronix/laronix_automos/data/Patient_sil_trim_16k_normed_5_snr_40/" | |
# tgt_audio_dir= "/Users/kevingeng/Laronix/Dataset/Pneumatic/automos" | |
# src_audio_list = sorted(Path(src_audio_dir).glob("**/*.wav")) | |
# src_audio_list = [str(x) for x in src_audio_list] | |
# src_audio_dict = {"audio": src_audio_list} | |
# src_dataset = Dataset.from_dict(src_audio_dict).cast_column("audio", Audio()) | |
# tgt_audio_list = sorted(Path(tgt_audio_dir).glob("**/*.wav")) | |
# tgt_audio_list = [str(x) for x in tgt_audio_list] | |
# tgt_audio_dict = {"audio": tgt_audio_list} | |
# tgt_dataset = Dataset.from_dict(tgt_audio_dict).cast_column("audio", Audio()) | |
# Get Transcription, WER and PPM | |
""" | |
TODO: | |
[DONE]: Automatic generating Config | |
""" | |
import yaml | |
import argparse | |
import sys | |
from pathlib import Path | |
sys.path.append("./src") | |
import lightning_module | |
from UV import plot_UV, get_speech_interval | |
from transformers import pipeline | |
from rich.progress import track | |
from rich import print as rprint | |
import numpy as np | |
import jiwer | |
import pdb | |
import torch.nn as nn | |
import torch | |
import torchaudio | |
import gradio as gr | |
from sys import flags | |
from random import sample | |
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC | |
# root_path = Path(__file__).parents[1] | |
class ChangeSampleRate(nn.Module): | |
def __init__(self, input_rate: int, output_rate: int): | |
super().__init__() | |
self.output_rate = output_rate | |
self.input_rate = input_rate | |
def forward(self, wav: torch.tensor) -> torch.tensor: | |
# Only accepts 1-channel waveform input | |
wav = wav.view(wav.size(0), -1) | |
new_length = wav.size(-1) * self.output_rate // self.input_rate | |
indices = torch.arange(new_length) * ( | |
self.input_rate / self.output_rate | |
) | |
round_down = wav[:, indices.long()] | |
round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)] | |
output = round_down * (1.0 - indices.fmod(1.0)).unsqueeze( | |
0 | |
) + round_up * indices.fmod(1.0).unsqueeze(0) | |
return output | |
model = lightning_module.BaselineLightningModule.load_from_checkpoint( | |
"./src/epoch=3-step=7459.ckpt" | |
).eval() | |
def calc_wer(audio_path, ref): | |
wav, sr = torchaudio.load(audio_path) | |
osr = 16_000 | |
batch = wav.unsqueeze(0).repeat(10, 1, 1) | |
csr = ChangeSampleRate(sr, osr) | |
out_wavs = csr(wav) | |
# ASR | |
trans = p(audio_path)["text"] | |
# WER | |
wer = jiwer.wer( | |
ref, | |
trans, | |
truth_transform=transformation, | |
hypothesis_transform=transformation, | |
) | |
return trans, wer | |
# if __name__ == "__main__": | |
# # Argparse | |
# parser = argparse.ArgumentParser( | |
# prog="get_ref_PPM", | |
# description="Generate Phoneme per Minute (and Voice/Unvoice plot)", | |
# epilog="", | |
# ) | |
# parser.add_argument( | |
# "--tag", | |
# type=str, | |
# default=None, | |
# required=False, | |
# help="ID tag for output *.csv", | |
# ) | |
# parser.add_argument("--ref_txt", type=str, required=True, help="Reference TXT") | |
# parser.add_argument( | |
# "--ref_wavs", type=str, required=True, help="Reference WAVs" | |
# ) | |
# parser.add_argument( | |
# "--output_dir", | |
# type=str, | |
# required=True, | |
# help="Output Directory for *.csv", | |
# ) | |
# parser.add_argument( | |
# "--to_config", | |
# choices=["True", "False"], | |
# default="False", | |
# help="Generating Config from .txt and wavs/*wav", | |
# ) | |
# args = parser.parse_args() | |
# refs = np.loadtxt(args.ref_txt, delimiter="\n", dtype="str") | |
# refs_ids = [x.split()[0] for x in refs] | |
# refs_txt = [" ".join(x.split()[1:]) for x in refs] | |
# ref_wavs = [str(x) for x in sorted(Path(args.ref_wavs).glob("**/*.wav"))] | |
# # pdb.set_trace() | |
# try: | |
# len(refs) == len(ref_wavs) | |
# except ValueError: | |
# print("Error: Text and Wavs don't match") | |
# exit() | |
# # ASR part | |
# p = pipeline("automatic-speech-recognition") | |
# # WER part | |
# transformation = jiwer.Compose( | |
# [ | |
# jiwer.ToLowerCase(), | |
# jiwer.RemoveWhiteSpace(replace_by_space=True), | |
# jiwer.RemoveMultipleSpaces(), | |
# jiwer.ReduceToListOfListOfWords(word_delimiter=" "), | |
# ] | |
# ) | |
# # WPM part | |
# processor = Wav2Vec2Processor.from_pretrained( | |
# "facebook/wav2vec2-xlsr-53-espeak-cv-ft" | |
# ) | |
# phoneme_model = Wav2Vec2ForCTC.from_pretrained( | |
# "facebook/wav2vec2-xlsr-53-espeak-cv-ft" | |
# ) | |
# # phoneme_model = pipeline(model="facebook/wav2vec2-xlsr-53-espeak-cv-ft") | |
# description = """ | |
# MOS prediction demo using UTMOS-strong w/o phoneme encoder model, \ | |
# which is trained on the main track dataset. | |
# This demo only accepts .wav format. Best at 16 kHz sampling rate. | |
# Paper is available [here](https://arxiv.org/abs/2204.02152) | |
# Add ASR based on wav2vec-960, currently only English available. | |
# Add WER interface. | |
# """ | |
# referance_id = gr.Textbox( | |
# value="ID", placeholder="Utter ID", label="Reference_ID" | |
# ) | |
# referance_textbox = gr.Textbox( | |
# value="", placeholder="Input reference here", label="Reference" | |
# ) | |
# # Set up interface | |
# result = [] | |
# result.append("id, trans, wer") | |
# for id, x, y in track( | |
# zip(refs_ids, ref_wavs, refs_txt), | |
# total=len(refs_ids), | |
# description="Loading references information", | |
# ): | |
# trans, wer = calc_wer(x, y) | |
# record = ",".join( | |
# [ | |
# id, | |
# str(trans), | |
# str(wer) | |
# ] | |
# ) | |
# result.append(record) | |
# # Output | |
# if args.tag == None: | |
# args.tag = Path(args.ref_wavs).stem | |
# # Make output_dir | |
# # pdb.set_trace() | |
# Path.mkdir(Path(args.output_dir), exist_ok=True) | |
# # pdb.set_trace() | |
# with open("%s/%s.csv" % (args.output_dir, args.tag), "w") as f: | |
# print("\n".join(result), file=f) | |
# # Generating config | |
# if args.to_config == "True": | |
# config_dict = { | |
# "exp_id": args.tag, | |
# "ref_txt": args.ref_txt, | |
# "ref_feature": "%s/%s.csv" % (args.output_dir, args.tag), | |
# "ref_wavs": args.ref_wavs, | |
# "thre": { | |
# "minppm": 100, | |
# "maxppm": 100, | |
# "WER": 0.1, | |
# "AUTOMOS": 4.0, | |
# }, | |
# "auth": {"username": None, "password": None}, | |
# } | |
# with open("./config/%s.yaml" % args.tag, "w") as config_f: | |
# rprint("Dumping as config ./config/%s.yaml" % args.tag) | |
# rprint(config_dict) | |
# yaml.dump(config_dict, stream=config_f) | |
# rprint("Change parameter ./config/%s.yaml if necessary" % args.tag) | |
# print("Reference Dumping Finished") | |
def dataclean(example): | |
return {"transcription": example["transcription"].upper().translate(str.maketrans('', '', string.punctuation))} | |
# processor = AutoFeatureExtractor.from_pretrained( | |
# "facebook/wav2vec2-base-960h" | |
# ) | |
processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base") | |
def prepare_dataset(batch): | |
audio = batch["audio"] | |
batch = processor(audio["array"], sampling_rate = audio["sampling_rate"], text=batch['transcription']) | |
batch["input_length"] = len(batch["input_values"][0]) | |
return batch | |
src_dataset = load_dataset("audiofolder", data_dir=audio_dir, split="train") | |
src_dataset = src_dataset.map(dataclean) | |
# train_dev / test | |
ds = src_dataset.train_test_split(test_size=0.1) | |
train_dev = ds['train'] | |
# train / dev | |
train_dev = train_dev.train_test_split(test_size=int(len(src_dataset)*0.1)) | |
# train/dev/test | |
train = train_dev['train'] | |
test = ds['test'] | |
dev = train_dev['test'] | |
# pdb.set_trace() | |
import numpy as np | |
def compute_metrics(pred): | |
pred_logits = pred.predictions | |
pred_ids = np.argmax(pred_logits, axis=-1) | |
pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id | |
pred_str = processor.batch_decode(pred_ids) | |
label_str = processor.batch_decode(pred.label_ids, group_tokens=False) | |
wer = wer.compute(predictions=pred_str, references=label_str) | |
return {"wer": wer} | |
pdb.set_trace() | |
# TOKENLIZER("data/samples/5_Laronix1.wav") | |
# pdb.set_trace() | |
# tokenizer | |
tokenizer = AutoTokenizer.from_pretrained("facebook/wav2vec2-base-960h") | |
encoded_train = train.map(prepare_dataset, num_proc=4) | |
from transformers import AutoModelForCTC, TrainingArguments, Trainer | |
model = AutoModelForCTC.from_pretrained( | |
"facebook/wav2vec2-base", | |
ctc_loss_reduction="mean", | |
pad_token_id=processor.tokenizer.pad_token_id, | |
) | |
pdb.set_trace() | |
training_args = TrainingArguments( | |
output_dir="my_awesome_asr_mind_model", | |
per_device_train_batch_size=8, | |
gradient_accumulation_steps=2, | |
learning_rate=1e-5, | |
warmup_steps=500, | |
max_steps=2000, | |
gradient_checkpointing=True, | |
fp16=True, | |
group_by_length=True, | |
evaluation_strategy="steps", | |
per_device_eval_batch_size=8, | |
save_steps=1000, | |
eval_steps=1000, | |
logging_steps=25, | |
load_best_model_at_end=True, | |
metric_for_best_model="wer", | |
greater_is_better=False, | |
push_to_hub=True, | |
) | |
pdb.set_trace() | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=encoded_train["train"], | |
eval_dataset=encoded_train["test"], | |
tokenizer=processor.feature_extractor, | |
compute_metrics=compute_metrics, | |
) | |
pdb.set_trace() | |
# data_collator=data_collator, | |
trainer.train() | |
# x = tokenizer(test['transcription'][0]) | |