import sys sys.path.append("..") from argparse import ArgumentParser import os, string from transformers import ( Wav2Vec2ForPreTraining, Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, ) from pytorch_lightning import seed_everything from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor from pytorch_lightning.loggers import WandbLogger from src.datamodule import VLSP2020TarDataset from src.datamodule.vlsp2020 import get_dataloader from finetuning.wav2vec2 import SpeechRecognizer def remove_punctuation(text: str): return text.translate(str.maketrans("", "", string.punctuation)).lower() def prepare_dataloader(data_dir, batch_size, num_workers): train_dataset = VLSP2020TarDataset( os.path.join(data_dir, "vlsp2020_train_set.tar") ).load() val_dataset = VLSP2020TarDataset( os.path.join(data_dir, "vlsp2020_val_set.tar") ).load() train_dataloader = get_dataloader( train_dataset, return_transcript=True, target_transform=remove_punctuation, batch_size=batch_size, num_workers=num_workers, ) val_dataloader = get_dataloader( val_dataset, return_transcript=True, target_transform=remove_punctuation, batch_size=batch_size, num_workers=num_workers, ) return train_dataloader, val_dataloader def prepare_model(adam_config: dict, tristate_scheduler_config: dict): model_name = "nguyenvulebinh/wav2vec2-base-vietnamese-250h" wav2vec2 = Wav2Vec2ForPreTraining.from_pretrained(model_name) tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(model_name) feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name) model = SpeechRecognizer( wav2vec2, tokenizer, feature_extractor, adam_config, tristate_scheduler_config ) return model def main(): parser = ArgumentParser() parser.add_argument("--batch_size", type=int, default=2) parser.add_argument("--num_workers", type=int, default=0) parser.add_argument("--classifier_lr", type=float, default=1e-4) parser.add_argument("--wav2vec2_lr", type=float, default=1e-5) parser.add_argument("--max_epochs", type=int, default=10) parser.add_argument("--accelerator", type=str, default="gpu") parser.add_argument("--weight_decay", type=float, default=0.0) parser.add_argument("--warmup_steps", type=float, default=0.1) parser.add_argument("--constant_steps", type=float, default=0.4) parser.add_argument("--scheduler_factor", type=float, default=1e-3) parser.add_argument("--data_dir", type=str, default="data") parser.add_argument("--ckpt_dir", type=str, default="ckpt") parser.add_argument("--ckpt_path", type=str, default=None) parser.add_argument("--detect_anomaly", type=bool, default=False) parser.add_argument("--grad_clip", type=float, default=None) parser.add_argument("--wandb_id", type=str, default=None) args = parser.parse_args() print(args) train_loader, val_loader = prepare_dataloader( args.data_dir, args.batch_size, args.num_workers ) total_steps = args.max_epochs * 42_000 // args.batch_size warmup_steps = int(total_steps * args.warmup_steps) constant_steps = int(total_steps * args.constant_steps) model = prepare_model( { "wav2vec2_lr": args.wav2vec2_lr, "classifier_lr": args.classifier_lr, "weight_decay": args.weight_decay, }, { "warmup_steps": warmup_steps, "constant_steps": constant_steps, "total_steps": total_steps, "factor": args.scheduler_factor, }, ) trainer = Trainer( accelerator=args.accelerator, callbacks=[ ModelCheckpoint( args.ckpt_dir, monitor="val/wer", mode="min", save_top_k=1, save_last=True, ), LearningRateMonitor(logging_interval="step"), ], logger=WandbLogger(project="Wav2Vec2", id=args.wandb_id), max_epochs=args.max_epochs, detect_anomaly=args.detect_anomaly, gradient_clip_val=args.grad_clip, ) trainer.fit(model, train_loader, val_loader) if __name__ == "__main__": seed_everything(188) main()