wav2vec2 / src /train.py
hoang1007
init
5381499
raw
history blame contribute delete
No virus
788 Bytes
import sys
sys.path.append(".")
from src.config import model as conf
from src.model import Wav2Vec2PretrainingModule
from src.datamodule import WebDatasetConverter, VLSP2020ForPretrainingDataModule
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
if __name__ == "__main__":
model = Wav2Vec2PretrainingModule(conf.wav2vec2_pretraining)
dts = WebDatasetConverter(conf.dataset.path).get_dataset()
dtm = VLSP2020ForPretrainingDataModule(dts, **conf.dataset)
trainer = Trainer(
callbacks=[
ModelCheckpoint(
monitor="val/loss",
dirpath=conf["checkpoint_dir"],
)
],
gradient_clip_val=1.0,
accelerator="gpu"
)
trainer.fit(model, dtm)