Vietnamese_ASR / src /training.py
DuyTa's picture
Source )
c6b1960
raw
history blame
6.36 kB
from huggingface_hub import interpreter_login
from datasets import load_dataset, DatasetDict, load_from_disk
from transformers import WhisperProcessor
from transformers import WhisperForConditionalGeneration
from transformers import Seq2SeqTrainingArguments
from transformers import Seq2SeqTrainer
from transformers import EarlyStoppingCallback
from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from peft import prepare_model_for_int8_training
from peft import PeftModel, LoraModel, LoraConfig, get_peft_model
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import evaluate
import os
class SavePeftModelCallback(TrainerCallback):
def on_save(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
kwargs["model"].save_pretrained(peft_model_path)
pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
if os.path.exists(pytorch_model_path):
os.remove(pytorch_model_path)
return control
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
processor: Any
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need different padding methods
# first treat the audio inputs by simply returning torch tensors
input_features = [{"input_features": feature["input_features"]} for feature in features]
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
# get the tokenized label sequences
label_features = [{"input_ids": feature["labels"]} for feature in features]
# ******************This is only in the case of augmented data ***************** Remove if not
batch["attention_mask"] = torch.LongTensor([feature["attention_mask"] for feature in features])
# pad the labels to max length
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
# if bos token is appended in previous tokenization step,
# cut bos token here as it's append later anyways
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
labels = labels[:, 1:]
batch["labels"] = labels
return batch
def compute_metrics(pred):
pred_ids = pred.predictions
label_ids = pred.label_ids
# replace -100 with the pad_token_id
label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
# we do not want to group tokens when computing the metrics
pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
wer = 100 * metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer}
if __name__ == "__main__":
early_stopping_callback = EarlyStoppingCallback(
early_stopping_patience=3, # Stop training if the metric doesn't improve for 3 evaluations
early_stopping_threshold=0.0005, # Minimum change in the metric to be considered an improvement
)
# Load Dataset
processed_dataset = DatasetDict()
processed_dataset = load_from_disk("./vin_clean")
print(processed_dataset)
# load processor
processor = WhisperProcessor.from_pretrained("openai/whisper-medium", language="Vietnamese", task="transcribe")
# intialize data collator
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
# download metric
metric = evaluate.load("wer")
# Download model in 8bit
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-medium", load_in_8bit=True, device_map="auto")
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []
# preparing model with PEFT
model = prepare_model_for_int8_training(model, output_imbedding_layer="proj_out")
config = LoraConfig(r=32, lora_alpha=64, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none")
model = get_peft_model(model, config)
model.print_trainable_parameters()
# Define trainnig arguments
training_args = Seq2SeqTrainingArguments(
output_dir="./whisper-medium-Lora", # change to a repo name of your choice
per_device_train_batch_size=32,
gradient_accumulation_steps=2, # increase by 2x for every 2x decrease in batch size
learning_rate=5e-5,
warmup_steps=500,
max_steps=10000,
evaluation_strategy="steps",
gradient_checkpointing=True,
optim="adamw_torch",
fp16=True,
per_device_eval_batch_size=8,
generation_max_length=225,
save_steps=2000,
eval_steps=500,
logging_steps=25,
report_to=["tensorboard"],
predict_with_generate=True,
# load_best_model_at_end=True,
metric_for_best_model="wer",
greater_is_better=False,
# required as the PeftModel forward doesn't have the signature of the wrapped model's forward
remove_unused_columns=False,
label_names=["labels"], # same reason as above
push_to_hub=False,
)
# initialize trainer
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=processed_dataset["train"],
eval_dataset=processed_dataset["test"],
data_collator=data_collator,
tokenizer=processor.feature_extractor,
callbacks=[early_stopping_callback, SavePeftModelCallback],
)
# start training
trainer.train()
# set up args and push to hub
kwargs = {
"dataset": "vin100h",
"language": "vi",
"model_name": "Whisper Medium LoRA - Clean Data",
"finetuned_from": "openai/whisper-medium",
"tasks": "automatic-speech-recognition",
}
model.push_to_hub(**kwargs)