File size: 6,359 Bytes
c6b1960 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
from huggingface_hub import interpreter_login
from datasets import load_dataset, DatasetDict, load_from_disk
from transformers import WhisperProcessor
from transformers import WhisperForConditionalGeneration
from transformers import Seq2SeqTrainingArguments
from transformers import Seq2SeqTrainer
from transformers import EarlyStoppingCallback
from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from peft import prepare_model_for_int8_training
from peft import PeftModel, LoraModel, LoraConfig, get_peft_model
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import evaluate
import os
class SavePeftModelCallback(TrainerCallback):
def on_save(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
kwargs["model"].save_pretrained(peft_model_path)
pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
if os.path.exists(pytorch_model_path):
os.remove(pytorch_model_path)
return control
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
processor: Any
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need different padding methods
# first treat the audio inputs by simply returning torch tensors
input_features = [{"input_features": feature["input_features"]} for feature in features]
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
# get the tokenized label sequences
label_features = [{"input_ids": feature["labels"]} for feature in features]
# ******************This is only in the case of augmented data ***************** Remove if not
batch["attention_mask"] = torch.LongTensor([feature["attention_mask"] for feature in features])
# pad the labels to max length
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
# if bos token is appended in previous tokenization step,
# cut bos token here as it's append later anyways
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
labels = labels[:, 1:]
batch["labels"] = labels
return batch
def compute_metrics(pred):
pred_ids = pred.predictions
label_ids = pred.label_ids
# replace -100 with the pad_token_id
label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
# we do not want to group tokens when computing the metrics
pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
wer = 100 * metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer}
if __name__ == "__main__":
early_stopping_callback = EarlyStoppingCallback(
early_stopping_patience=3, # Stop training if the metric doesn't improve for 3 evaluations
early_stopping_threshold=0.0005, # Minimum change in the metric to be considered an improvement
)
# Load Dataset
processed_dataset = DatasetDict()
processed_dataset = load_from_disk("./vin_clean")
print(processed_dataset)
# load processor
processor = WhisperProcessor.from_pretrained("openai/whisper-medium", language="Vietnamese", task="transcribe")
# intialize data collator
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
# download metric
metric = evaluate.load("wer")
# Download model in 8bit
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-medium", load_in_8bit=True, device_map="auto")
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []
# preparing model with PEFT
model = prepare_model_for_int8_training(model, output_imbedding_layer="proj_out")
config = LoraConfig(r=32, lora_alpha=64, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none")
model = get_peft_model(model, config)
model.print_trainable_parameters()
# Define trainnig arguments
training_args = Seq2SeqTrainingArguments(
output_dir="./whisper-medium-Lora", # change to a repo name of your choice
per_device_train_batch_size=32,
gradient_accumulation_steps=2, # increase by 2x for every 2x decrease in batch size
learning_rate=5e-5,
warmup_steps=500,
max_steps=10000,
evaluation_strategy="steps",
gradient_checkpointing=True,
optim="adamw_torch",
fp16=True,
per_device_eval_batch_size=8,
generation_max_length=225,
save_steps=2000,
eval_steps=500,
logging_steps=25,
report_to=["tensorboard"],
predict_with_generate=True,
# load_best_model_at_end=True,
metric_for_best_model="wer",
greater_is_better=False,
# required as the PeftModel forward doesn't have the signature of the wrapped model's forward
remove_unused_columns=False,
label_names=["labels"], # same reason as above
push_to_hub=False,
)
# initialize trainer
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=processed_dataset["train"],
eval_dataset=processed_dataset["test"],
data_collator=data_collator,
tokenizer=processor.feature_extractor,
callbacks=[early_stopping_callback, SavePeftModelCallback],
)
# start training
trainer.train()
# set up args and push to hub
kwargs = {
"dataset": "vin100h",
"language": "vi",
"model_name": "Whisper Medium LoRA - Clean Data",
"finetuned_from": "openai/whisper-medium",
"tasks": "automatic-speech-recognition",
}
model.push_to_hub(**kwargs) |