|
import os |
|
import random |
|
from datasets import ClassLabel, Dataset, DatasetDict, load_dataset |
|
from datasets.features import Audio |
|
import pandas as pd |
|
import numpy as np |
|
from tqdm import tqdm |
|
from IPython.display import display, HTML |
|
|
|
|
|
def load_custom_dataset(data_dir): |
|
data = { |
|
"audio": [], |
|
"text": [] |
|
} |
|
|
|
wav_dir = os.path.join(data_dir, 'wav') |
|
txt_dir = os.path.join(data_dir, 'transcription') |
|
|
|
|
|
for wav_file in os.listdir(wav_dir): |
|
if wav_file.endswith('.wav'): |
|
txt_file = wav_file.replace('.wav', '.txt') |
|
wav_path = os.path.join(wav_dir, wav_file) |
|
txt_path = os.path.join(txt_dir, txt_file) |
|
|
|
|
|
with open(txt_path, 'r', encoding='utf-8') as f: |
|
transcription = f.read().strip() |
|
|
|
|
|
data["audio"].append(wav_path) |
|
data["text"].append(transcription) |
|
|
|
|
|
df = pd.DataFrame(data) |
|
|
|
|
|
dataset = Dataset.from_pandas(df) |
|
|
|
|
|
dataset = dataset.cast_column("audio", Audio(sampling_rate=16_000)) |
|
|
|
return dataset |
|
|
|
custom_train_dataset = load_custom_dataset("./") |
|
|
|
|
|
dataset_dict = DatasetDict({ |
|
"train": custom_train_dataset, |
|
}) |
|
|
|
|
|
train_size = len(dataset_dict["train"]) |
|
sample_indices = random.sample(range(train_size), 975) |
|
|
|
|
|
test_samples = dataset_dict["train"].select(sample_indices) |
|
|
|
|
|
remaining_train_samples = dataset_dict["train"].filter(lambda example, idx: idx not in sample_indices, with_indices=True) |
|
|
|
|
|
dataset_dict["test"] = test_samples |
|
dataset_dict["train"] = remaining_train_samples |
|
|
|
print(dataset_dict) |
|
|
|
def show_random_elements(dataset, num_examples=10): |
|
assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset." |
|
picks = [] |
|
for _ in range(num_examples): |
|
pick = random.randint(0, len(dataset)-1) |
|
while pick in picks: |
|
pick = random.randint(0, len(dataset)-1) |
|
picks.append(pick) |
|
|
|
df = pd.DataFrame(dataset[picks]) |
|
|
|
show_random_elements(dataset_dict["train"]) |
|
|
|
import re |
|
chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"]' |
|
|
|
def remove_special_characters(batch): |
|
batch["text"] = re.sub(chars_to_ignore_regex, '', batch["text"]).lower() |
|
return batch |
|
|
|
dataset_dict = dataset_dict.map(remove_special_characters) |
|
|
|
show_random_elements(dataset_dict["train"]) |
|
|
|
def extract_all_chars(batch): |
|
all_text = " ".join(batch["text"]) |
|
vocab = list(set(all_text)) |
|
return {"vocab": [vocab], "all_text": [all_text]} |
|
|
|
vocabs = dataset_dict.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=dataset_dict.column_names["train"]) |
|
|
|
vocab_list = list(set(vocabs["train"]["vocab"][0])) |
|
|
|
vocab_dict = {v: k for k, v in enumerate(vocab_list)} |
|
print(vocab_dict) |
|
|
|
vocab_dict["[UNK]"] = len(vocab_dict) |
|
vocab_dict["[PAD]"] = len(vocab_dict) |
|
print(len(vocab_dict)) |
|
|
|
import json |
|
with open('vocab.json', 'w') as vocab_file: |
|
json.dump(vocab_dict, vocab_file) |
|
|
|
from transformers import Wav2Vec2CTCTokenizer |
|
|
|
tokenizer = Wav2Vec2CTCTokenizer("./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|", vocab_size=len(vocab_dict)) |
|
|
|
from transformers import Wav2Vec2FeatureExtractor |
|
|
|
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False) |
|
|
|
from transformers import Wav2Vec2Processor |
|
|
|
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) |
|
|
|
rand_int = random.randint(0, len(dataset_dict["train"])) |
|
|
|
print("Target text:", dataset_dict["train"][rand_int]["text"]) |
|
print("Input array shape:", np.asarray(dataset_dict["train"][rand_int]["audio"]["array"]).shape) |
|
print("Sampling rate:", dataset_dict["train"][rand_int]["audio"]["sampling_rate"]) |
|
|
|
def prepare_dataset(batch): |
|
audio = batch["audio"] |
|
|
|
|
|
batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0] |
|
|
|
with processor.as_target_processor(): |
|
batch["labels"] = processor(batch["text"]).input_ids |
|
return batch |
|
|
|
dataset_dict = dataset_dict.map(prepare_dataset, remove_columns=dataset_dict.column_names["train"], num_proc=None) |
|
|
|
import torch |
|
|
|
from dataclasses import dataclass, field |
|
from typing import Any, Dict, List, Optional, Union |
|
|
|
@dataclass |
|
class DataCollatorCTCWithPadding: |
|
""" |
|
Data collator that will dynamically pad the inputs received. |
|
Args: |
|
processor (:class:`~transformers.Wav2Vec2Processor`) |
|
The processor used for proccessing the data. |
|
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): |
|
Select a strategy to pad the returned sequences (according to the model's padding side and padding index) |
|
among: |
|
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single |
|
sequence if provided). |
|
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the |
|
maximum acceptable input length for the model if that argument is not provided. |
|
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of |
|
different lengths). |
|
max_length (:obj:`int`, `optional`): |
|
Maximum length of the ``input_values`` of the returned list and optionally padding length (see above). |
|
max_length_labels (:obj:`int`, `optional`): |
|
Maximum length of the ``labels`` returned list and optionally padding length (see above). |
|
pad_to_multiple_of (:obj:`int`, `optional`): |
|
If set will pad the sequence to a multiple of the provided value. |
|
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= |
|
7.5 (Volta). |
|
""" |
|
|
|
processor: Wav2Vec2Processor |
|
padding: Union[bool, str] = True |
|
max_length: Optional[int] = None |
|
max_length_labels: Optional[int] = None |
|
pad_to_multiple_of: Optional[int] = None |
|
pad_to_multiple_of_labels: Optional[int] = None |
|
|
|
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: |
|
|
|
|
|
input_features = [{"input_values": feature["input_values"]} for feature in features] |
|
label_features = [{"input_ids": feature["labels"]} for feature in features] |
|
|
|
batch = self.processor.pad( |
|
input_features, |
|
padding=self.padding, |
|
max_length=self.max_length, |
|
pad_to_multiple_of=self.pad_to_multiple_of, |
|
return_tensors="pt", |
|
) |
|
with self.processor.as_target_processor(): |
|
labels_batch = self.processor.pad( |
|
label_features, |
|
padding=self.padding, |
|
max_length=self.max_length_labels, |
|
pad_to_multiple_of=self.pad_to_multiple_of_labels, |
|
return_tensors="pt", |
|
) |
|
|
|
|
|
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) |
|
|
|
batch["labels"] = labels |
|
|
|
return batch |
|
|
|
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True) |
|
|
|
import evaluate |
|
|
|
wer_metric = evaluate.load("wer") |
|
|
|
def compute_metrics(pred): |
|
pred_logits = pred.predictions |
|
pred_ids = np.argmax(pred_logits, axis=-1) |
|
|
|
pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id |
|
|
|
pred_str = processor.batch_decode(pred_ids) |
|
|
|
label_str = processor.batch_decode(pred.label_ids, group_tokens=False) |
|
|
|
wer = wer_metric.compute(predictions=pred_str, references=label_str) |
|
|
|
return {"wer": wer} |
|
|
|
from transformers import Wav2Vec2ForCTC |
|
|
|
model = Wav2Vec2ForCTC.from_pretrained( |
|
"facebook/wav2vec2-base", |
|
ctc_loss_reduction="mean", |
|
pad_token_id=processor.tokenizer.pad_token_id, |
|
vocab_size=len(vocab_dict), |
|
) |
|
|
|
model.freeze_feature_encoder() |
|
|
|
model.gradient_checkpointing_enable() |
|
|
|
from transformers import TrainingArguments |
|
|
|
training_args = TrainingArguments( |
|
output_dir='wav2vec2-mal', |
|
group_by_length=True, |
|
per_device_train_batch_size=24, |
|
eval_strategy="steps", |
|
num_train_epochs=30, |
|
fp16=True, |
|
|
|
save_steps=500, |
|
eval_steps=500, |
|
logging_steps=500, |
|
learning_rate=1e-4, |
|
weight_decay=0.005, |
|
warmup_steps=1000, |
|
save_total_limit=2, |
|
) |
|
|
|
from transformers import Trainer |
|
|
|
trainer = Trainer( |
|
model=model, |
|
data_collator=data_collator, |
|
args=training_args, |
|
compute_metrics=compute_metrics, |
|
train_dataset=dataset_dict["train"], |
|
eval_dataset=dataset_dict["test"], |
|
processing_class=processor.feature_extractor, |
|
) |
|
|
|
trainer.train() |
|
|
|
def map_to_result(batch): |
|
with torch.no_grad(): |
|
input_values = torch.tensor(batch["input_values"], device="cuda").unsqueeze(0) |
|
logits = model(input_values).logits |
|
|
|
pred_ids = torch.argmax(logits, dim=-1) |
|
batch["pred_str"] = processor.batch_decode(pred_ids)[0] |
|
batch["text"] = processor.decode(batch["labels"], group_tokens=False) |
|
|
|
return batch |
|
|
|
results = dataset_dict["test"].map(map_to_result, remove_columns=dataset_dict["test"].column_names) |
|
|
|
print("Test WER: {:.3f}".format(wer_metric.compute(predictions=results["pred_str"], references=results["text"]))) |