JustinLin610's picture
first commit
ee21b96
raw
history blame
12.2 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import csv
from pathlib import Path
import zipfile
from functools import reduce
from multiprocessing import cpu_count
from typing import Any, Dict, List, Optional, Union
import io
import numpy as np
import pandas as pd
import sentencepiece as sp
from fairseq.data.audio.audio_utils import (
convert_waveform, _get_kaldi_fbank, _get_torchaudio_fbank, is_npy_data,
is_sf_audio_data
)
import torch
import soundfile as sf
from tqdm import tqdm
UNK_TOKEN, UNK_TOKEN_ID = "<unk>", 3
BOS_TOKEN, BOS_TOKEN_ID = "<s>", 0
EOS_TOKEN, EOS_TOKEN_ID = "</s>", 2
PAD_TOKEN, PAD_TOKEN_ID = "<pad>", 1
def gen_vocab(
input_path: Path, output_path_prefix: Path, model_type="bpe",
vocab_size=1000, special_symbols: Optional[List[str]] = None
):
# Train SentencePiece Model
arguments = [
f"--input={input_path.as_posix()}",
f"--model_prefix={output_path_prefix.as_posix()}",
f"--model_type={model_type}",
f"--vocab_size={vocab_size}",
"--character_coverage=1.0",
f"--num_threads={cpu_count()}",
f"--unk_id={UNK_TOKEN_ID}",
f"--bos_id={BOS_TOKEN_ID}",
f"--eos_id={EOS_TOKEN_ID}",
f"--pad_id={PAD_TOKEN_ID}",
]
if special_symbols is not None:
_special_symbols = ",".join(special_symbols)
arguments.append(f"--user_defined_symbols={_special_symbols}")
sp.SentencePieceTrainer.Train(" ".join(arguments))
# Export fairseq dictionary
spm = sp.SentencePieceProcessor()
spm.Load(output_path_prefix.as_posix() + ".model")
vocab = {i: spm.IdToPiece(i) for i in range(spm.GetPieceSize())}
assert (
vocab.get(UNK_TOKEN_ID) == UNK_TOKEN
and vocab.get(PAD_TOKEN_ID) == PAD_TOKEN
and vocab.get(BOS_TOKEN_ID) == BOS_TOKEN
and vocab.get(EOS_TOKEN_ID) == EOS_TOKEN
)
vocab = {
i: s
for i, s in vocab.items()
if s not in {UNK_TOKEN, BOS_TOKEN, EOS_TOKEN, PAD_TOKEN}
}
with open(output_path_prefix.as_posix() + ".txt", "w") as f_out:
for _, s in sorted(vocab.items(), key=lambda x: x[0]):
f_out.write(f"{s} 1\n")
def extract_fbank_features(
waveform: torch.FloatTensor,
sample_rate: int,
output_path: Optional[Path] = None,
n_mel_bins: int = 80,
overwrite: bool = False,
):
if output_path is not None and output_path.is_file() and not overwrite:
return
_waveform = convert_waveform(waveform, sample_rate, to_mono=True)
# Kaldi compliance: 16-bit signed integers
_waveform = _waveform * (2 ** 15)
_waveform = _waveform.numpy()
features = _get_kaldi_fbank(_waveform, sample_rate, n_mel_bins)
if features is None:
features = _get_torchaudio_fbank(_waveform, sample_rate, n_mel_bins)
if features is None:
raise ImportError(
"Please install pyKaldi or torchaudio to enable fbank feature extraction"
)
if output_path is not None:
np.save(output_path.as_posix(), features)
return features
def create_zip(data_root: Path, zip_path: Path):
paths = list(data_root.glob("*.npy"))
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_STORED) as f:
for path in tqdm(paths):
f.write(path, arcname=path.name)
def get_zip_manifest(
zip_path: Path, zip_root: Optional[Path] = None, is_audio=False
):
_zip_path = Path.joinpath(zip_root or Path(""), zip_path)
with zipfile.ZipFile(_zip_path, mode="r") as f:
info = f.infolist()
paths, lengths = {}, {}
for i in tqdm(info):
utt_id = Path(i.filename).stem
offset, file_size = i.header_offset + 30 + len(i.filename), i.file_size
paths[utt_id] = f"{zip_path.as_posix()}:{offset}:{file_size}"
with open(_zip_path, "rb") as f:
f.seek(offset)
byte_data = f.read(file_size)
assert len(byte_data) > 1
if is_audio:
assert is_sf_audio_data(byte_data), i
else:
assert is_npy_data(byte_data), i
byte_data_fp = io.BytesIO(byte_data)
if is_audio:
lengths[utt_id] = sf.info(byte_data_fp).frames
else:
lengths[utt_id] = np.load(byte_data_fp).shape[0]
return paths, lengths
def gen_config_yaml(
manifest_root: Path,
spm_filename: Optional[str] = None,
vocab_name: Optional[str] = None,
yaml_filename: str = "config.yaml",
specaugment_policy: Optional[str] = "lb",
prepend_tgt_lang_tag: bool = False,
sampling_alpha: Optional[float] = None,
input_channels: Optional[int] = 1,
input_feat_per_channel: Optional[int] = 80,
audio_root: str = "",
cmvn_type: str = "utterance",
gcmvn_path: Optional[Path] = None,
extra=None
):
manifest_root = manifest_root.absolute()
writer = S2TDataConfigWriter(manifest_root / yaml_filename)
assert spm_filename is not None or vocab_name is not None
vocab_name = spm_filename.replace(".model", ".txt") if vocab_name is None \
else vocab_name
writer.set_vocab_filename(vocab_name)
if input_channels is not None:
writer.set_input_channels(input_channels)
if input_feat_per_channel is not None:
writer.set_input_feat_per_channel(input_feat_per_channel)
specaugment_setters = {
"lb": writer.set_specaugment_lb_policy,
"ld": writer.set_specaugment_ld_policy,
"sm": writer.set_specaugment_sm_policy,
"ss": writer.set_specaugment_ss_policy,
}
specaugment_setter = specaugment_setters.get(specaugment_policy, None)
if specaugment_setter is not None:
specaugment_setter()
if spm_filename is not None:
writer.set_bpe_tokenizer(
{
"bpe": "sentencepiece",
"sentencepiece_model": (manifest_root / spm_filename).as_posix(),
}
)
if prepend_tgt_lang_tag:
writer.set_prepend_tgt_lang_tag(True)
if sampling_alpha is not None:
writer.set_sampling_alpha(sampling_alpha)
if cmvn_type not in ["global", "utterance"]:
raise NotImplementedError
if specaugment_policy is not None:
writer.set_feature_transforms(
"_train", [f"{cmvn_type}_cmvn", "specaugment"]
)
writer.set_feature_transforms("*", [f"{cmvn_type}_cmvn"])
if cmvn_type == "global":
if gcmvn_path is None:
raise ValueError("Please provide path of global cmvn file.")
else:
writer.set_global_cmvn(gcmvn_path.as_posix())
if len(audio_root) > 0:
writer.set_audio_root(audio_root)
if extra is not None:
writer.set_extra(extra)
writer.flush()
def load_df_from_tsv(path: Union[str, Path]) -> pd.DataFrame:
_path = path if isinstance(path, str) else path.as_posix()
return pd.read_csv(
_path,
sep="\t",
header=0,
encoding="utf-8",
escapechar="\\",
quoting=csv.QUOTE_NONE,
na_filter=False,
)
def save_df_to_tsv(dataframe, path: Union[str, Path]):
_path = path if isinstance(path, str) else path.as_posix()
dataframe.to_csv(
_path,
sep="\t",
header=True,
index=False,
encoding="utf-8",
escapechar="\\",
quoting=csv.QUOTE_NONE,
)
def load_tsv_to_dicts(path: Union[str, Path]) -> List[dict]:
with open(path, "r") as f:
reader = csv.DictReader(
f,
delimiter="\t",
quotechar=None,
doublequote=False,
lineterminator="\n",
quoting=csv.QUOTE_NONE,
)
rows = [dict(e) for e in reader]
return rows
def filter_manifest_df(
df, is_train_split=False, extra_filters=None, min_n_frames=5, max_n_frames=3000
):
filters = {
"no speech": df["audio"] == "",
f"short speech (<{min_n_frames} frames)": df["n_frames"] < min_n_frames,
"empty sentence": df["tgt_text"] == "",
}
if is_train_split:
filters[f"long speech (>{max_n_frames} frames)"] = df["n_frames"] > max_n_frames
if extra_filters is not None:
filters.update(extra_filters)
invalid = reduce(lambda x, y: x | y, filters.values())
valid = ~invalid
print(
"| "
+ ", ".join(f"{n}: {f.sum()}" for n, f in filters.items())
+ f", total {invalid.sum()} filtered, {valid.sum()} remained."
)
return df[valid]
def cal_gcmvn_stats(features_list):
features = np.concatenate(features_list)
square_sums = (features ** 2).sum(axis=0)
mean = features.mean(axis=0)
features = np.subtract(features, mean)
var = square_sums / features.shape[0] - mean ** 2
std = np.sqrt(np.maximum(var, 1e-8))
return {"mean": mean.astype("float32"), "std": std.astype("float32")}
class S2TDataConfigWriter(object):
DEFAULT_VOCAB_FILENAME = "dict.txt"
DEFAULT_INPUT_FEAT_PER_CHANNEL = 80
DEFAULT_INPUT_CHANNELS = 1
def __init__(self, yaml_path: Path):
try:
import yaml
except ImportError:
print("Please install PyYAML for S2T data config YAML files")
self.yaml = yaml
self.yaml_path = yaml_path
self.config = {}
def flush(self):
with open(self.yaml_path, "w") as f:
self.yaml.dump(self.config, f)
def set_audio_root(self, audio_root=""):
self.config["audio_root"] = audio_root
def set_vocab_filename(self, vocab_filename: str = "dict.txt"):
self.config["vocab_filename"] = vocab_filename
def set_specaugment(
self,
time_wrap_w: int,
freq_mask_n: int,
freq_mask_f: int,
time_mask_n: int,
time_mask_t: int,
time_mask_p: float,
):
self.config["specaugment"] = {
"time_wrap_W": time_wrap_w,
"freq_mask_N": freq_mask_n,
"freq_mask_F": freq_mask_f,
"time_mask_N": time_mask_n,
"time_mask_T": time_mask_t,
"time_mask_p": time_mask_p,
}
def set_specaugment_lb_policy(self):
self.set_specaugment(
time_wrap_w=0,
freq_mask_n=1,
freq_mask_f=27,
time_mask_n=1,
time_mask_t=100,
time_mask_p=1.0,
)
def set_specaugment_ld_policy(self):
self.set_specaugment(
time_wrap_w=0,
freq_mask_n=2,
freq_mask_f=27,
time_mask_n=2,
time_mask_t=100,
time_mask_p=1.0,
)
def set_specaugment_sm_policy(self):
self.set_specaugment(
time_wrap_w=0,
freq_mask_n=2,
freq_mask_f=15,
time_mask_n=2,
time_mask_t=70,
time_mask_p=0.2,
)
def set_specaugment_ss_policy(self):
self.set_specaugment(
time_wrap_w=0,
freq_mask_n=2,
freq_mask_f=27,
time_mask_n=2,
time_mask_t=70,
time_mask_p=0.2,
)
def set_input_channels(self, input_channels: int = 1):
self.config["input_channels"] = input_channels
def set_input_feat_per_channel(self, input_feat_per_channel: int = 80):
self.config["input_feat_per_channel"] = input_feat_per_channel
def set_bpe_tokenizer(self, bpe_tokenizer: Dict[str, Any]):
self.config["bpe_tokenizer"] = bpe_tokenizer
def set_global_cmvn(self, stats_npz_path: str):
self.config["global_cmvn"] = {"stats_npz_path": stats_npz_path}
def set_feature_transforms(self, split: str, transforms: List[str]):
if "transforms" not in self.config:
self.config["transforms"] = {}
self.config["transforms"][split] = transforms
def set_prepend_tgt_lang_tag(self, flag: bool = True):
self.config["prepend_tgt_lang_tag"] = flag
def set_sampling_alpha(self, sampling_alpha: float = 1.0):
self.config["sampling_alpha"] = sampling_alpha
def set_extra(self, data):
self.config.update(data)