Spaces:
Runtime error
Runtime error
File size: 2,189 Bytes
44db343 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import os
from dataset.util import load_dataset
from dataset.vocab import Vocab
if __name__ == '__main__':
import argparse
description = '''
train.py:
Usage: python train.py --model tfmwtr --start-epoch n --data_path ./data --dataset binhvq
Params:
--start-epoch n
n = 0: training from beginning
n > 0: continue training from the nth epoch
--model
tfmwtr - Transformer with Tokenization Repair
--data_path: default to ./data
--dataset: default to 'binhvq'
'''
parser = argparse.ArgumentParser(description=description)
parser.add_argument('--model', type=str, default='tfmwtr')
parser.add_argument('--start_epoch', type=int, default=0)
parser.add_argument('--data_path', type=str, default='./data')
parser.add_argument('--dataset', type=str, default='binhvq')
args = parser.parse_args()
dataset_path = os.path.join(args.data_path, f'{args.dataset}')
vocab_path = os.path.join(dataset_path, f'{args.dataset}.vocab.pkl')
vocab = Vocab()
vocab.load_vocab_dict(vocab_path)
checkpoint_dir = os.path.join(args.data_path, f'checkpoints/{args.model}')
incorrect_file = f'{args.dataset}.train.noise'
correct_file = f'{args.dataset}.train'
length_file = f'{args.dataset}.length.train'
valid_incorrect_file = f'{args.dataset}.valid.noise'
valid_correct_file = f'{args.dataset}.valid'
valid_length_file = f'{args.dataset}.length.valid'
valid_data = load_dataset(base_path=dataset_path, corr_file=valid_correct_file, incorr_file=valid_incorrect_file,
length_file = valid_length_file)
from dataset.autocorrect_dataset import SpellCorrectDataset
from models.trainer import Trainer
from models.model import ModelWrapper
valid_dataset = SpellCorrectDataset(dataset=valid_data)
model_wrapper = ModelWrapper(args.model, vocab)
trainer = Trainer(model_wrapper, dataset_path, args.dataset, valid_dataset)
trainer.load_checkpoint(checkpoint_dir, args.dataset, args.start_epoch)
trainer.train()
|