File size: 4,387 Bytes
5381499 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import sys
sys.path.append("..")
from argparse import ArgumentParser
import os, string
from transformers import (
Wav2Vec2ForPreTraining,
Wav2Vec2CTCTokenizer,
Wav2Vec2FeatureExtractor,
)
from pytorch_lightning import seed_everything
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger
from src.datamodule import VLSP2020TarDataset
from src.datamodule.vlsp2020 import get_dataloader
from finetuning.wav2vec2 import SpeechRecognizer
def remove_punctuation(text: str):
return text.translate(str.maketrans("", "", string.punctuation)).lower()
def prepare_dataloader(data_dir, batch_size, num_workers):
train_dataset = VLSP2020TarDataset(
os.path.join(data_dir, "vlsp2020_train_set.tar")
).load()
val_dataset = VLSP2020TarDataset(
os.path.join(data_dir, "vlsp2020_val_set.tar")
).load()
train_dataloader = get_dataloader(
train_dataset,
return_transcript=True,
target_transform=remove_punctuation,
batch_size=batch_size,
num_workers=num_workers,
)
val_dataloader = get_dataloader(
val_dataset,
return_transcript=True,
target_transform=remove_punctuation,
batch_size=batch_size,
num_workers=num_workers,
)
return train_dataloader, val_dataloader
def prepare_model(adam_config: dict, tristate_scheduler_config: dict):
model_name = "nguyenvulebinh/wav2vec2-base-vietnamese-250h"
wav2vec2 = Wav2Vec2ForPreTraining.from_pretrained(model_name)
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(model_name)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
model = SpeechRecognizer(
wav2vec2, tokenizer, feature_extractor, adam_config, tristate_scheduler_config
)
return model
def main():
parser = ArgumentParser()
parser.add_argument("--batch_size", type=int, default=2)
parser.add_argument("--num_workers", type=int, default=0)
parser.add_argument("--classifier_lr", type=float, default=1e-4)
parser.add_argument("--wav2vec2_lr", type=float, default=1e-5)
parser.add_argument("--max_epochs", type=int, default=10)
parser.add_argument("--accelerator", type=str, default="gpu")
parser.add_argument("--weight_decay", type=float, default=0.0)
parser.add_argument("--warmup_steps", type=float, default=0.1)
parser.add_argument("--constant_steps", type=float, default=0.4)
parser.add_argument("--scheduler_factor", type=float, default=1e-3)
parser.add_argument("--data_dir", type=str, default="data")
parser.add_argument("--ckpt_dir", type=str, default="ckpt")
parser.add_argument("--ckpt_path", type=str, default=None)
parser.add_argument("--detect_anomaly", type=bool, default=False)
parser.add_argument("--grad_clip", type=float, default=None)
parser.add_argument("--wandb_id", type=str, default=None)
args = parser.parse_args()
print(args)
train_loader, val_loader = prepare_dataloader(
args.data_dir, args.batch_size, args.num_workers
)
total_steps = args.max_epochs * 42_000 // args.batch_size
warmup_steps = int(total_steps * args.warmup_steps)
constant_steps = int(total_steps * args.constant_steps)
model = prepare_model(
{
"wav2vec2_lr": args.wav2vec2_lr,
"classifier_lr": args.classifier_lr,
"weight_decay": args.weight_decay,
},
{
"warmup_steps": warmup_steps,
"constant_steps": constant_steps,
"total_steps": total_steps,
"factor": args.scheduler_factor,
},
)
trainer = Trainer(
accelerator=args.accelerator,
callbacks=[
ModelCheckpoint(
args.ckpt_dir,
monitor="val/wer",
mode="min",
save_top_k=1,
save_last=True,
),
LearningRateMonitor(logging_interval="step"),
],
logger=WandbLogger(project="Wav2Vec2", id=args.wandb_id),
max_epochs=args.max_epochs,
detect_anomaly=args.detect_anomaly,
gradient_clip_val=args.grad_clip,
)
trainer.fit(model, train_loader, val_loader)
if __name__ == "__main__":
seed_everything(188)
main()
|