wav2vec2 / src /train.py
hoang1007
init
5381499
raw
history blame
788 Bytes
import sys
sys.path.append(".")
from src.config import model as conf
from src.model import Wav2Vec2PretrainingModule
from src.datamodule import WebDatasetConverter, VLSP2020ForPretrainingDataModule
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
if __name__ == "__main__":
model = Wav2Vec2PretrainingModule(conf.wav2vec2_pretraining)
dts = WebDatasetConverter(conf.dataset.path).get_dataset()
dtm = VLSP2020ForPretrainingDataModule(dts, **conf.dataset)
trainer = Trainer(
callbacks=[
ModelCheckpoint(
monitor="val/loss",
dirpath=conf["checkpoint_dir"],
)
],
gradient_clip_val=1.0,
accelerator="gpu"
)
trainer.fit(model, dtm)