|
from huggingface_hub import interpreter_login |
|
|
|
from datasets import load_dataset, DatasetDict, load_from_disk |
|
|
|
from transformers import WhisperProcessor |
|
from transformers import WhisperForConditionalGeneration |
|
from transformers import Seq2SeqTrainingArguments |
|
from transformers import Seq2SeqTrainer |
|
from transformers import EarlyStoppingCallback |
|
from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl |
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR |
|
|
|
from peft import prepare_model_for_int8_training |
|
from peft import PeftModel, LoraModel, LoraConfig, get_peft_model |
|
|
|
import torch |
|
|
|
from dataclasses import dataclass |
|
from typing import Any, Dict, List, Union |
|
|
|
import evaluate |
|
|
|
import os |
|
|
|
class SavePeftModelCallback(TrainerCallback): |
|
def on_save( |
|
self, |
|
args: TrainingArguments, |
|
state: TrainerState, |
|
control: TrainerControl, |
|
**kwargs, |
|
): |
|
checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}") |
|
|
|
peft_model_path = os.path.join(checkpoint_folder, "adapter_model") |
|
kwargs["model"].save_pretrained(peft_model_path) |
|
|
|
pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin") |
|
if os.path.exists(pytorch_model_path): |
|
os.remove(pytorch_model_path) |
|
return control |
|
|
|
@dataclass |
|
class DataCollatorSpeechSeq2SeqWithPadding: |
|
processor: Any |
|
|
|
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: |
|
|
|
|
|
input_features = [{"input_features": feature["input_features"]} for feature in features] |
|
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt") |
|
|
|
|
|
label_features = [{"input_ids": feature["labels"]} for feature in features] |
|
|
|
|
|
batch["attention_mask"] = torch.LongTensor([feature["attention_mask"] for feature in features]) |
|
|
|
|
|
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt") |
|
|
|
|
|
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) |
|
|
|
|
|
|
|
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item(): |
|
labels = labels[:, 1:] |
|
|
|
batch["labels"] = labels |
|
|
|
return batch |
|
|
|
def compute_metrics(pred): |
|
pred_ids = pred.predictions |
|
label_ids = pred.label_ids |
|
|
|
|
|
label_ids[label_ids == -100] = processor.tokenizer.pad_token_id |
|
|
|
|
|
pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True) |
|
label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True) |
|
|
|
wer = 100 * metric.compute(predictions=pred_str, references=label_str) |
|
|
|
return {"wer": wer} |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
early_stopping_callback = EarlyStoppingCallback( |
|
early_stopping_patience=3, |
|
early_stopping_threshold=0.0005, |
|
) |
|
|
|
|
|
processed_dataset = DatasetDict() |
|
processed_dataset = load_from_disk("./vin_clean") |
|
|
|
|
|
print(processed_dataset) |
|
|
|
|
|
processor = WhisperProcessor.from_pretrained("openai/whisper-medium", language="Vietnamese", task="transcribe") |
|
|
|
|
|
|
|
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor) |
|
|
|
|
|
metric = evaluate.load("wer") |
|
|
|
|
|
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-medium", load_in_8bit=True, device_map="auto") |
|
model.config.forced_decoder_ids = None |
|
model.config.suppress_tokens = [] |
|
|
|
|
|
model = prepare_model_for_int8_training(model, output_imbedding_layer="proj_out") |
|
|
|
config = LoraConfig(r=32, lora_alpha=64, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none") |
|
|
|
model = get_peft_model(model, config) |
|
model.print_trainable_parameters() |
|
|
|
|
|
|
|
training_args = Seq2SeqTrainingArguments( |
|
output_dir="./whisper-medium-Lora", |
|
per_device_train_batch_size=32, |
|
gradient_accumulation_steps=2, |
|
learning_rate=5e-5, |
|
warmup_steps=500, |
|
max_steps=10000, |
|
evaluation_strategy="steps", |
|
gradient_checkpointing=True, |
|
optim="adamw_torch", |
|
fp16=True, |
|
per_device_eval_batch_size=8, |
|
generation_max_length=225, |
|
save_steps=2000, |
|
eval_steps=500, |
|
logging_steps=25, |
|
report_to=["tensorboard"], |
|
predict_with_generate=True, |
|
|
|
metric_for_best_model="wer", |
|
greater_is_better=False, |
|
|
|
remove_unused_columns=False, |
|
label_names=["labels"], |
|
push_to_hub=False, |
|
) |
|
|
|
|
|
trainer = Seq2SeqTrainer( |
|
args=training_args, |
|
model=model, |
|
train_dataset=processed_dataset["train"], |
|
eval_dataset=processed_dataset["test"], |
|
data_collator=data_collator, |
|
tokenizer=processor.feature_extractor, |
|
callbacks=[early_stopping_callback, SavePeftModelCallback], |
|
) |
|
|
|
|
|
|
|
trainer.train() |
|
|
|
|
|
|
|
kwargs = { |
|
"dataset": "vin100h", |
|
"language": "vi", |
|
"model_name": "Whisper Medium LoRA - Clean Data", |
|
"finetuned_from": "openai/whisper-medium", |
|
"tasks": "automatic-speech-recognition", |
|
} |
|
|
|
model.push_to_hub(**kwargs) |