wav2vec2-large-xlsr-53-th
Finetuning wav2vec2-large-xlsr-53
on Thai Common Voice 7.0
We finetune wav2vec2-large-xlsr-53 based on Fine-tuning Wav2Vec2 for English ASR using Thai examples of Common Voice Corpus 7.0. The notebooks and scripts can be found in vistec-ai/wav2vec2-large-xlsr-53-th. The pretrained model and processor can be found at airesearch/wav2vec2-large-xlsr-53-th.
robust-speech-event
Add syllable_tokenize
, word_tokenize
(PyThaiNLP) and deepcut tokenizers to eval.py
from robust-speech-event
> python eval.py --model_id ./ --dataset mozilla-foundation/common_voice_7_0 --config th --split test --log_outputs --thai_tokenizer newmm/syllable/deepcut/cer
Eval results on Common Voice 7 "test":
WER PyThaiNLP 2.3.1 | WER deepcut | SER | CER | |
---|---|---|---|---|
Only Tokenization | 0.9524% | 2.5316% | 1.2346% | 0.1623% |
Cleaning rules and Tokenization | TBD | TBD | TBD | TBD |
Usage
#load pretrained processor and model
processor = Wav2Vec2Processor.from_pretrained("airesearch/wav2vec2-large-xlsr-53-th")
model = Wav2Vec2ForCTC.from_pretrained("airesearch/wav2vec2-large-xlsr-53-th")
#function to resample to 16_000
def speech_file_to_array_fn(batch,
text_col="sentence",
fname_col="path",
resampling_to=16000):
speech_array, sampling_rate = torchaudio.load(batch[fname_col])
resampler=torchaudio.transforms.Resample(sampling_rate, resampling_to)
batch["speech"] = resampler(speech_array)[0].numpy()
batch["sampling_rate"] = resampling_to
batch["target_text"] = batch[text_col]
return batch
#get 2 examples as sample input
test_dataset = test_dataset.map(speech_file_to_array_fn)
inputs = processor(test_dataset["speech"][:2], sampling_rate=16_000, return_tensors="pt", padding=True)
#infer
with torch.no_grad():
logits = model(inputs.input_values,).logits
predicted_ids = torch.argmax(logits, dim=-1)
print("Prediction:", processor.batch_decode(predicted_ids))
print("Reference:", test_dataset["sentence"][:2])
>> Prediction: ['เนเธฅเธฐ เนเธเธฒ เธเน เธชเธฑเธกเธเธฑเธช เธเธตเธเธธเธ', 'เธเธธเธ เธชเธฒเธกเธฒเธฃเธ เธฃเธฑเธเธเธฃเธฒเธ เนเธกเธทเนเธญ เธเนเธญเธเธงเธฒเธก เธเธตเน เธเธนเธ เธญเนเธฒเธ เนเธฅเนเธง']
>> Reference: ['เนเธฅเธฐเนเธเธฒเธเนเธชเธฑเธกเธเธฑเธชเธเธตเธเธธเธ', 'เธเธธเธเธชเธฒเธกเธฒเธฃเธเธฃเธฑเธเธเธฃเธฒเธเนเธกเธทเนเธญเธเนเธญเธเธงเธฒเธกเธเธตเนเธเธนเธเธญเนเธฒเธเนเธฅเนเธง']
Datasets
Common Voice Corpus 7.0](https://commonvoice.mozilla.org/en/datasets) contains 133 validated hours of Thai (255 total hours) at 5GB. We pre-tokenize with pythainlp.tokenize.word_tokenize
. We preprocess the dataset using cleaning rules described in notebooks/cv-preprocess.ipynb
by @tann9949. We then deduplicate and split as described in ekapolc/Thai_commonvoice_split in order to 1) avoid data leakage due to random splits after cleaning in Common Voice Corpus 7.0 and 2) preserve the majority of the data for the training set. The dataset loading script is scripts/th_common_voice_70.py
. You can use this scripts together with train_cleand.tsv
, validation_cleaned.tsv
and test_cleaned.tsv
to have the same splits as we do. The resulting dataset is as follows:
DatasetDict({
train: Dataset({
features: ['path', 'sentence'],
num_rows: 86586
})
test: Dataset({
features: ['path', 'sentence'],
num_rows: 2502
})
validation: Dataset({
features: ['path', 'sentence'],
num_rows: 3027
})
})
Training
We fintuned using the following configuration on a single V100 GPU and chose the checkpoint with the lowest validation loss. The finetuning script is scripts/wav2vec2_finetune.py
# create model
model = Wav2Vec2ForCTC.from_pretrained(
"facebook/wav2vec2-large-xlsr-53",
attention_dropout=0.1,
hidden_dropout=0.1,
feat_proj_dropout=0.0,
mask_time_prob=0.05,
layerdrop=0.1,
gradient_checkpointing=True,
ctc_loss_reduction="mean",
pad_token_id=processor.tokenizer.pad_token_id,
vocab_size=len(processor.tokenizer)
)
model.freeze_feature_extractor()
training_args = TrainingArguments(
output_dir="../data/wav2vec2-large-xlsr-53-thai",
group_by_length=True,
per_device_train_batch_size=32,
gradient_accumulation_steps=1,
per_device_eval_batch_size=16,
metric_for_best_model='wer',
evaluation_strategy="steps",
eval_steps=1000,
logging_strategy="steps",
logging_steps=1000,
save_strategy="steps",
save_steps=1000,
num_train_epochs=100,
fp16=True,
learning_rate=1e-4,
warmup_steps=1000,
save_total_limit=3,
report_to="tensorboard"
)
Evaluation
We benchmark on the test set using WER with words tokenized by PyThaiNLP 2.3.1 and deepcut, and CER. We also measure performance when spell correction using TNC ngrams is applied. Evaluation codes can be found in notebooks/wav2vec2_finetuning_tutorial.ipynb
. Benchmark is performed on test-unique
split.
WER PyThaiNLP 2.3.1 | WER deepcut | CER | |
---|---|---|---|
Kaldi from scratch | 23.04 | 7.57 | |
Ours without spell correction | 13.634024 | 8.152052 | 2.813019 |
Ours with spell correction | 17.996397 | 14.167975 | 5.225761 |
Google Web Speech APIโป | 13.711234 | 10.860058 | 7.357340 |
Microsoft Bing Speech APIโป | 12.578819 | 9.620991 | 5.016620 |
Amazon Transcribeโป | 21.86334 | 14.487553 | 7.077562 |
NECTEC AI for Thai Partii APIโป | 20.105887 | 15.515631 | 9.551027 |
โป APIs are not finetuned with Common Voice 7.0 data
LICENSE
Ackowledgements
- model training and validation notebooks/scripts @cstorm125
- dataset cleaning scripts @tann9949
- dataset splits @ekapolc and @14mss
- running the training @mrpeerat
- spell correction @wannaphong
- Downloads last month
- 57,961
Model tree for airesearch/wav2vec2-large-xlsr-53-th
Dataset used to train airesearch/wav2vec2-large-xlsr-53-th
Spaces using airesearch/wav2vec2-large-xlsr-53-th 6
Evaluation results
- Test WER on Common Voice 7self-reported0.952
- Test SER on Common Voice 7self-reported1.235
- Test CER on Common Voice 7self-reported0.162
- Test WER on Robust Speech Event - Dev Dataself-reportednull
- Test SER on Robust Speech Event - Dev Dataself-reportednull
- Test CER on Robust Speech Event - Dev Dataself-reportednull