|
import sys |
|
sys.path.append(".") |
|
|
|
from src.config import model as conf |
|
from src.model import Wav2Vec2PretrainingModule |
|
from src.datamodule import WebDatasetConverter, VLSP2020ForPretrainingDataModule |
|
from pytorch_lightning import Trainer |
|
from pytorch_lightning.callbacks import ModelCheckpoint |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
model = Wav2Vec2PretrainingModule(conf.wav2vec2_pretraining) |
|
dts = WebDatasetConverter(conf.dataset.path).get_dataset() |
|
dtm = VLSP2020ForPretrainingDataModule(dts, **conf.dataset) |
|
trainer = Trainer( |
|
callbacks=[ |
|
ModelCheckpoint( |
|
monitor="val/loss", |
|
dirpath=conf["checkpoint_dir"], |
|
) |
|
], |
|
gradient_clip_val=1.0, |
|
accelerator="gpu" |
|
) |
|
|
|
trainer.fit(model, dtm) |
|
|