# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import os from collections import defaultdict from functools import lru_cache from pathlib import Path from subprocess import CalledProcessError, run from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union import kaldialign import numpy as np import soundfile import torch import torch.nn.functional as F Pathlike = Union[str, Path] SAMPLE_RATE = 16000 N_FFT = 400 HOP_LENGTH = 160 CHUNK_LENGTH = 30 N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk def load_audio(file: str, sr: int = SAMPLE_RATE): """ Open an audio file and read as mono waveform, resampling as necessary Parameters ---------- file: str The audio file to open sr: int The sample rate to resample the audio if necessary Returns ------- A NumPy array containing the audio waveform, in float32 dtype. """ # This launches a subprocess to decode audio while down-mixing # and resampling as necessary. Requires the ffmpeg CLI in PATH. # fmt: off cmd = [ "ffmpeg", "-nostdin", "-threads", "0", "-i", file, "-f", "s16le", "-ac", "1", "-acodec", "pcm_s16le", "-ar", str(sr), "-" ] # fmt: on try: out = run(cmd, capture_output=True, check=True).stdout except CalledProcessError as e: raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 def load_audio_wav_format(wav_path): # make sure audio in .wav format assert wav_path.endswith( '.wav'), f"Only support .wav format, but got {wav_path}" waveform, sample_rate = soundfile.read(wav_path) assert sample_rate == 16000, f"Only support 16k sample rate, but got {sample_rate}" return waveform, sample_rate def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): """ Pad or trim the audio array to N_SAMPLES, as expected by the encoder. """ if torch.is_tensor(array): if array.shape[axis] > length: array = array.index_select(dim=axis, index=torch.arange(length, device=array.device)) if array.shape[axis] < length: pad_widths = [(0, 0)] * array.ndim pad_widths[axis] = (0, length - array.shape[axis]) array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes]) else: if array.shape[axis] > length: array = array.take(indices=range(length), axis=axis) if array.shape[axis] < length: pad_widths = [(0, 0)] * array.ndim pad_widths[axis] = (0, length - array.shape[axis]) array = np.pad(array, pad_widths) return array @lru_cache(maxsize=None) def mel_filters(device, n_mels: int, mel_filters_dir: str = None) -> torch.Tensor: """ load the mel filterbank matrix for projecting STFT into a Mel spectrogram. Allows decoupling librosa dependency; saved using: np.savez_compressed( "mel_filters.npz", mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), ) """ assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}" if mel_filters_dir is None: mel_filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz") else: mel_filters_path = os.path.join(mel_filters_dir, "mel_filters.npz") with np.load(mel_filters_path) as f: return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) def log_mel_spectrogram( audio: Union[str, np.ndarray, torch.Tensor], n_mels: int, padding: int = 0, device: Optional[Union[str, torch.device]] = None, return_duration: bool = False, mel_filters_dir: str = None, ): """ Compute the log-Mel spectrogram of Parameters ---------- audio: Union[str, np.ndarray, torch.Tensor], shape = (*) The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz n_mels: int The number of Mel-frequency filters, only 80 and 128 are supported padding: int Number of zero samples to pad to the right device: Optional[Union[str, torch.device]] If given, the audio tensor is moved to this device before STFT Returns ------- torch.Tensor, shape = (80 or 128, n_frames) A Tensor that contains the Mel spectrogram """ if not torch.is_tensor(audio): if isinstance(audio, str): if audio.endswith('.wav'): audio, _ = load_audio_wav_format(audio) else: audio = load_audio(audio) assert isinstance(audio, np.ndarray), f"Unsupported audio type: {type(audio)}" duration = audio.shape[-1] / SAMPLE_RATE audio = pad_or_trim(audio, N_SAMPLES) audio = audio.astype(np.float32) audio = torch.from_numpy(audio) if device is not None: audio = audio.to(device) if padding > 0: audio = F.pad(audio, (0, padding)) window = torch.hann_window(N_FFT).to(audio.device) stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) magnitudes = stft[..., :-1].abs()**2 filters = mel_filters(audio.device, n_mels, mel_filters_dir) mel_spec = filters @ magnitudes log_spec = torch.clamp(mel_spec, min=1e-10).log10() log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) log_spec = (log_spec + 4.0) / 4.0 if return_duration: return log_spec, duration else: return log_spec def store_transcripts(filename: Pathlike, texts: Iterable[Tuple[str, str, str]]) -> None: """Save predicted results and reference transcripts to a file. https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py Args: filename: File to save the results to. texts: An iterable of tuples. The first element is the cur_id, the second is the reference transcript and the third element is the predicted result. Returns: Return None. """ with open(filename, "w") as f: for cut_id, ref, hyp in texts: print(f"{cut_id}:\tref={ref}", file=f) print(f"{cut_id}:\thyp={hyp}", file=f) def write_error_stats( f: TextIO, test_set_name: str, results: List[Tuple[str, str]], enable_log: bool = True, ) -> float: """Write statistics based on predicted results and reference transcripts. https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py It will write the following to the given file: - WER - number of insertions, deletions, substitutions, corrects and total reference words. For example:: Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606 reference words (2337 correct) - The difference between the reference transcript and predicted result. An instance is given below:: THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES The above example shows that the reference word is `EDISON`, but it is predicted to `ADDISON` (a substitution error). Another example is:: FOR THE FIRST DAY (SIR->*) I THINK The reference word `SIR` is missing in the predicted results (a deletion error). results: An iterable of tuples. The first element is the cur_id, the second is the reference transcript and the third element is the predicted result. enable_log: If True, also print detailed WER to the console. Otherwise, it is written only to the given file. Returns: Return None. """ subs: Dict[Tuple[str, str], int] = defaultdict(int) ins: Dict[str, int] = defaultdict(int) dels: Dict[str, int] = defaultdict(int) # `words` stores counts per word, as follows: # corr, ref_sub, hyp_sub, ins, dels words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0]) num_corr = 0 ERR = "*" for cut_id, ref, hyp in results: ali = kaldialign.align(ref, hyp, ERR) for ref_word, hyp_word in ali: if ref_word == ERR: ins[hyp_word] += 1 words[hyp_word][3] += 1 elif hyp_word == ERR: dels[ref_word] += 1 words[ref_word][4] += 1 elif hyp_word != ref_word: subs[(ref_word, hyp_word)] += 1 words[ref_word][1] += 1 words[hyp_word][2] += 1 else: words[ref_word][0] += 1 num_corr += 1 ref_len = sum([len(r) for _, r, _ in results]) sub_errs = sum(subs.values()) ins_errs = sum(ins.values()) del_errs = sum(dels.values()) tot_errs = sub_errs + ins_errs + del_errs tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len) if enable_log: logging.info(f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} " f"[{tot_errs} / {ref_len}, {ins_errs} ins, " f"{del_errs} del, {sub_errs} sub ]") print(f"%WER = {tot_err_rate}", file=f) print( f"Errors: {ins_errs} insertions, {del_errs} deletions, " f"{sub_errs} substitutions, over {ref_len} reference " f"words ({num_corr} correct)", file=f, ) print( "Search below for sections starting with PER-UTT DETAILS:, " "SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:", file=f, ) print("", file=f) print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f) for cut_id, ref, hyp in results: ali = kaldialign.align(ref, hyp, ERR) combine_successive_errors = True if combine_successive_errors: ali = [[[x], [y]] for x, y in ali] for i in range(len(ali) - 1): if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]: ali[i + 1][0] = ali[i][0] + ali[i + 1][0] ali[i + 1][1] = ali[i][1] + ali[i + 1][1] ali[i] = [[], []] ali = [[ list(filter(lambda a: a != ERR, x)), list(filter(lambda a: a != ERR, y)), ] for x, y in ali] ali = list(filter(lambda x: x != [[], []], ali)) ali = [[ ERR if x == [] else " ".join(x), ERR if y == [] else " ".join(y), ] for x, y in ali] print( f"{cut_id}:\t" + " ".join((ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})" for ref_word, hyp_word in ali)), file=f, ) print("", file=f) print("SUBSTITUTIONS: count ref -> hyp", file=f) for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True): print(f"{count} {ref} -> {hyp}", file=f) print("", file=f) print("DELETIONS: count ref", file=f) for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True): print(f"{count} {ref}", file=f) print("", file=f) print("INSERTIONS: count hyp", file=f) for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True): print(f"{count} {hyp}", file=f) print("", file=f) print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f) for _, word, counts in sorted([(sum(v[1:]), k, v) for k, v in words.items()], reverse=True): (corr, ref_sub, hyp_sub, ins, dels) = counts tot_errs = ref_sub + hyp_sub + ins + dels ref_count = corr + ref_sub + dels hyp_count = corr + hyp_sub + ins print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) return float(tot_err_rate)