File size: 6,369 Bytes
fb238e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import torch
import os
import random
import math
import argparse
from fengshen.data.fs_datasets.fs_datamodule import FSDataModule
from fengshen.example.deepVAE.vae_pl_module import DeepVAEModule

from pytorch_lightning import (
    Trainer,
    loggers,
)

from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from torch.nn.utils.rnn import pad_sequence


class NER_RE_Collator:
    def __init__(self, bos_token, eos_token, sep_token) -> None:
        self.bos_token = bos_token
        self.eos_token = eos_token
        self.sep_token = sep_token

    def __call__(self, samples, max_len=128):
        # when len(samples) is larger than one, we need to save the sentence length info
        inputs_tensors, entity_tensors = [], []
        for sp in samples:
            # NOTE: in TD-VAE, both encoder and decoder are gpt2, thus use decoder sent twice !
            input_entities, input_ids = sp['decoder_entities'], sp['decoder_target']
            input_entities = input_entities[:max_len] + [self.sep_token]
            # shorten input_ids, based on the fact that sentence must be longer than the entities
            input_ids = [self.bos_token] + input_ids[:max_len] + [self.eos_token]
            entity_tensors.append(torch.tensor(input_entities, dtype=torch.long))
            inputs_tensors.append(torch.tensor(input_ids, dtype=torch.long))
        if not inputs_tensors or not entity_tensors:
            return None  # if all the examples in the batch exceed max_length sentence
        inputs_tensors = pad_sequence(inputs_tensors, batch_first=True, padding_value=0)
        entity_tensors = pad_sequence(entity_tensors, batch_first=True, padding_value=0)
        return inputs_tensors, entity_tensors


class TDVAECollator:
    def __init__(self, bos_token, eos_token) -> None:
        self.bos_token = bos_token
        self.eos_token = eos_token

    def __call__(self, samples, max_len=120):
        # when len(samples) is larger than one, we need to save the sentence length info
        inputs = []
        for sp in samples:
            # NOTE: in TD-VAE, both encoder and decoder are gpt2, thus use decoder sent twice !
            sent_lengths, input_ids = sp['decoder_sent_lengths'], sp['decoder_target']
            potential_indices = [idx for idx, slen in enumerate(sent_lengths) if slen < max_len]
            if len(potential_indices) == 0:
                continue  # we ignore paragraphs with only one sentence split
            selected_idx = random.choice(potential_indices)
            start_pos, end_pos = sum(sent_lengths[:selected_idx]), sum(sent_lengths[:selected_idx])+sent_lengths[selected_idx]
            selected_input_ids = [self.bos_token] + input_ids[start_pos:end_pos] + [self.eos_token]
            inputs.append(torch.tensor(selected_input_ids, dtype=torch.long))
        if not inputs:
            return None  # if all the examples in the batch exceed max_length sentence
        inputs = pad_sequence(inputs, batch_first=True, padding_value=0)
        return inputs


class ZH_Fin_Collator:
    def __init__(self, bos_token, eos_token) -> None:
        self.bos_token = bos_token
        self.eos_token = eos_token

    def __call__(self, samples, max_len=120):
        inputs = []
        for sp in samples:
            # NOTE: in TD-VAE, both encoder and decoder are gpt2, thus use decoder sent twice !
            input_ids = sp['input_ids']
            if len(input_ids) == 0:
                continue  # we ignore paragraphs with empty string
            selected_input_ids = [self.bos_token] + input_ids + [self.eos_token]
            inputs.append(torch.tensor(selected_input_ids, dtype=torch.long))
        if not inputs:
            return None
        inputs = pad_sequence(inputs, batch_first=True, padding_value=0)
        return inputs


class VAEModelCheckpoint:
    @ staticmethod
    def add_argparse_args(parent_args):
        parser = parent_args.add_argument_group('BaseModel')

        parser.add_argument('--monitor', default='total_loss', type=str)
        parser.add_argument('--mode', default='min', type=str)
        parser.add_argument('--dirpath', default='./log/', type=str)
        parser.add_argument('--filename', default='model-{epoch:2d}-{train_loss:.4f}', type=str)

        parser.add_argument('--save_top_k', default=-1, type=int)
        parser.add_argument('--every_n_train_steps', default=1000, type=float)
        parser.add_argument('--save_weights_only', default=True, type=bool)

        return parent_args

    @staticmethod
    def get_callback(args):
        return ModelCheckpoint(monitor=args.monitor,
                               save_top_k=args.save_top_k,
                               mode=args.mode,
                               every_n_train_steps=args.every_n_train_steps,
                               save_weights_only=args.save_weights_only,
                               dirpath=args.dirpath,
                               filename=args.filename)


if __name__ == '__main__':
    args_parser = argparse.ArgumentParser()

    args_parser = FSDataModule.add_data_specific_args(args_parser)
    args_parser = Trainer.add_argparse_args(args_parser)
    args_parser = DeepVAEModule.add_module_specific_args(args_parser)
    args_parser = VAEModelCheckpoint.add_argparse_args(args_parser)

    args = args_parser.parse_args()
    # TODO: update this to be tokenizer specific
    # collator = NER_RE_Collator(bos_token=21128, eos_token=21129, sep_token=102)
    # collator = TDVAECollator(bos_token=21128, eos_token=21129)
    collator = ZH_Fin_Collator(bos_token=21128, eos_token=21129)

    data_module = FSDataModule(args=args, collate_fn=collator)

    train_steps = math.ceil(len(data_module.train_dataset)*args.max_epochs /
                            args.train_batchsize / args.num_nodes / args.gpus)
    model = DeepVAEModule(args, train_steps)

    logger = loggers.TensorBoardLogger(save_dir=os.path.join(
        args.default_root_dir, 'logs/'), name='deepvae_lightning')

    save_cpt_callback = VAEModelCheckpoint.get_callback(args)
    lr_monitor = LearningRateMonitor(logging_interval='step')
    trainer = Trainer.from_argparse_args(args,
                                         callbacks=[save_cpt_callback, lr_monitor],
                                         logger=logger)
    trainer.fit(model, data_module)