|
|
|
|
|
import logging |
|
import sys |
|
from pathlib import Path |
|
import os |
|
|
|
import librosa |
|
|
|
import torch |
|
from torch.utils.data import DataLoader |
|
from hyperpyyaml import load_hyperpyyaml |
|
|
|
import speechbrain as sb |
|
from speechbrain.utils.distributed import if_main_process, run_on_main |
|
|
|
from jiwer import wer, cer |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
class ASR(sb.Brain): |
|
def compute_forward(self, batch, stage): |
|
"""Forward computations from the waveform batches to the output probabilities.""" |
|
batch = batch.to(self.device) |
|
sig, self.sig_lens = batch.sig |
|
tokens_bos, _ = batch.tokens_bos |
|
sig, self.sig_lens = sig.to(self.device), self.sig_lens.to(self.device) |
|
|
|
|
|
if stage == sb.Stage.TRAIN: |
|
sig, self.sig_lens = self.hparams.wav_augment(sig, self.sig_lens) |
|
|
|
|
|
encoded_outputs = self.modules.encoder_w2v2(sig.detach()) |
|
embedded_tokens = self.modules.embedding(tokens_bos) |
|
decoder_outputs, _ = self.modules.decoder(embedded_tokens, encoded_outputs, self.sig_lens) |
|
|
|
|
|
logits = self.modules.seq_lin(decoder_outputs) |
|
predictions = {"seq_logprobs": self.hparams.log_softmax(logits)} |
|
|
|
if self.is_ctc_active(stage): |
|
|
|
ctc_logits = self.modules.ctc_lin(encoded_outputs) |
|
predictions["ctc_logprobs"] = self.hparams.log_softmax(ctc_logits) |
|
elif stage == sb.Stage.VALID: |
|
predictions["tokens"], _, _, _ = self.hparams.greedy_search(encoded_outputs, self.sig_lens) |
|
elif stage == sb.Stage.TEST: |
|
predictions["tokens"], _, _, _ = self.hparams.test_search(encoded_outputs, self.sig_lens) |
|
|
|
return predictions |
|
|
|
|
|
def is_ctc_active(self, stage): |
|
"""Check if CTC is currently active. |
|
|
|
Arguments |
|
--------- |
|
stage : sb.Stage |
|
Currently executing stage. |
|
""" |
|
if stage != sb.Stage.TRAIN: |
|
return False |
|
current_epoch = self.hparams.epoch_counter.current |
|
return current_epoch <= self.hparams.number_of_ctc_epochs |
|
|
|
|
|
|
|
def compute_objectives(self, predictions, batch, stage): |
|
"""Computes the loss (CTC+NLL) given predictions and targets.""" |
|
ids = batch.id |
|
tokens_eos, tokens_eos_lens = batch.tokens_eos |
|
tokens, tokens_lens = batch.tokens |
|
|
|
loss = self.hparams.nll_cost(log_probabilities=predictions["seq_logprobs"], targets=tokens_eos, length=tokens_eos_lens) |
|
|
|
if self.is_ctc_active(stage): |
|
|
|
loss_ctc = self.hparams.ctc_cost(predictions["ctc_logprobs"], tokens, self.sig_lens, tokens_lens) |
|
loss *= 1 - self.hparams.ctc_weight |
|
loss += self.hparams.ctc_weight * loss_ctc |
|
|
|
if stage != sb.Stage.TRAIN: |
|
predicted_words = [self.hparams.tokenizer.decode_ids(prediction).split(" ") for prediction in predictions["tokens"]] |
|
target_words = [words.split(" ") for words in batch.transcript] |
|
self.wer_metric.append(ids, predicted_words, target_words) |
|
self.cer_metric.append(ids, predicted_words, target_words) |
|
|
|
return loss |
|
|
|
def on_stage_start(self, stage, epoch): |
|
"""Gets called at the beginning of each epoch""" |
|
if stage != sb.Stage.TRAIN: |
|
self.cer_metric = self.hparams.cer_computer() |
|
self.wer_metric = self.hparams.error_rate_computer() |
|
|
|
def on_stage_end(self, stage, stage_loss, epoch): |
|
"""Gets called at the end of a epoch.""" |
|
|
|
stage_stats = {"loss": stage_loss} |
|
if stage == sb.Stage.TRAIN: |
|
self.train_stats = stage_stats |
|
else: |
|
stage_stats["CER"] = self.cer_metric.summarize("error_rate") |
|
stage_stats["WER"] = self.wer_metric.summarize("error_rate") |
|
|
|
|
|
if stage == sb.Stage.VALID: |
|
old_lr, new_lr = self.hparams.lr_annealing(stage_stats["WER"]) |
|
sb.nnet.schedulers.update_learning_rate(self.optimizer, new_lr) |
|
self.hparams.train_logger.log_stats( |
|
stats_meta={"epoch": epoch, "lr": old_lr}, |
|
train_stats=self.train_stats, |
|
valid_stats=stage_stats, |
|
) |
|
self.checkpointer.save_and_keep_only( |
|
meta={"WER": stage_stats["WER"]}, |
|
min_keys=["WER"], |
|
) |
|
elif stage == sb.Stage.TEST: |
|
self.hparams.train_logger.log_stats( |
|
stats_meta={"Epoch loaded": self.hparams.epoch_counter.current}, |
|
test_stats=stage_stats, |
|
) |
|
if if_main_process(): |
|
with open(self.hparams.test_wer_file, "w") as w: |
|
self.wer_metric.write_stats(w) |
|
|
|
def run_inference( |
|
self, |
|
dataset, |
|
min_key, |
|
loader_kwargs, |
|
): |
|
|
|
|
|
if not isinstance(dataset, DataLoader): |
|
loader_kwargs["ckpt_prefix"] = None |
|
dataset = self.make_dataloader( |
|
dataset, sb.Stage.TEST, **loader_kwargs |
|
) |
|
|
|
self.checkpointer.recover_if_possible(min_key=min_key) |
|
self.modules.eval() |
|
|
|
with torch.no_grad(): |
|
true_labels = [] |
|
pred_labels = [] |
|
for batch in dataset: |
|
|
|
|
|
|
|
predictions = self.compute_forward(batch, stage=sb.Stage.TEST) |
|
|
|
pred_batch = [] |
|
predicted_words = [] |
|
|
|
predicted_words = [self.hparams.tokenizer.decode_ids(prediction).split(" ") for prediction in predictions["tokens"]] |
|
for sent in predicted_words: |
|
|
|
sent = filter_repetitions(sent, 3) |
|
sent = " ".join(sent) |
|
pred_batch.append(sent) |
|
|
|
pred_labels.append(pred_batch[0]) |
|
true_labels.append(batch.transcript[0]) |
|
|
|
print('WER: ', wer(true_labels, pred_labels) * 100) |
|
print('CER: ', cer(true_labels, pred_labels) * 100) |
|
|
|
|
|
def filter_repetitions(seq, max_repetition_length): |
|
seq = list(seq) |
|
output = [] |
|
max_n = len(seq) // 2 |
|
for n in range(max_n, 0, -1): |
|
max_repetitions = max(max_repetition_length // n, 1) |
|
|
|
|
|
if (len(seq) <= n*2) or (len(seq) <= max_repetition_length): |
|
continue |
|
iterator = enumerate(seq) |
|
|
|
buffers = [[next(iterator)[1]] for _ in range(n)] |
|
for seq_index, token in iterator: |
|
current_buffer = seq_index % n |
|
if token != buffers[current_buffer][-1]: |
|
|
|
buf_len = sum(map(len, buffers)) |
|
flush_start = (current_buffer-buf_len) % n |
|
|
|
for flush_index in range(buf_len - buf_len%n): |
|
if (buf_len - flush_index) > n-1: |
|
to_flush = buffers[(flush_index + flush_start) % n].pop(0) |
|
else: |
|
to_flush = None |
|
|
|
if (flush_index // n < max_repetitions) and to_flush is not None: |
|
output.append(to_flush) |
|
elif (flush_index // n >= max_repetitions) and to_flush is None: |
|
output.append(to_flush) |
|
buffers[current_buffer].append(token) |
|
|
|
current_buffer += 1 |
|
buf_len = sum(map(len, buffers)) |
|
flush_start = (current_buffer-buf_len) % n |
|
for flush_index in range(buf_len): |
|
to_flush = buffers[(flush_index + flush_start) % n].pop(0) |
|
|
|
if flush_index // n < max_repetitions: |
|
output.append(to_flush) |
|
seq = [] |
|
to_delete = 0 |
|
for token in output: |
|
if token is None: |
|
to_delete += 1 |
|
elif to_delete > 0: |
|
to_delete -= 1 |
|
else: |
|
seq.append(token) |
|
output = [] |
|
return seq |
|
|
|
def dataio_prepare(hparams): |
|
"""This function prepares the datasets to be used in the brain class. |
|
It also defines the data processing pipeline through user-defined functions. |
|
""" |
|
data_folder = hparams["data_folder"] |
|
|
|
train_data = sb.dataio.dataset.DynamicItemDataset.from_json(json_path=os.path.join(hparams["data_folder"], "train.json"), replacements={"data_root": data_folder}) |
|
train_data = train_data.filtered_sorted(sort_key="duration") |
|
hparams["train_dataloader_opts"]["shuffle"] = False |
|
|
|
valid_data = sb.dataio.dataset.DynamicItemDataset.from_json(json_path=os.path.join(hparams["data_folder"], "dev.json"), replacements={"data_root": data_folder}) |
|
valid_data = valid_data.filtered_sorted(sort_key="duration") |
|
|
|
test_data = sb.dataio.dataset.DynamicItemDataset.from_json(json_path=os.path.join(hparams["data_folder"], "test.json"), replacements={"data_root": data_folder}) |
|
|
|
|
|
datasets = [train_data, valid_data, test_data] |
|
|
|
|
|
|
|
tokenizer = hparams["tokenizer"] |
|
|
|
|
|
@sb.utils.data_pipeline.takes("data_path") |
|
@sb.utils.data_pipeline.provides("sig") |
|
def audio_pipeline(data_path): |
|
sig, sr = librosa.load(data_path, sr=16000) |
|
|
|
return sig |
|
|
|
sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline) |
|
|
|
|
|
@sb.utils.data_pipeline.takes("transcript") |
|
@sb.utils.data_pipeline.provides("transcript", "tokens_list", "tokens_bos", "tokens_eos", "tokens") |
|
def text_pipeline(transcript): |
|
yield transcript |
|
tokens_list = tokenizer.encode_as_ids(transcript) |
|
yield tokens_list |
|
tokens_bos = torch.LongTensor([hparams["bos_index"]] + (tokens_list)) |
|
yield tokens_bos |
|
tokens_eos = torch.LongTensor(tokens_list + [hparams["eos_index"]]) |
|
yield tokens_eos |
|
tokens = torch.LongTensor(tokens_list) |
|
yield tokens |
|
|
|
sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline) |
|
|
|
|
|
sb.dataio.dataset.set_output_keys(datasets, ["id", "sig", "transcript", "tokens_list", "tokens_bos", "tokens_eos", "tokens"]) |
|
|
|
return (train_data, valid_data, test_data) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) |
|
|
|
|
|
sb.utils.distributed.ddp_init_group(run_opts) |
|
|
|
with open(hparams_file) as fin: |
|
hparams = load_hyperpyyaml(fin, overrides) |
|
|
|
|
|
sb.create_experiment_directory( |
|
experiment_directory=hparams["output_folder"], |
|
hyperparams_to_save=hparams_file, |
|
overrides=overrides, |
|
) |
|
|
|
|
|
(train_data, valid_data, test_data) = dataio_prepare(hparams) |
|
|
|
|
|
|
|
asr_brain = ASR( |
|
modules=hparams["modules"], |
|
opt_class=hparams["opt_class"], |
|
hparams=hparams, |
|
run_opts=run_opts, |
|
checkpointer=hparams["checkpointer"], |
|
) |
|
|
|
|
|
|
|
asr_brain.tokenizer = hparams["tokenizer"] |
|
train_dataloader_opts = hparams["train_dataloader_opts"] |
|
valid_dataloader_opts = hparams["valid_dataloader_opts"] |
|
|
|
|
|
|
|
if hparams["skip_training"] == False: |
|
print("Training...") |
|
|
|
asr_brain.fit( |
|
asr_brain.hparams.epoch_counter, |
|
train_data, |
|
valid_data, |
|
train_loader_kwargs=train_dataloader_opts, |
|
valid_loader_kwargs=valid_dataloader_opts, |
|
) |
|
|
|
else: |
|
|
|
print("Evaluating") |
|
asr_brain.run_inference(test_data, "WER", hparams["test_dataloader_opts"]) |
|
|
|
|