{ "cells": [ { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "speechbrain.utils.distributed - distributed_launch flag is disabled, this experiment will be executed without DDP.\n", "speechbrain.lobes.models.huggingface_wav2vec - speechbrain.lobes.models.huggingface_wav2vec - wav2vec 2.0 feature extractor is frozen.\n", "speechbrain.core - Beginning experiment!\n", "speechbrain.core - Experiment folder: TunisianASR/semi_wavlm_large_tunisian_ctc/1234\n", "speechbrain.pretrained.fetching - Fetch hyperparams.yaml: Using existing file/symlink in pretrained_models/asr-wav2vec2-commonvoice-fr/hyperparams.yaml.\n", "speechbrain.pretrained.fetching - Fetch custom.py: Linking to local file in /home/salah/Code_Switched_Tunisian_Speech_Recognition/asr-wav2vec2-commonvoice-fr/custom.py.\n", "speechbrain.lobes.models.huggingface_wav2vec - speechbrain.lobes.models.huggingface_wav2vec - wav2vec 2.0 is frozen.\n", "speechbrain.pretrained.fetching - Fetch wav2vec2.ckpt: Using existing file/symlink in pretrained_models/asr-wav2vec2-commonvoice-fr/wav2vec2.ckpt.\n", "speechbrain.pretrained.fetching - Fetch asr.ckpt: Using existing file/symlink in pretrained_models/asr-wav2vec2-commonvoice-fr/asr.ckpt.\n", "speechbrain.pretrained.fetching - Fetch tokenizer.ckpt: Using existing file/symlink in pretrained_models/asr-wav2vec2-commonvoice-fr/tokenizer.ckpt.\n", "speechbrain.utils.parameter_transfer - Loading pretrained files for: wav2vec2, asr, tokenizer\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at wav2vec2-large-lv60/ were not used when initializing Wav2Vec2Model: ['project_hid.bias', 'project_q.bias', 'project_hid.weight', 'quantizer.codevectors', 'quantizer.weight_proj.weight', 'quantizer.weight_proj.bias', 'project_q.weight']\n", "- This IS expected if you are initializing Wav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing Wav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "speechbrain.lobes.models.huggingface_wav2vec - speechbrain.lobes.models.huggingface_wav2vec - wav2vec 2.0 feature extractor is frozen.\n", "speechbrain.core - Info: auto_mix_prec arg from hparam file is used\n", "speechbrain.core - Info: ckpt_interval_minutes arg from hparam file is used\n", "speechbrain.core - 314.4M trainable parameters in ASRCV\n", "speechbrain.utils.checkpoints - Loading a checkpoint from EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00\n", "speechbrain.core - Info: auto_mix_prec arg from hparam file is used\n", "speechbrain.core - Info: ckpt_interval_minutes arg from hparam file is used\n", "speechbrain.core - 314.4M trainable parameters in ASR\n", "speechbrain.utils.checkpoints - Loading a checkpoint from TunisianASR/semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00\n" ] } ], "source": [ "import os\n", "import sys\n", "import torch\n", "import logging\n", "import speechbrain as sb\n", "from speechbrain.utils.distributed import run_on_main\n", "from hyperpyyaml import load_hyperpyyaml\n", "from pathlib import Path\n", "import torchaudio.transforms as T\n", "from cv_train import ASRCV\n", "import torchaudio\n", "import numpy as np\n", "import kenlm\n", "from pyctcdecode import build_ctcdecoder\n", "import re\n", "from torch.nn.utils.rnn import pad_sequence\n", "import torch.optim as optim\n", "import torch.nn as nn\n", "\n", "\n", "# Commented out IPython magic to ensure Python compatibility.\n", "hparams_file, run_opts, overrides = sb.parse_arguments([\"TunisianASR/train_semi.yaml\"])\n", "\n", "# If distributed_launch=True then\n", "# create ddp_group with the right communication protocol\n", "sb.utils.distributed.ddp_init_group(run_opts)\n", "\n", "with open(hparams_file) as fin:\n", " hparams = load_hyperpyyaml(fin, overrides)\n", "\n", "# Create experiment directory\n", "sb.create_experiment_directory(\n", " experiment_directory=hparams[\"output_folder\"],\n", " hyperparams_to_save=hparams_file,\n", " overrides=overrides,\n", ")\n", "# Dataset prep (parsing Librispeech)\n", "\n", "def dataio_prepare(hparams):\n", " \"\"\"This function prepares the datasets to be used in the brain class.\n", " It also defines the data processing pipeline through user-defined functions.\"\"\"\n", "\n", " # 1. Define datasets\n", " data_folder = hparams[\"data_folder\"]\n", "\n", " train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(\n", " csv_path=hparams[\"train_csv\"], replacements={\"data_root\": data_folder},\n", " )\n", "\n", " if hparams[\"sorting\"] == \"ascending\":\n", " # we sort training data to speed up training and get better results.\n", " train_data = train_data.filtered_sorted(\n", " sort_key=\"duration\",\n", " key_max_value={\"duration\": hparams[\"avoid_if_longer_than\"]},\n", " )\n", " # when sorting do not shuffle in dataloader ! otherwise is pointless\n", " hparams[\"dataloader_options\"][\"shuffle\"] = False\n", "\n", " elif hparams[\"sorting\"] == \"descending\":\n", " train_data = train_data.filtered_sorted(\n", " sort_key=\"duration\",\n", " reverse=True,\n", " key_max_value={\"duration\": hparams[\"avoid_if_longer_than\"]},\n", " )\n", " # when sorting do not shuffle in dataloader ! otherwise is pointless\n", " hparams[\"dataloader_options\"][\"shuffle\"] = False\n", "\n", " elif hparams[\"sorting\"] == \"random\":\n", " pass\n", "\n", " else:\n", " raise NotImplementedError(\n", " \"sorting must be random, ascending or descending\"\n", " )\n", "\n", " valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(\n", " csv_path=hparams[\"valid_csv\"], replacements={\"data_root\": data_folder},\n", " )\n", " # We also sort the validation data so it is faster to validate\n", " valid_data = valid_data.filtered_sorted(sort_key=\"duration\")\n", " test_datasets = {}\n", " for csv_file in hparams[\"test_csv\"]:\n", " name = Path(csv_file).stem\n", " test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv(\n", " csv_path=csv_file, replacements={\"data_root\": data_folder}\n", " )\n", " test_datasets[name] = test_datasets[name].filtered_sorted(\n", " sort_key=\"duration\"\n", " )\n", "\n", " datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()]\n", "\n", "\n", " # 2. Define audio pipeline:\n", " @sb.utils.data_pipeline.takes(\"wav\")\n", " @sb.utils.data_pipeline.provides(\"sig\")\n", " def audio_pipeline(wav):\n", " info = torchaudio.info(wav)\n", " sig = sb.dataio.dataio.read_audio(wav)\n", " if len(sig.shape)>1 :\n", " sig = torch.mean(sig, dim=1)\n", " resampled = torchaudio.transforms.Resample(\n", " info.sample_rate, hparams[\"sample_rate\"],\n", " )(sig)\n", " return resampled\n", "\n", " sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)\n", " label_encoder = sb.dataio.encoder.CTCTextEncoder()\n", "\n", " # 3. Define text pipeline:\n", " @sb.utils.data_pipeline.takes(\"wrd\")\n", " @sb.utils.data_pipeline.provides(\n", " \"wrd\", \"char_list\", \"tokens_list\", \"tokens\"\n", " )\n", " def text_pipeline(wrd):\n", " yield wrd\n", " char_list = list(wrd)\n", " yield char_list\n", " tokens_list = label_encoder.encode_sequence(char_list)\n", " yield tokens_list\n", " tokens = torch.LongTensor(tokens_list)\n", " yield tokens\n", "\n", " sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)\n", " lab_enc_file = os.path.join(hparams[\"save_folder\"], \"label_encoder.txt\")\n", " special_labels = {\n", " \"blank_label\": hparams[\"blank_index\"],\n", " \"unk_label\": hparams[\"unk_index\"]\n", " }\n", " label_encoder.load_or_create(\n", " path=lab_enc_file,\n", " from_didatasets=[train_data],\n", " output_key=\"char_list\",\n", " special_labels=special_labels,\n", " sequence_input=True,\n", " )\n", "\n", " # 4. Set output:\n", " sb.dataio.dataset.set_output_keys(\n", " datasets, [\"id\", \"sig\", \"wrd\", \"char_list\", \"tokens\"],\n", " )\n", " return train_data, valid_data,test_datasets, label_encoder\n", "\n", "class ASR(sb.core.Brain):\n", " def compute_forward(self, batch, stage):\n", " \"\"\"Forward computations from the waveform batches to the output probabilities.\"\"\"\n", "\n", " batch = batch.to(self.device)\n", " wavs, wav_lens = batch.sig\n", " wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)\n", "\n", " if stage == sb.Stage.TRAIN:\n", " if hasattr(self.hparams, \"augmentation\"):\n", " wavs = self.hparams.augmentation(wavs, wav_lens)\n", "\n", " # Forward pass\n", " feats = self.modules.wav2vec2(wavs, wav_lens)\n", " x = self.modules.enc(feats)\n", " logits = self.modules.ctc_lin(x)\n", " p_ctc = self.hparams.log_softmax(logits)\n", "\n", " return p_ctc, wav_lens\n", "\n", " def custom_encode(self,wavs,wav_lens) :\n", " wavs = wavs.to(\"cpu\")\n", " if(wav_lens is not None): wav_lens.to(self.device)\n", "\n", " feats = self.modules.wav2vec2(wavs, wav_lens)\n", " x = self.modules.enc(feats)\n", " logits = self.modules.ctc_lin(x)\n", " p_ctc = self.hparams.log_softmax(logits)\n", "\n", " return feats,p_ctc\n", "\n", "\n", "\n", " def compute_objectives(self, predictions, batch, stage):\n", " \"\"\"Computes the loss (CTC) given predictions and targets.\"\"\"\n", "\n", " p_ctc, wav_lens = predictions\n", "\n", " ids = batch.id\n", " tokens, tokens_lens = batch.tokens\n", "\n", " loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)\n", "\n", " if stage != sb.Stage.TRAIN:\n", " predicted_tokens = sb.decoders.ctc_greedy_decode(\n", " p_ctc, wav_lens, blank_id=self.hparams.blank_index\n", " )\n", " # Decode token terms to words\n", " if self.hparams.use_language_modelling:\n", " predicted_words = []\n", " for logs in p_ctc:\n", " text = decoder.decode(logs.detach().cpu().numpy())\n", " predicted_words.append(text.split(\" \"))\n", " else:\n", " predicted_words = [\n", " \"\".join(self.tokenizer.decode_ndim(utt_seq)).split(\" \")\n", " for utt_seq in predicted_tokens\n", " ]\n", " # Convert indices to words\n", " target_words = [wrd.split(\" \") for wrd in batch.wrd]\n", "\n", " self.wer_metric.append(ids, predicted_words, target_words)\n", " self.cer_metric.append(ids, predicted_words, target_words)\n", "\n", " return loss\n", "\n", " def fit_batch(self, batch):\n", " \"\"\"Train the parameters given a single batch in input\"\"\"\n", " should_step = self.step % self.grad_accumulation_factor == 0\n", " # Managing automatic mixed precision\n", " # TOFIX: CTC fine-tuning currently is unstable\n", " # This is certainly due to CTC being done in fp16 instead of fp32\n", " if self.auto_mix_prec:\n", " with torch.cuda.amp.autocast():\n", " with self.no_sync():\n", " outputs = self.compute_forward(batch, sb.Stage.TRAIN)\n", " loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)\n", " with self.no_sync(not should_step):\n", " self.scaler.scale(\n", " loss / self.grad_accumulation_factor\n", " ).backward()\n", " if should_step:\n", "\n", " if not self.hparams.wav2vec2.freeze:\n", " self.scaler.unscale_(self.wav2vec_optimizer)\n", " self.scaler.unscale_(self.model_optimizer)\n", " if self.check_gradients(loss):\n", " if not self.hparams.wav2vec2.freeze:\n", " if self.optimizer_step >= self.hparams.warmup_steps:\n", " self.scaler.step(self.wav2vec_optimizer)\n", " self.scaler.step(self.model_optimizer)\n", " self.scaler.update()\n", " self.zero_grad()\n", " self.optimizer_step += 1\n", " else:\n", " # This is mandatory because HF models have a weird behavior with DDP\n", " # on the forward pass\n", " with self.no_sync():\n", " outputs = self.compute_forward(batch, sb.Stage.TRAIN)\n", "\n", " loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)\n", "\n", " with self.no_sync(not should_step):\n", " (loss / self.grad_accumulation_factor).backward()\n", " if should_step:\n", " if self.check_gradients(loss):\n", " if not self.hparams.wav2vec2.freeze:\n", " if self.optimizer_step >= self.hparams.warmup_steps:\n", " self.wav2vec_optimizer.step()\n", " self.model_optimizer.step()\n", " self.zero_grad()\n", " self.optimizer_step += 1\n", "\n", " self.on_fit_batch_end(batch, outputs, loss, should_step)\n", " return loss.detach().cpu()\n", "\n", " def evaluate_batch(self, batch, stage):\n", " \"\"\"Computations needed for validation/test batches\"\"\"\n", " predictions = self.compute_forward(batch, stage=stage)\n", " with torch.no_grad():\n", " loss = self.compute_objectives(predictions, batch, stage=stage)\n", " return loss.detach()\n", "\n", " def on_stage_start(self, stage, epoch):\n", " \"\"\"Gets called at the beginning of each epoch\"\"\"\n", " if stage != sb.Stage.TRAIN:\n", " self.cer_metric = self.hparams.cer_computer()\n", " self.wer_metric = self.hparams.error_rate_computer()\n", "\n", " def on_stage_end(self, stage, stage_loss, epoch):\n", " \"\"\"Gets called at the end of an epoch.\"\"\"\n", " # Compute/store important stats\n", " stage_stats = {\"loss\": stage_loss}\n", " if stage == sb.Stage.TRAIN:\n", " self.train_stats = stage_stats\n", " else:\n", " stage_stats[\"CER\"] = self.cer_metric.summarize(\"error_rate\")\n", " stage_stats[\"WER\"] = self.wer_metric.summarize(\"error_rate\")\n", "\n", " # Perform end-of-iteration things, like annealing, logging, etc.\n", " if stage == sb.Stage.VALID:\n", " old_lr_model, new_lr_model = self.hparams.lr_annealing_model(\n", " stage_stats[\"loss\"]\n", " )\n", " old_lr_wav2vec, new_lr_wav2vec = self.hparams.lr_annealing_wav2vec(\n", " stage_stats[\"loss\"]\n", " )\n", " sb.nnet.schedulers.update_learning_rate(\n", " self.model_optimizer, new_lr_model\n", " )\n", " if not self.hparams.wav2vec2.freeze:\n", " sb.nnet.schedulers.update_learning_rate(\n", " self.wav2vec_optimizer, new_lr_wav2vec\n", " )\n", " self.hparams.train_logger.log_stats(\n", " stats_meta={\n", " \"epoch\": epoch,\n", " \"lr_model\": old_lr_model,\n", " \"lr_wav2vec\": old_lr_wav2vec,\n", " },\n", " train_stats=self.train_stats,\n", " valid_stats=stage_stats,\n", " )\n", " self.checkpointer.save_and_keep_only(\n", " meta={\"WER\": stage_stats[\"WER\"]}, min_keys=[\"WER\"],\n", " )\n", " elif stage == sb.Stage.TEST:\n", " self.hparams.train_logger.log_stats(\n", " stats_meta={\"Epoch loaded\": self.hparams.epoch_counter.current},\n", " test_stats=stage_stats,\n", " )\n", " with open(self.hparams.wer_file, \"w\") as w:\n", " self.wer_metric.write_stats(w)\n", "\n", " def init_optimizers(self):\n", " \"Initializes the wav2vec2 optimizer and model optimizer\"\n", "\n", " # If the wav2vec encoder is unfrozen, we create the optimizer\n", " if not self.hparams.wav2vec2.freeze:\n", " self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(\n", " self.modules.wav2vec2.parameters()\n", " )\n", " if self.checkpointer is not None:\n", " self.checkpointer.add_recoverable(\n", " \"wav2vec_opt\", self.wav2vec_optimizer\n", " )\n", "\n", " self.model_optimizer = self.hparams.model_opt_class(\n", " self.hparams.model.parameters()\n", " )\n", "\n", " if self.checkpointer is not None:\n", " self.checkpointer.add_recoverable(\"modelopt\", self.model_optimizer)\n", "\n", " def zero_grad(self, set_to_none=False):\n", " if not self.hparams.wav2vec2.freeze:\n", " self.wav2vec_optimizer.zero_grad(set_to_none)\n", " self.model_optimizer.zero_grad(set_to_none)\n", "\n", "\n", "from speechbrain.pretrained import EncoderASR,EncoderDecoderASR\n", "french_asr_model = EncoderASR.from_hparams(source=\"asr-wav2vec2-commonvoice-fr\", savedir=\"pretrained_models/asr-wav2vec2-commonvoice-fr\").cuda()\n", "french_asr_model.to(\"cpu\")\n", "cvhparams_file, cvrun_opts, cvoverrides = sb.parse_arguments([\"EnglishCV/train_en_with_wav2vec.yaml\"])\n", "with open(cvhparams_file) as cvfin:\n", " cvhparams = load_hyperpyyaml(cvfin, cvoverrides)\n", "english_asr_model = ASRCV(\n", " modules=cvhparams[\"modules\"],\n", " hparams=cvhparams,\n", " run_opts=cvrun_opts,\n", " checkpointer=cvhparams[\"checkpointer\"],\n", " )\n", "english_asr_model.modules.to(\"cpu\")\n", "english_asr_model.checkpointer.recover_if_possible()\n", "asr_brain = ASR(\n", " modules=hparams[\"modules\"],\n", " hparams=hparams,\n", " run_opts=run_opts,\n", " checkpointer=hparams[\"checkpointer\"],\n", ")\n", "asr_brain.modules.to(\"cpu\")\n", "asr_brain.checkpointer.recover_if_possible()\n", "asr_brain.modules.eval()\n", "english_asr_model.modules.eval()\n", "french_asr_model.mods.eval()\n", "asr_brain.modules.to(\"cpu\")\n", "\n", "# Commented out IPython magic to ensure Python compatibility.\n", "# %ls\n", "\n", "#UTILS FUNCTIOJNS\n", "def get_size_dimensions(arr):\n", " size_dimensions = []\n", " while isinstance(arr, list):\n", " size_dimensions.append(len(arr))\n", " arr = arr[0]\n", " return size_dimensions\n", "\n", "def scale_array(batch,n):\n", " scaled_batch = []\n", "\n", " for array in batch:\n", " if(n < len(array)): raise ValueError(\"Cannot scale Array down\")\n", "\n", " repeat = round(n/len(array))+1\n", " scaled_length_array= []\n", "\n", " for i in array:\n", " for j in range(repeat) :\n", " if(len(scaled_length_array) == n): break\n", " scaled_length_array.append(i)\n", "\n", " scaled_batch.append(scaled_length_array)\n", "\n", " return torch.tensor(scaled_batch)\n", "\n", "\n", "def load_paths(wavs_path):\n", " waveforms = []\n", " for path in wavs_path :\n", " waveform, _ = torchaudio.load(path)\n", " waveforms.append(waveform.squeeze(0))\n", " # normalize array length to the bigger arrays by pading with 0's\n", " padded_arrays = pad_sequence(waveforms, batch_first=True)\n", " return torch.tensor(padded_arrays)\n", "\n", "\n", "\n", "device = 'cuda'\n", "verbose = 0\n", "#FLOW LEVEL FUNCTIONS\n", "def merge_strategy(embeddings1, embeddings2, embeddings3,post1, post2,post3):\n", "\n", "\n", " post1 = post1.to(device)\n", " post2 = post2.to(device)\n", " post3 = post3.to(device)\n", " embeddings1 = embeddings1.to(device)\n", " embeddings2 = embeddings2.to(device)\n", " embeddings3 = embeddings3.to(device)\n", "\n", " posteriograms_merged = torch.cat((post1,post2,post3),dim=2)\n", " embeddings_merged = torch.cat((embeddings1,embeddings2,embeddings3),dim=2)\n", "\n", " if(verbose !=0):\n", " print('MERGED POST ',posteriograms_merged.shape)\n", " print('MERGED emb ',embeddings_merged.shape)\n", "\n", " return torch.cat((posteriograms_merged,embeddings_merged),dim=2).to(device)\n", "\n", "def decode(model,wavs,wav_lens):\n", "\n", " with torch.no_grad():\n", " wav_lens = wav_lens.to(model.device)\n", " encoder_out = model.encode_batch(wavs, wav_lens)\n", " predictions = model.decoding_function(encoder_out, wav_lens)\n", " return predictions\n", "\n", "def middle_layer(batch, lens):\n", "\n", " tn_embeddings, tn_posteriogram = asr_brain.custom_encode(batch,None)\n", "\n", " fr_embeddings = french_asr_model.mods.encoder.wav2vec2(batch)\n", " fr_posteriogram =french_asr_model.encode_batch(batch,lens)\n", " en_embeddings = english_asr_model.modules.wav2vec2(batch, lens)\n", " x = english_asr_model.modules.enc(en_embeddings)\n", " en_posteriogram = english_asr_model.modules.ctc_lin(x)\n", " #scores, en_posteriogram = english_asr_model.mods.decoder(en_embeddings ,lens)\n", " if(verbose !=0):\n", " print('[EMBEDDINGS] FR:',fr_embeddings.shape, \"EN:\",en_embeddings.shape, \"TN:\", tn_embeddings.shape)\n", " print('[POSTERIOGRAM] FR:',fr_posteriogram.shape, \"EN:\",en_posteriogram.shape,\"TN:\",tn_posteriogram.shape)\n", "\n", "\n", " bilangual_sample = merge_strategy(fr_embeddings,en_embeddings,tn_embeddings,fr_posteriogram,en_posteriogram,tn_posteriogram)\n", " return bilangual_sample\n", "\n", "class Mixer(sb.core.Brain):\n", "\n", " def compute_forward(self, batch, stage):\n", " \"\"\"Forward computations from the waveform batches to the output probabilities.\"\"\"\n", " wavs, wav_lens = batch.sig\n", " wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)\n", "\n", " if stage == sb.Stage.TRAIN:\n", " if hasattr(self.hparams, \"augmentation\"):\n", " wavs = self.hparams.augmentation(wavs, wav_lens)\n", "\n", " multi_langual_feats = middle_layer(wavs, wav_lens)\n", " multi_langual_feats= multi_langual_feats.to(device)\n", " feats, _ = self.modules.enc(multi_langual_feats)\n", " logits = self.modules.ctc_lin(feats)\n", " p_ctc = self.hparams.log_softmax(logits)\n", " \n", " if stage!= sb.Stage.TRAIN:\n", " p_tokens = sb.decoders.ctc_greedy_decode(\n", " p_ctc, wav_lens, blank_id=self.hparams.blank_index\n", " )\n", " else : \n", " p_tokens = None\n", " return p_ctc, wav_lens, p_tokens\n", " \n", " \n", " def treat_wav(self,sig):\n", " multi_langual_feats = middle_layer(sig.to(\"cpu\"), torch.tensor([1]).to(\"cpu\"))\n", " multi_langual_feats= multi_langual_feats.to(device)\n", " feats, _ = self.modules.enc(multi_langual_feats)\n", " logits = self.modules.ctc_lin(feats)\n", " p_ctc = self.hparams.log_softmax(logits)\n", " predicted_words =[]\n", " for logs in p_ctc:\n", " text = decoder.decode(logs.detach().cpu().numpy())\n", " predicted_words.append(text.split(\" \"))\n", " return \" \".join(predicted_words[0])\n", " \n", "\n", " def compute_objectives(self, predictions, batch, stage):\n", " \"\"\"Computes the loss (CTC) given predictions and targets.\"\"\"\n", "\n", " p_ctc, wav_lens , predicted_tokens= predictions\n", "\n", " ids = batch.id\n", " tokens, tokens_lens = batch.tokens\n", "\n", " loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)\n", "\n", "\n", " if stage == sb.Stage.VALID:\n", " predicted_words = [\n", " \"\".join(self.tokenizer.decode_ndim(utt_seq)).split(\" \")\n", " for utt_seq in predicted_tokens\n", " ]\n", " target_words = [wrd.split(\" \") for wrd in batch.wrd]\n", " self.wer_metric.append(ids, predicted_words, target_words)\n", " self.cer_metric.append(ids, predicted_words, target_words)\n", " if stage ==sb.Stage.TEST : \n", " if self.hparams.language_modelling:\n", " predicted_words = []\n", " for logs in p_ctc:\n", " text = decoder.decode(logs.detach().cpu().numpy())\n", " predicted_words.append(text.split(\" \"))\n", " else : \n", " predicted_words = [\n", " \"\".join(self.tokenizer.decode_ndim(utt_seq)).split(\" \")\n", " for utt_seq in predicted_tokens\n", " ]\n", "\n", " target_words = [wrd.split(\" \") for wrd in batch.wrd]\n", " self.wer_metric.append(ids, predicted_words, target_words)\n", " self.cer_metric.append(ids, predicted_words, target_words)\n", "\n", " return loss\n", "\n", " def fit_batch(self, batch):\n", " \"\"\"Train the parameters given a single batch in input\"\"\"\n", " should_step = self.step % self.grad_accumulation_factor == 0\n", " # Managing automatic mixed precision\n", " # TOFIX: CTC fine-tuning currently is unstable\n", " # This is certainly due to CTC being done in fp16 instead of fp32\n", " if self.auto_mix_prec:\n", " with torch.cuda.amp.autocast():\n", " with self.no_sync():\n", " outputs = self.compute_forward(batch, sb.Stage.TRAIN)\n", " loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)\n", " with self.no_sync(not should_step):\n", " self.scaler.scale(\n", " loss / self.grad_accumulation_factor\n", " ).backward()\n", " if should_step:\n", "\n", "\n", " self.scaler.unscale_(self.model_optimizer)\n", " if self.check_gradients(loss):\n", " self.scaler.step(self.model_optimizer)\n", " self.scaler.update()\n", " self.zero_grad()\n", " self.optimizer_step += 1\n", " else:\n", " # This is mandatory because HF models have a weird behavior with DDP\n", " # on the forward pass\n", " with self.no_sync():\n", " outputs = self.compute_forward(batch, sb.Stage.TRAIN)\n", "\n", " loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)\n", "\n", " with self.no_sync(not should_step):\n", " (loss / self.grad_accumulation_factor).backward()\n", " if should_step:\n", " if self.check_gradients(loss):\n", " self.model_optimizer.step()\n", " self.zero_grad()\n", " self.optimizer_step += 1\n", "\n", " self.on_fit_batch_end(batch, outputs, loss, should_step)\n", " return loss.detach().cpu()\n", "\n", " def evaluate_batch(self, batch, stage):\n", " \"\"\"Computations needed for validation/test batches\"\"\"\n", " predictions = self.compute_forward(batch, stage=stage)\n", " with torch.no_grad():\n", " loss = self.compute_objectives(predictions, batch, stage=stage)\n", " return loss.detach()\n", "\n", " def on_stage_start(self, stage, epoch):\n", " \"\"\"Gets called at the beginning of each epoch\"\"\"\n", " if stage != sb.Stage.TRAIN:\n", " self.cer_metric = self.hparams.cer_computer()\n", " self.wer_metric = self.hparams.error_rate_computer()\n", "\n", " def on_stage_end(self, stage, stage_loss, epoch):\n", " \"\"\"Gets called at the end of an epoch.\"\"\"\n", " # Compute/store important stats\n", " stage_stats = {\"loss\": stage_loss}\n", " if stage == sb.Stage.TRAIN:\n", " self.train_stats = stage_stats\n", " else:\n", " stage_stats[\"CER\"] = self.cer_metric.summarize(\"error_rate\")\n", " stage_stats[\"WER\"] = self.wer_metric.summarize(\"error_rate\")\n", "\n", " # Perform end-of-iteration things, like annealing, logging, etc.\n", " if stage == sb.Stage.VALID:\n", " old_lr_model, new_lr_model = self.hparams.lr_annealing_model(\n", " stage_stats[\"loss\"]\n", " )\n", " sb.nnet.schedulers.update_learning_rate(\n", " self.model_optimizer, new_lr_model\n", " )\n", " self.hparams.train_logger.log_stats(\n", " stats_meta={\n", " \"epoch\": epoch,\n", " \"lr_model\": old_lr_model,\n", " },\n", " train_stats=self.train_stats,\n", " valid_stats=stage_stats,\n", " )\n", " self.checkpointer.save_and_keep_only(\n", " meta={\"WER\": stage_stats[\"WER\"]}, min_keys=[\"WER\"],\n", " )\n", " elif stage == sb.Stage.TEST:\n", " self.hparams.train_logger.log_stats(\n", " stats_meta={\"Epoch loaded\": self.hparams.epoch_counter.current},\n", " test_stats=stage_stats,\n", " )\n", " with open(self.hparams.wer_file, \"w\") as w:\n", " self.wer_metric.write_stats(w)\n", "\n", " def init_optimizers(self):\n", "\n", " self.model_optimizer = self.hparams.model_opt_class(\n", " self.hparams.model.parameters()\n", " )\n", "\n", " if self.checkpointer is not None:\n", " self.checkpointer.add_recoverable(\"modelopt\", self.model_optimizer)\n", "\n", " def zero_grad(self, set_to_none=False):\n", "\n", " self.model_optimizer.zero_grad(set_to_none)\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "speechbrain.utils.distributed - distributed_launch flag is disabled, this experiment will be executed without DDP.\n", "speechbrain.core - Beginning experiment!\n", "speechbrain.core - Experiment folder: results/non_semi_final_stac\n", "speechbrain.dataio.encoder - Load called, but CTCTextEncoder is not empty. Loaded data will overwrite everything. This is normal if there is e.g. an unk label defined at init.\n", "pyctcdecode.decoder - Using arpa instead of binary LM file, decoder instantiation might be slow.\n", "pyctcdecode.alphabet - Alphabet determined to be of regular style.\n", "pyctcdecode.alphabet - Unigrams and labels don't seem to agree.\n", "speechbrain.core - Info: auto_mix_prec arg from hparam file is used\n", "speechbrain.core - 60.1M trainable parameters in Mixer\n", "pyctcdecode.decoder - Using arpa instead of binary LM file, decoder instantiation might be slow.\n", "pyctcdecode.alphabet - Alphabet determined to be of regular style.\n", "pyctcdecode.alphabet - Unigrams and labels don't seem to agree.\n", "speechbrain.utils.checkpoints - Loading a checkpoint from TunisianASR/semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ ":119: GradioDeprecationWarning: `optional` parameter is deprecated, and it has no effect\n", " inputs=[gr.Audio(source=\"microphone\", type='filepath', label = \"record\", optional = True),\n", ":120: GradioDeprecationWarning: `optional` parameter is deprecated, and it has no effect\n", " gr.Audio(source=\"upload\", type='filepath', label=\"filein\", optional=True)]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7860\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/gradio/processing_utils.py:188: UserWarning: Trying to convert audio automatically from int32 to 16-bit int format.\n", " warnings.warn(warning.format(data.dtype))\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "tensor([[ 0.0000, 0.0000, 0.0000, ..., 0.0075, -0.0042, -0.0031]])\n", "tensor([[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 3.0518e-05,\n", " -3.0518e-05, 0.0000e+00]])\n" ] } ], "source": [ "hparams_file, run_opts, overrides = sb.parse_arguments([\"cs.yaml\"])\n", "\n", "# If distributed_launch=True then\n", "# create ddp_group with the right communication protocol\n", "sb.utils.distributed.ddp_init_group(run_opts)\n", "\n", "with open(hparams_file) as fin:\n", " hparams = load_hyperpyyaml(fin, overrides)\n", "\n", "# Create experiment directory\n", "sb.create_experiment_directory(\n", " experiment_directory=hparams[\"output_folder\"],\n", " hyperparams_to_save=hparams_file,\n", " overrides=overrides,\n", ")\n", "def read_labels_file(labels_file):\n", " with open(labels_file, \"r\",encoding=\"utf-8\") as lf:\n", " lines = lf.read().splitlines()\n", " division = \"===\"\n", " numbers = {}\n", " for line in lines :\n", " if division in line :\n", " break\n", " string, number = line.split(\"=>\")\n", " number = int(number)\n", " string = string[1:-2]\n", " numbers[number] = string\n", " return [numbers[x] for x in range(len(numbers))]\n", "\n", "label_encoder = sb.dataio.encoder.CTCTextEncoder()\n", "\n", "lab_enc_file = os.path.join(hparams[\"save_folder\"], \"label_encoder.txt\")\n", "special_labels = {\n", " \"blank_label\": hparams[\"blank_index\"],\n", " \"unk_label\": hparams[\"unk_index\"]\n", "}\n", "label_encoder.load_or_create(\n", " path=lab_enc_file,\n", " from_didatasets=[[]],\n", " output_key=\"char_list\",\n", " special_labels=special_labels,\n", " sequence_input=True,\n", ")\n", "\n", "\n", "labels = read_labels_file(os.path.join(hparams[\"save_folder\"], \"label_encoder.txt\"))\n", "labels = [\"\"] + labels[1:-1] + [\"1\"] \n", "if hparams[\"language_modelling\"]:\n", " decoder = build_ctcdecoder(\n", " labels,\n", " kenlm_model_path=hparams[\"ngram_lm_path\"], # either .arpa or .bin file\n", " alpha=0.5, # tuned on a val set\n", " beta=1, # tuned on a val set\n", " )\n", "\n", "\n", "\n", "\n", "mixer = Mixer(\n", " modules=hparams[\"modules\"],\n", " hparams=hparams,\n", " run_opts=run_opts,\n", " checkpointer=hparams[\"checkpointer\"],\n", ")\n", "mixer.tokenizer = label_encoder\n", "\n", "\n", "label_encoder = sb.dataio.encoder.CTCTextEncoder()\n", "\n", "\n", "# We dynamicaly add the tokenizer to our brain class.\n", "# NB: This tokenizer corresponds to the one used for the LM!!\n", "\n", "decoder = build_ctcdecoder(\n", " labels,\n", " kenlm_model_path= \"arpas/everything.arpa\", # either .arpa or .bin file\n", " alpha=0.5, # tuned on a val set\n", " beta=1, # tuned on a val set\n", ")\n", "\n", "run_opts[\"device\"]=\"cpu\"\n", "\n", "\n", "device = \"cpu\"\n", "mixer.device= \"cpu\"\n", "mixer.modules.to(\"cpu\")\n", "\n", "from enum import Enum, auto\n", "class Stage(Enum):\n", " TRAIN = auto()\n", " VALID = auto()\n", " TEST = auto()\n", "\n", "asr_brain.on_evaluate_start()\n", "asr_brain.modules.eval()\n", "\n", "\n", "import gradio as gr\n", "\n", "def treat_wav_file(file_mic,file_upload ,asr=mixer, device=\"cpu\") :\n", " if (file_mic is not None) and (file_upload is not None):\n", " warn_output = \"WARNING: You've uploaded an audio file and used the microphone. The recorded file from the microphone will be used and the uploaded audio will be discarded.\\n\"\n", " wav = file_mic\n", " elif (file_mic is None) and (file_upload is None):\n", " return \"ERROR: You have to either use the microphone or upload an audio file\"\n", " elif file_mic is not None:\n", " wav = file_mic\n", " else:\n", " wav = file_upload\n", " sig, sr = torchaudio.load(wav)\n", " tensor_wav = sig.to(device)\n", " resampled = torchaudio.functional.resample( tensor_wav, sr, 16000)\n", " sentence = asr.treat_wav(resampled)\n", " return sentence\n", "\n", "gr.Interface(\n", " fn=treat_wav_file, \n", " inputs=[gr.Audio(source=\"microphone\", type='filepath', label = \"record\", optional = True), \n", " gr.Audio(source=\"upload\", type='filepath', label=\"filein\", optional=True)]\n", " ,outputs=\"text\").launch(share= False, debug = True)\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 5 }