|
import torch |
|
import os |
|
import random |
|
import math |
|
import argparse |
|
from fengshen.data.fs_datasets.fs_datamodule import FSDataModule |
|
from fengshen.example.deepVAE.vae_pl_module import DeepVAEModule |
|
|
|
from pytorch_lightning import ( |
|
Trainer, |
|
loggers, |
|
) |
|
|
|
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor |
|
from torch.nn.utils.rnn import pad_sequence |
|
|
|
|
|
class NER_RE_Collator: |
|
def __init__(self, bos_token, eos_token, sep_token) -> None: |
|
self.bos_token = bos_token |
|
self.eos_token = eos_token |
|
self.sep_token = sep_token |
|
|
|
def __call__(self, samples, max_len=128): |
|
|
|
inputs_tensors, entity_tensors = [], [] |
|
for sp in samples: |
|
|
|
input_entities, input_ids = sp['decoder_entities'], sp['decoder_target'] |
|
input_entities = input_entities[:max_len] + [self.sep_token] |
|
|
|
input_ids = [self.bos_token] + input_ids[:max_len] + [self.eos_token] |
|
entity_tensors.append(torch.tensor(input_entities, dtype=torch.long)) |
|
inputs_tensors.append(torch.tensor(input_ids, dtype=torch.long)) |
|
if not inputs_tensors or not entity_tensors: |
|
return None |
|
inputs_tensors = pad_sequence(inputs_tensors, batch_first=True, padding_value=0) |
|
entity_tensors = pad_sequence(entity_tensors, batch_first=True, padding_value=0) |
|
return inputs_tensors, entity_tensors |
|
|
|
|
|
class TDVAECollator: |
|
def __init__(self, bos_token, eos_token) -> None: |
|
self.bos_token = bos_token |
|
self.eos_token = eos_token |
|
|
|
def __call__(self, samples, max_len=120): |
|
|
|
inputs = [] |
|
for sp in samples: |
|
|
|
sent_lengths, input_ids = sp['decoder_sent_lengths'], sp['decoder_target'] |
|
potential_indices = [idx for idx, slen in enumerate(sent_lengths) if slen < max_len] |
|
if len(potential_indices) == 0: |
|
continue |
|
selected_idx = random.choice(potential_indices) |
|
start_pos, end_pos = sum(sent_lengths[:selected_idx]), sum(sent_lengths[:selected_idx])+sent_lengths[selected_idx] |
|
selected_input_ids = [self.bos_token] + input_ids[start_pos:end_pos] + [self.eos_token] |
|
inputs.append(torch.tensor(selected_input_ids, dtype=torch.long)) |
|
if not inputs: |
|
return None |
|
inputs = pad_sequence(inputs, batch_first=True, padding_value=0) |
|
return inputs |
|
|
|
|
|
class ZH_Fin_Collator: |
|
def __init__(self, bos_token, eos_token) -> None: |
|
self.bos_token = bos_token |
|
self.eos_token = eos_token |
|
|
|
def __call__(self, samples, max_len=120): |
|
inputs = [] |
|
for sp in samples: |
|
|
|
input_ids = sp['input_ids'] |
|
if len(input_ids) == 0: |
|
continue |
|
selected_input_ids = [self.bos_token] + input_ids + [self.eos_token] |
|
inputs.append(torch.tensor(selected_input_ids, dtype=torch.long)) |
|
if not inputs: |
|
return None |
|
inputs = pad_sequence(inputs, batch_first=True, padding_value=0) |
|
return inputs |
|
|
|
|
|
class VAEModelCheckpoint: |
|
@ staticmethod |
|
def add_argparse_args(parent_args): |
|
parser = parent_args.add_argument_group('BaseModel') |
|
|
|
parser.add_argument('--monitor', default='total_loss', type=str) |
|
parser.add_argument('--mode', default='min', type=str) |
|
parser.add_argument('--dirpath', default='./log/', type=str) |
|
parser.add_argument('--filename', default='model-{epoch:2d}-{train_loss:.4f}', type=str) |
|
|
|
parser.add_argument('--save_top_k', default=-1, type=int) |
|
parser.add_argument('--every_n_train_steps', default=1000, type=float) |
|
parser.add_argument('--save_weights_only', default=True, type=bool) |
|
|
|
return parent_args |
|
|
|
@staticmethod |
|
def get_callback(args): |
|
return ModelCheckpoint(monitor=args.monitor, |
|
save_top_k=args.save_top_k, |
|
mode=args.mode, |
|
every_n_train_steps=args.every_n_train_steps, |
|
save_weights_only=args.save_weights_only, |
|
dirpath=args.dirpath, |
|
filename=args.filename) |
|
|
|
|
|
if __name__ == '__main__': |
|
args_parser = argparse.ArgumentParser() |
|
|
|
args_parser = FSDataModule.add_data_specific_args(args_parser) |
|
args_parser = Trainer.add_argparse_args(args_parser) |
|
args_parser = DeepVAEModule.add_module_specific_args(args_parser) |
|
args_parser = VAEModelCheckpoint.add_argparse_args(args_parser) |
|
|
|
args = args_parser.parse_args() |
|
|
|
|
|
|
|
collator = ZH_Fin_Collator(bos_token=21128, eos_token=21129) |
|
|
|
data_module = FSDataModule(args=args, collate_fn=collator) |
|
|
|
train_steps = math.ceil(len(data_module.train_dataset)*args.max_epochs / |
|
args.train_batchsize / args.num_nodes / args.gpus) |
|
model = DeepVAEModule(args, train_steps) |
|
|
|
logger = loggers.TensorBoardLogger(save_dir=os.path.join( |
|
args.default_root_dir, 'logs/'), name='deepvae_lightning') |
|
|
|
save_cpt_callback = VAEModelCheckpoint.get_callback(args) |
|
lr_monitor = LearningRateMonitor(logging_interval='step') |
|
trainer = Trainer.from_argparse_args(args, |
|
callbacks=[save_cpt_callback, lr_monitor], |
|
logger=logger) |
|
trainer.fit(model, data_module) |
|
|