File size: 6,359 Bytes
c6b1960
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
from huggingface_hub import interpreter_login

from datasets import load_dataset, DatasetDict, load_from_disk

from transformers import WhisperProcessor
from transformers import WhisperForConditionalGeneration
from transformers import Seq2SeqTrainingArguments
from transformers import Seq2SeqTrainer
from transformers import EarlyStoppingCallback
from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR

from peft import prepare_model_for_int8_training
from peft import PeftModel, LoraModel, LoraConfig, get_peft_model

import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Union

import evaluate

import os

class SavePeftModelCallback(TrainerCallback):
    def on_save(
        self,
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")

        peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
        kwargs["model"].save_pretrained(peft_model_path)

        pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
        if os.path.exists(pytorch_model_path):
            os.remove(pytorch_model_path)
        return control

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        # ******************This is only in the case of augmented data ***************** Remove if not
        batch["attention_mask"] = torch.LongTensor([feature["attention_mask"] for feature in features])

        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

         # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}



if __name__ == "__main__":


    early_stopping_callback = EarlyStoppingCallback(
        early_stopping_patience=3,  # Stop training if the metric doesn't improve for 3 evaluations
        early_stopping_threshold=0.0005,  # Minimum change in the metric to be considered an improvement
    )

    # Load Dataset
    processed_dataset = DatasetDict()
    processed_dataset = load_from_disk("./vin_clean")


    print(processed_dataset)

    # load processor
    processor = WhisperProcessor.from_pretrained("openai/whisper-medium", language="Vietnamese", task="transcribe")


    # intialize data collator
    data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

    # download metric
    metric = evaluate.load("wer")

    # Download model in 8bit
    model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-medium", load_in_8bit=True, device_map="auto")
    model.config.forced_decoder_ids = None
    model.config.suppress_tokens = []

    # preparing model with PEFT
    model = prepare_model_for_int8_training(model, output_imbedding_layer="proj_out")

    config = LoraConfig(r=32, lora_alpha=64, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none")

    model = get_peft_model(model, config)
    model.print_trainable_parameters()


    # Define trainnig arguments
    training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-medium-Lora",  # change to a repo name of your choice
    per_device_train_batch_size=32,
    gradient_accumulation_steps=2,  # increase by 2x for every 2x decrease in batch size
    learning_rate=5e-5,
    warmup_steps=500,
    max_steps=10000,
    evaluation_strategy="steps",
    gradient_checkpointing=True,
    optim="adamw_torch",
    fp16=True,
    per_device_eval_batch_size=8,
    generation_max_length=225,
    save_steps=2000,
    eval_steps=500,
    logging_steps=25,
    report_to=["tensorboard"],
    predict_with_generate=True,
    # load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    # required as the PeftModel forward doesn't have the signature of the wrapped model's forward
    remove_unused_columns=False,
    label_names=["labels"],  # same reason as above
    push_to_hub=False,
)

    # initialize trainer
    trainer = Seq2SeqTrainer(
        args=training_args,
        model=model,
        train_dataset=processed_dataset["train"],
        eval_dataset=processed_dataset["test"],
        data_collator=data_collator,
        tokenizer=processor.feature_extractor,
        callbacks=[early_stopping_callback, SavePeftModelCallback],
    )


    # start training
    trainer.train()


    # set up args and push to hub
    kwargs = {
        "dataset": "vin100h",
        "language": "vi",
        "model_name": "Whisper Medium LoRA - Clean Data",
        "finetuned_from": "openai/whisper-medium",
        "tasks": "automatic-speech-recognition",
    }

    model.push_to_hub(**kwargs)