wav2vec2-ctc-spgispeech / get_ctc_tokenizer.py
sanchit-gandhi's picture
Add training scripts and weights
1a8f407
#!/usr/bin/env python3
from datasets import load_dataset
from collections import Counter
import json
import os
import tempfile
from transformers import Wav2Vec2CTCTokenizer
# which dataset
dataset_name = "spgispeech"
# which split -> we should only use train to train our tokenizer
split = "train"
# in case the dataset requires access
use_auth_token = True
# name of tok to upload to the Hub
tokenizer_name = f"wav2vec2-ctc-{dataset_name}-tokenizer"
# FIX the cutoff freq for all datasets -> an entirely dataset-agnostic approach
cutoff_freq = 0.01
dataset = load_dataset(
"esb/datasets",
dataset_name,
split=split,
use_auth_token=use_auth_token,
)
# remove all data that is unnecessary to save RAM
dataset = dataset.remove_columns(list(set(dataset.column_names) - {"text"}))
# define function to see stats about letters and to create vocab
def create_vocabulary_from_data(dataset, word_delimiter_token="|", cutoff_freq=0.0):
def extract_all_chars(batch):
all_text = " ".join(batch["text"])
count_chars_dict = Counter(list(all_text))
# sort by freq
count_chars_dict = sorted(count_chars_dict.items(), key=lambda item: (-item[1], item[0]))
# retrieve dict, freq
vocab, freqs = zip(*count_chars_dict)
return {"vocab": list(vocab), "freqs": list(freqs)}
dataset = dataset.map(
extract_all_chars,
batched=True,
batch_size=-1,
remove_columns=dataset.column_names,
)
vocab, freqs = dataset["vocab"], dataset["freqs"]
total_num_chars = sum(freqs)
chars_to_remove = []
print("Character Occurences")
print(f"Total characters in dataset: {total_num_chars}")
print(50 * "-")
print(f"{'Char'.rjust(5)} | {'Total occ'.rjust(10)} | {'% of total occ'.rjust(20)} |")
print(50 * "-")
for char, freq in zip(vocab, freqs):
freq_in_percent = freq / total_num_chars * 100
print(f"{char.rjust(5)} | {str(freq).rjust(10)} | {str(round(freq_in_percent, 3)).rjust(20)} |")
if freq_in_percent < cutoff_freq:
chars_to_remove.append(char)
print(50 * "-")
vocab = list(set(vocab) - set(chars_to_remove))
# Wav2Vec2CTC Tokenizers always have those as the first tokens (important for CTC)
vocab = ["<pad>", "<s>", "</s>", "<unk>"] + vocab
# create json dict
vocab_dict = {v: k for k, v in enumerate(list(vocab))}
# replace white space with delimiter token
if word_delimiter_token is not None:
vocab_dict[word_delimiter_token] = vocab_dict[" "]
del vocab_dict[" "]
return vocab_dict
# Note that the functions accepts the following important args
# --cutoff_freq
# => This is very important! Lots of datasets will contain "wrong" characters in the training set, e.g.
# characters that just occur a couple of times.
# By default, the CTC vocab creation would just add them to the vocab even if their occurance is neglectible # compared to the "super frequent" letters. We can see such characters as "errors" or irrelevant in the
# dataset, so that we should delete them from the vocab. During training, they would then just be classified
# unknown <unk> tokens which the model can handle.
# In this script, we deploy a mechanism to remove all chars whose freq in % is below a certain threshold.
# We FIX this threshold for all datasets (i.e. dataset-agnostic)
vocab_dict = create_vocabulary_from_data(dataset, cutoff_freq=cutoff_freq)
# save vocab dict to be loaded into tokenizer
with tempfile.TemporaryDirectory() as tmp:
with open(os.path.join(tmp, "vocab.json"), "w") as file:
json.dump(vocab_dict, file)
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(tmp)
# push tokenizer to the Hub
tokenizer.push_to_hub(tokenizer_name)