File size: 788 Bytes
5381499
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import sys
sys.path.append(".")

from src.config import model as conf
from src.model import Wav2Vec2PretrainingModule
from src.datamodule import WebDatasetConverter, VLSP2020ForPretrainingDataModule
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint


if __name__ == "__main__":

    model = Wav2Vec2PretrainingModule(conf.wav2vec2_pretraining)
    dts = WebDatasetConverter(conf.dataset.path).get_dataset()
    dtm = VLSP2020ForPretrainingDataModule(dts, **conf.dataset)
    trainer = Trainer(
        callbacks=[
            ModelCheckpoint(
                monitor="val/loss",
                dirpath=conf["checkpoint_dir"],
            )
        ],
        gradient_clip_val=1.0,
        accelerator="gpu"
    )

    trainer.fit(model, dtm)