wav2vec2 / finetuning /train.py
hoang1007
init
5381499
raw
history blame
No virus
4.39 kB
import sys
sys.path.append("..")
from argparse import ArgumentParser
import os, string
from transformers import (
Wav2Vec2ForPreTraining,
Wav2Vec2CTCTokenizer,
Wav2Vec2FeatureExtractor,
)
from pytorch_lightning import seed_everything
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger
from src.datamodule import VLSP2020TarDataset
from src.datamodule.vlsp2020 import get_dataloader
from finetuning.wav2vec2 import SpeechRecognizer
def remove_punctuation(text: str):
return text.translate(str.maketrans("", "", string.punctuation)).lower()
def prepare_dataloader(data_dir, batch_size, num_workers):
train_dataset = VLSP2020TarDataset(
os.path.join(data_dir, "vlsp2020_train_set.tar")
).load()
val_dataset = VLSP2020TarDataset(
os.path.join(data_dir, "vlsp2020_val_set.tar")
).load()
train_dataloader = get_dataloader(
train_dataset,
return_transcript=True,
target_transform=remove_punctuation,
batch_size=batch_size,
num_workers=num_workers,
)
val_dataloader = get_dataloader(
val_dataset,
return_transcript=True,
target_transform=remove_punctuation,
batch_size=batch_size,
num_workers=num_workers,
)
return train_dataloader, val_dataloader
def prepare_model(adam_config: dict, tristate_scheduler_config: dict):
model_name = "nguyenvulebinh/wav2vec2-base-vietnamese-250h"
wav2vec2 = Wav2Vec2ForPreTraining.from_pretrained(model_name)
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(model_name)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
model = SpeechRecognizer(
wav2vec2, tokenizer, feature_extractor, adam_config, tristate_scheduler_config
)
return model
def main():
parser = ArgumentParser()
parser.add_argument("--batch_size", type=int, default=2)
parser.add_argument("--num_workers", type=int, default=0)
parser.add_argument("--classifier_lr", type=float, default=1e-4)
parser.add_argument("--wav2vec2_lr", type=float, default=1e-5)
parser.add_argument("--max_epochs", type=int, default=10)
parser.add_argument("--accelerator", type=str, default="gpu")
parser.add_argument("--weight_decay", type=float, default=0.0)
parser.add_argument("--warmup_steps", type=float, default=0.1)
parser.add_argument("--constant_steps", type=float, default=0.4)
parser.add_argument("--scheduler_factor", type=float, default=1e-3)
parser.add_argument("--data_dir", type=str, default="data")
parser.add_argument("--ckpt_dir", type=str, default="ckpt")
parser.add_argument("--ckpt_path", type=str, default=None)
parser.add_argument("--detect_anomaly", type=bool, default=False)
parser.add_argument("--grad_clip", type=float, default=None)
parser.add_argument("--wandb_id", type=str, default=None)
args = parser.parse_args()
print(args)
train_loader, val_loader = prepare_dataloader(
args.data_dir, args.batch_size, args.num_workers
)
total_steps = args.max_epochs * 42_000 // args.batch_size
warmup_steps = int(total_steps * args.warmup_steps)
constant_steps = int(total_steps * args.constant_steps)
model = prepare_model(
{
"wav2vec2_lr": args.wav2vec2_lr,
"classifier_lr": args.classifier_lr,
"weight_decay": args.weight_decay,
},
{
"warmup_steps": warmup_steps,
"constant_steps": constant_steps,
"total_steps": total_steps,
"factor": args.scheduler_factor,
},
)
trainer = Trainer(
accelerator=args.accelerator,
callbacks=[
ModelCheckpoint(
args.ckpt_dir,
monitor="val/wer",
mode="min",
save_top_k=1,
save_last=True,
),
LearningRateMonitor(logging_interval="step"),
],
logger=WandbLogger(project="Wav2Vec2", id=args.wandb_id),
max_epochs=args.max_epochs,
detect_anomaly=args.detect_anomaly,
gradient_clip_val=args.grad_clip,
)
trainer.fit(model, train_loader, val_loader)
if __name__ == "__main__":
seed_everything(188)
main()