from huggingface_hub import interpreter_login from datasets import load_dataset, DatasetDict, load_from_disk from transformers import WhisperProcessor from transformers import WhisperForConditionalGeneration from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainer from transformers import EarlyStoppingCallback from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from peft import prepare_model_for_int8_training from peft import PeftModel, LoraModel, LoraConfig, get_peft_model import torch from dataclasses import dataclass from typing import Any, Dict, List, Union import evaluate import os class SavePeftModelCallback(TrainerCallback): def on_save( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs, ): checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}") peft_model_path = os.path.join(checkpoint_folder, "adapter_model") kwargs["model"].save_pretrained(peft_model_path) pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin") if os.path.exists(pytorch_model_path): os.remove(pytorch_model_path) return control @dataclass class DataCollatorSpeechSeq2SeqWithPadding: processor: Any def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: # split inputs and labels since they have to be of different lengths and need different padding methods # first treat the audio inputs by simply returning torch tensors input_features = [{"input_features": feature["input_features"]} for feature in features] batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt") # get the tokenized label sequences label_features = [{"input_ids": feature["labels"]} for feature in features] # ******************This is only in the case of augmented data ***************** Remove if not batch["attention_mask"] = torch.LongTensor([feature["attention_mask"] for feature in features]) # pad the labels to max length labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt") # replace padding with -100 to ignore loss correctly labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) # if bos token is appended in previous tokenization step, # cut bos token here as it's append later anyways if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item(): labels = labels[:, 1:] batch["labels"] = labels return batch def compute_metrics(pred): pred_ids = pred.predictions label_ids = pred.label_ids # replace -100 with the pad_token_id label_ids[label_ids == -100] = processor.tokenizer.pad_token_id # we do not want to group tokens when computing the metrics pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True) label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True) wer = 100 * metric.compute(predictions=pred_str, references=label_str) return {"wer": wer} if __name__ == "__main__": early_stopping_callback = EarlyStoppingCallback( early_stopping_patience=3, # Stop training if the metric doesn't improve for 3 evaluations early_stopping_threshold=0.0005, # Minimum change in the metric to be considered an improvement ) # Load Dataset processed_dataset = DatasetDict() processed_dataset = load_from_disk("./vin_clean") print(processed_dataset) # load processor processor = WhisperProcessor.from_pretrained("openai/whisper-medium", language="Vietnamese", task="transcribe") # intialize data collator data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor) # download metric metric = evaluate.load("wer") # Download model in 8bit model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-medium", load_in_8bit=True, device_map="auto") model.config.forced_decoder_ids = None model.config.suppress_tokens = [] # preparing model with PEFT model = prepare_model_for_int8_training(model, output_imbedding_layer="proj_out") config = LoraConfig(r=32, lora_alpha=64, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none") model = get_peft_model(model, config) model.print_trainable_parameters() # Define trainnig arguments training_args = Seq2SeqTrainingArguments( output_dir="./whisper-medium-Lora", # change to a repo name of your choice per_device_train_batch_size=32, gradient_accumulation_steps=2, # increase by 2x for every 2x decrease in batch size learning_rate=5e-5, warmup_steps=500, max_steps=10000, evaluation_strategy="steps", gradient_checkpointing=True, optim="adamw_torch", fp16=True, per_device_eval_batch_size=8, generation_max_length=225, save_steps=2000, eval_steps=500, logging_steps=25, report_to=["tensorboard"], predict_with_generate=True, # load_best_model_at_end=True, metric_for_best_model="wer", greater_is_better=False, # required as the PeftModel forward doesn't have the signature of the wrapped model's forward remove_unused_columns=False, label_names=["labels"], # same reason as above push_to_hub=False, ) # initialize trainer trainer = Seq2SeqTrainer( args=training_args, model=model, train_dataset=processed_dataset["train"], eval_dataset=processed_dataset["test"], data_collator=data_collator, tokenizer=processor.feature_extractor, callbacks=[early_stopping_callback, SavePeftModelCallback], ) # start training trainer.train() # set up args and push to hub kwargs = { "dataset": "vin100h", "language": "vi", "model_name": "Whisper Medium LoRA - Clean Data", "finetuned_from": "openai/whisper-medium", "tasks": "automatic-speech-recognition", } model.push_to_hub(**kwargs)