wav2vec2-base-mal / trainer.py
aoxo's picture
Create trainer.py
57cf122 verified
import os
import random
from datasets import ClassLabel, Dataset, DatasetDict, load_dataset
from datasets.features import Audio
import pandas as pd
import numpy as np
from tqdm import tqdm
from IPython.display import display, HTML
# Function to load your custom dataset
def load_custom_dataset(data_dir):
data = {
"audio": [],
"text": []
}
wav_dir = os.path.join(data_dir, 'wav')
txt_dir = os.path.join(data_dir, 'transcription')
# Assuming filenames in 'wav' and 'txt' match
for wav_file in os.listdir(wav_dir):
if wav_file.endswith('.wav'):
txt_file = wav_file.replace('.wav', '.txt')
wav_path = os.path.join(wav_dir, wav_file)
txt_path = os.path.join(txt_dir, txt_file)
# Read the transcription text
with open(txt_path, 'r', encoding='utf-8') as f:
transcription = f.read().strip()
# Append to the dataset
data["audio"].append(wav_path)
data["text"].append(transcription)
# Create a pandas dataframe
df = pd.DataFrame(data)
# Convert to a Hugging Face dataset
dataset = Dataset.from_pandas(df)
# Define the audio feature (for .wav files)
dataset = dataset.cast_column("audio", Audio(sampling_rate=16_000)) # Adjust the sampling rate if needed
return dataset
custom_train_dataset = load_custom_dataset("./")
# Combine them into a DatasetDict
dataset_dict = DatasetDict({
"train": custom_train_dataset,
})
# Select 975 random samples from train and add them to test
train_size = len(dataset_dict["train"])
sample_indices = random.sample(range(train_size), 975)
# Select the samples
test_samples = dataset_dict["train"].select(sample_indices)
# Filter out the selected samples from the train dataset
remaining_train_samples = dataset_dict["train"].filter(lambda example, idx: idx not in sample_indices, with_indices=True)
# Add the selected samples to the test dataset
dataset_dict["test"] = test_samples
dataset_dict["train"] = remaining_train_samples
print(dataset_dict)
def show_random_elements(dataset, num_examples=10):
assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
picks = []
for _ in range(num_examples):
pick = random.randint(0, len(dataset)-1)
while pick in picks:
pick = random.randint(0, len(dataset)-1)
picks.append(pick)
df = pd.DataFrame(dataset[picks])
show_random_elements(dataset_dict["train"])
import re
chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"]'
def remove_special_characters(batch):
batch["text"] = re.sub(chars_to_ignore_regex, '', batch["text"]).lower()
return batch
dataset_dict = dataset_dict.map(remove_special_characters)
show_random_elements(dataset_dict["train"])
def extract_all_chars(batch):
all_text = " ".join(batch["text"])
vocab = list(set(all_text))
return {"vocab": [vocab], "all_text": [all_text]}
vocabs = dataset_dict.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=dataset_dict.column_names["train"])
vocab_list = list(set(vocabs["train"]["vocab"][0]))
vocab_dict = {v: k for k, v in enumerate(vocab_list)}
print(vocab_dict)
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)
print(len(vocab_dict))
import json
with open('vocab.json', 'w') as vocab_file:
json.dump(vocab_dict, vocab_file)
from transformers import Wav2Vec2CTCTokenizer
tokenizer = Wav2Vec2CTCTokenizer("./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|", vocab_size=len(vocab_dict))
from transformers import Wav2Vec2FeatureExtractor
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False)
from transformers import Wav2Vec2Processor
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
rand_int = random.randint(0, len(dataset_dict["train"]))
print("Target text:", dataset_dict["train"][rand_int]["text"])
print("Input array shape:", np.asarray(dataset_dict["train"][rand_int]["audio"]["array"]).shape)
print("Sampling rate:", dataset_dict["train"][rand_int]["audio"]["sampling_rate"])
def prepare_dataset(batch):
audio = batch["audio"]
# batched output is "un-batched" to ensure mapping is correct
batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
with processor.as_target_processor():
batch["labels"] = processor(batch["text"]).input_ids
return batch
dataset_dict = dataset_dict.map(prepare_dataset, remove_columns=dataset_dict.column_names["train"], num_proc=None)
import torch
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
@dataclass
class DataCollatorCTCWithPadding:
"""
Data collator that will dynamically pad the inputs received.
Args:
processor (:class:`~transformers.Wav2Vec2Processor`)
The processor used for proccessing the data.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
among:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
max_length (:obj:`int`, `optional`):
Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
max_length_labels (:obj:`int`, `optional`):
Maximum length of the ``labels`` returned list and optionally padding length (see above).
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.5 (Volta).
"""
processor: Wav2Vec2Processor
padding: Union[bool, str] = True
max_length: Optional[int] = None
max_length_labels: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
pad_to_multiple_of_labels: Optional[int] = None
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need
# different padding methods
input_features = [{"input_values": feature["input_values"]} for feature in features]
label_features = [{"input_ids": feature["labels"]} for feature in features]
batch = self.processor.pad(
input_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="pt",
)
with self.processor.as_target_processor():
labels_batch = self.processor.pad(
label_features,
padding=self.padding,
max_length=self.max_length_labels,
pad_to_multiple_of=self.pad_to_multiple_of_labels,
return_tensors="pt",
)
# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
batch["labels"] = labels
return batch
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
import evaluate
wer_metric = evaluate.load("wer")
def compute_metrics(pred):
pred_logits = pred.predictions
pred_ids = np.argmax(pred_logits, axis=-1)
pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
pred_str = processor.batch_decode(pred_ids)
# we do not want to group tokens when computing the metrics
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
wer = wer_metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer}
from transformers import Wav2Vec2ForCTC
model = Wav2Vec2ForCTC.from_pretrained(
"facebook/wav2vec2-base",
ctc_loss_reduction="mean",
pad_token_id=processor.tokenizer.pad_token_id,
vocab_size=len(vocab_dict),
)
model.freeze_feature_encoder()
model.gradient_checkpointing_enable()
from transformers import TrainingArguments
training_args = TrainingArguments(
output_dir='wav2vec2-mal',
group_by_length=True,
per_device_train_batch_size=24,
eval_strategy="steps",
num_train_epochs=30,
fp16=True,
#gradient_checkpointing=True,
save_steps=500,
eval_steps=500,
logging_steps=500,
learning_rate=1e-4,
weight_decay=0.005,
warmup_steps=1000,
save_total_limit=2,
)
from transformers import Trainer
trainer = Trainer(
model=model,
data_collator=data_collator,
args=training_args,
compute_metrics=compute_metrics,
train_dataset=dataset_dict["train"],
eval_dataset=dataset_dict["test"],
processing_class=processor.feature_extractor,
)
trainer.train()
def map_to_result(batch):
with torch.no_grad():
input_values = torch.tensor(batch["input_values"], device="cuda").unsqueeze(0)
logits = model(input_values).logits
pred_ids = torch.argmax(logits, dim=-1)
batch["pred_str"] = processor.batch_decode(pred_ids)[0]
batch["text"] = processor.decode(batch["labels"], group_tokens=False)
return batch
results = dataset_dict["test"].map(map_to_result, remove_columns=dataset_dict["test"].column_names)
print("Test WER: {:.3f}".format(wer_metric.compute(predictions=results["pred_str"], references=results["text"])))