import argparse import functools import os import platform import torch from peft import LoraConfig, get_peft_model, AdaLoraConfig, PeftModel, prepare_model_for_kbit_training from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, WhisperForConditionalGeneration, WhisperProcessor from utils.callback import SavePeftModelCallback from utils.data_utils import DataCollatorSpeechSeq2SeqWithPadding from utils.model_utils import load_from_checkpoint from utils.reader import CustomDataset from utils.utils import print_arguments, make_inputs_require_grad, add_arguments parser = argparse.ArgumentParser(description=__doc__) add_arg = functools.partial(add_arguments, argparser=parser) add_arg("train_data", type=str, default="dataset/train.json", help="") add_arg("test_data", type=str, default="dataset/test.json", help="") add_arg("base_model", type=str, default="openai/whisper-tiny", help="Whisper") add_arg("output_dir", type=str, default="output/", help="") add_arg("warmup_steps", type=int, default=50, help="") add_arg("logging_steps", type=int, default=100, help="") add_arg("eval_steps", type=int, default=1000, help="") add_arg("save_steps", type=int, default=1000, help="") add_arg("num_workers", type=int, default=8, help="") add_arg("learning_rate", type=float, default=1e-3, help="") add_arg("min_audio_len", type=float, default=0.5, help="") add_arg("max_audio_len", type=float, default=30, help="") add_arg("use_adalora", type=bool, default=True, help="AdaLora/Lora") add_arg("fp16", type=bool, default=True, help="fp16") add_arg("use_8bit", type=bool, default=False, help="8 bit") add_arg("timestamps", type=bool, default=False, help="") add_arg("local_files_only", type=bool, default=False, help="") add_arg("num_train_epochs", type=int, default=3, help="") add_arg("language", type=str, default="bn", help="") add_arg("task", type=str, default="transcribe", choices=['transcribe', 'translate'], help="模型的任务") add_arg("augment_config_path", type=str, default=None, help="") add_arg("resume_from_checkpoint", type=str, default=None, help="") add_arg("per_device_train_batch_size", type=int, default=8, help="batch size") add_arg("per_device_eval_batch_size", type=int, default=8, help="batch size") add_arg("gradient_accumulation_steps", type=int, default=1, help="") args = parser.parse_args() print_arguments(args) # Whisper tokenizer processor = WhisperProcessor.from_pretrained(args.base_model, language=args.language, task=args.task, no_timestamps=not args.timestamps, local_files_only=args.local_files_only) # train_dataset = CustomDataset(data_list_path=args.train_data, processor=processor, language=args.language, timestamps=args.timestamps, min_duration=args.min_audio_len, max_duration=args.max_audio_len, augment_config_path=args.augment_config_path) test_dataset = CustomDataset(data_list_path=args.test_data, processor=processor, language=args.language, timestamps=args.timestamps, min_duration=args.min_audio_len, max_duration=args.max_audio_len) print(f"len train - {len(train_dataset)} test len - {len(test_dataset)}") # padding data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor) # Whisper device_map = "auto" world_size = int(os.environ.get("WORLD_SIZE", 1)) ddp = world_size != 1 if ddp: device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} # model = WhisperForConditionalGeneration.from_pretrained(args.base_model, load_in_8bit=args.use_8bit, device_map=device_map, local_files_only=args.local_files_only) model.config.forced_decoder_ids = None model.config.suppress_tokens = [] # model = prepare_model_for_kbit_training(model) # forward,req grad model.model.encoder.conv1.register_forward_hook(make_inputs_require_grad) print('加载LoRA模块...') if args.resume_from_checkpoint: # print("Loading adapters from checkpoint.") model = PeftModel.from_pretrained(model, args.resume_from_checkpoint, is_trainable=True) else: print(f'adding LoRA modules...') target_modules = ["k_proj", "q_proj", "v_proj", "out_proj", "fc1", "fc2"] print(target_modules) if args.use_adalora: config = AdaLoraConfig(init_r=12, target_r=4, beta1=0.85, beta2=0.85, tinit=200, tfinal=1000, deltaT=10, lora_alpha=32, lora_dropout=0.1, orth_reg_weight=0.5, target_modules=target_modules) else: config = LoraConfig(r=32, lora_alpha=64, target_modules=target_modules, lora_dropout=0.05, bias="none") model = get_peft_model(model, config) output_dir = os.path.join(args.output_dir, os.path.basename(args.base_model)) # training_args = \ Seq2SeqTrainingArguments(output_dir=output_dir, # Directory to save checkpoints per_device_train_batch_size=args.per_device_train_batch_size, # Training batch_size size per_device_eval_batch_size=args.per_device_eval_batch_size, # Eval batch_size gradient_accumulation_steps=args.gradient_accumulation_steps, # Cumulative steps of training gradient learning_rate=args.learning_rate, # learning rate size warmup_steps=args.warmup_steps, # Warm-up steps num_train_epochs=args.num_train_epochs, # epochs save_strategy="steps", # evaluation_strategy="steps", # load_best_model_at_end=True, # fp16=args.fp16, # report_to=["tensorboard"], # tensorboard save_steps=args.save_steps, # eval_steps=args.eval_steps, # save_total_limit=5, # optim='adamw_torch', # ddp_find_unused_parameters=False if ddp else None, # dataloader_num_workers=args.num_workers, # logging_steps=args.logging_steps, # remove_unused_columns=False, # label_names=["labels"]) # if training_args.local_rank == 0 or training_args.local_rank == -1: print('=' * 90) model.print_trainable_parameters() print('=' * 90) # Pytorch2.0 if torch.__version__ >= "2" and platform.system().lower() == 'windows': model = torch.compile(model) # trainer = Seq2SeqTrainer(args=training_args, model=model, train_dataset=train_dataset, eval_dataset=test_dataset, data_collator=data_collator, tokenizer=processor.feature_extractor, callbacks=[SavePeftModelCallback]) model.config.use_cache = False trainer._load_from_checkpoint = load_from_checkpoint # trainer.train(resume_from_checkpoint=args.resume_from_checkpoint) # trainer.save_state() if training_args.local_rank == 0 or training_args.local_rank == -1: model.save_pretrained(os.path.join(output_dir, "checkpoint-final"))