HaloMaster's picture
add fengshen
50f0fbb
raw
history blame
9.77 kB
from dataclasses import dataclass
from transformers import (
MegatronBertConfig,
MegatronBertForPreTraining,
AutoTokenizer,
)
from pytorch_lightning import (
LightningModule,
Trainer,
)
from pytorch_lightning.callbacks import (
LearningRateMonitor,
)
import argparse
import torch
import os
import numpy as np
import time
from fengshen.data.universal_datamodule import UniversalDataModule
from fengshen.data.data_utils.sop_utils import get_a_and_b_segments
from fengshen.data.data_utils.truncate_utils import truncate_segments
from fengshen.data.data_utils.token_type_utils import create_tokens_and_tokentypes
from fengshen.data.data_utils.mask_utils import create_masked_lm_predictions
from fengshen.models.model_utils import (
add_module_args,
configure_optimizers,
get_total_steps,
)
from fengshen.utils.universal_checkpoint import UniversalCheckpoint
from torch.utils.data._utils.collate import default_collate
SHOW_DATA = False
@dataclass
class ErLangShenCollator:
'''
由input处理成samples,也就是最终模型的输入
其中主要处理逻辑在__call__里
包含Mask和Sop任务
'''
tokenizer: None # 分词
max_seq_length: 512
masked_lm_prob: 0.15
content_key: str = 'text'
# 一些预处理操作
def setup(self):
from fengshen.data.data_utils.sentence_split import ChineseSentenceSplitter
self.sentence_split = ChineseSentenceSplitter()
self.np_rng = np.random.RandomState(seed=((int(time.time()) % 2**32)))
inv_vocab = {v: k for k, v in self.tokenizer.vocab.items()}
self.vocab_id_list = list(inv_vocab.keys())
self.vocab_id_to_token_dict = inv_vocab
def __call__(self, samples):
'''
samples: 一个sample长这样{"text": "hello world"}
'''
model_inputs = []
for s in samples:
sentences = self.sentence_split.tokenize(s[self.content_key])
# Divide sample into two segments (A and B).
tokenized_sentences = [self.tokenizer.convert_tokens_to_ids(
self.tokenizer.tokenize(sent)) for sent in sentences]
if len(tokenized_sentences) == 0:
print('find empty sentence')
continue
if len(tokenized_sentences) > 1:
tokens_a, tokens_b, is_next_random = get_a_and_b_segments(tokenized_sentences,
self.np_rng)
else:
tokens_a = tokenized_sentences[0]
tokens_b = []
is_next_random = False
# max_seq_length - 3因为还需要拼上[CLS] [SEP] [SEP]
if len(tokens_a) == 0:
continue
_ = truncate_segments(tokens_a, tokens_b, len(tokens_a),
len(tokens_b), self.max_seq_length-3, self.np_rng)
# Build tokens and toketypes.
tokens, tokentypes = create_tokens_and_tokentypes(tokens_a, tokens_b,
self.tokenizer.cls_token_id, self.tokenizer.sep_token_id)
# Masking.
max_predictions_per_seq = self.masked_lm_prob * len(tokens)
(tokens, masked_positions, masked_labels, _, _) = create_masked_lm_predictions(
tokens, self.vocab_id_list, self.vocab_id_to_token_dict, self.masked_lm_prob,
self.tokenizer.cls_token_id, self.tokenizer.sep_token_id, self.tokenizer.mask_token_id,
max_predictions_per_seq, self.np_rng,
masking_style='bert')
# Some checks.
num_tokens = len(tokens)
padding_length = self.max_seq_length - num_tokens
assert padding_length >= 0
assert len(tokentypes) == num_tokens
assert len(masked_positions) == len(masked_labels)
# Tokens and token types.
filler = [self.tokenizer.pad_token_id] * padding_length
tokens_np = np.array(tokens + filler, dtype=np.int64)
tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)
# Padding mask.
padding_mask_np = np.array([1] * num_tokens + [0] * padding_length,
dtype=np.int64)
# Lables and loss mask.
labels = [-100] * self.max_seq_length
for i in range(len(masked_positions)):
assert masked_positions[i] < num_tokens
labels[masked_positions[i]] = masked_labels[i]
labels_np = np.array(labels, dtype=np.int64)
model_inputs.append(
{
'input_ids': tokens_np,
'attention_mask': padding_mask_np,
'token_type_ids': tokentypes_np,
'labels': labels_np,
'next_sentence_label': int(is_next_random)
}
)
return default_collate(model_inputs)
class ErLangShenBert(LightningModule):
@staticmethod
def add_module_specific_args(parent_parser):
parser = parent_parser.add_argument_group('Erlangshen Bert')
parser.add_argument('--masked_lm_prob', type=float, default=0.15)
parser.add_argument('--max_seq_length', type=int, default=512)
parser.add_argument('--sample_content_key', type=str, default='text')
return parent_parser
def __init__(self, args, tokenizer, **kwargs) -> None:
super().__init__()
self.save_hyperparameters(args)
config = MegatronBertConfig.from_pretrained(args.model_path)
self.config = config
self.tokenizer = tokenizer
self.model = MegatronBertForPreTraining(config)
def setup(self, stage) -> None:
if stage == 'fit':
self.total_steps = get_total_steps(self.trainer, self.hparams)
print('Total steps: {}' .format(self.total_steps))
def configure_optimizers(self):
return configure_optimizers(self)
def forward(self, **batch):
return self.model(**batch)
def detokenize(self, token_ids):
toks = self.tokenizer.convert_ids_to_tokens(token_ids)
return self.tokenizer.convert_tokens_to_string(toks)
def comput_metrix(self, logits, labels):
y_pred = torch.argmax(logits, dim=-1)
y_pred = y_pred.view(size=(-1,))
y_true = labels.view(size=(-1,)).float()
corr = torch.eq(y_pred, y_true)
acc = torch.sum(corr.float())/labels.shape[0]
return acc
def training_step(self, batch, batch_idx):
if self.trainer.global_rank == 0:
global SHOW_DATA
if not SHOW_DATA:
print(self.config)
print(self.model)
SHOW_DATA = True
print('source: {}'.format(batch['input_ids'][0]))
print('target: {}'.format(batch['labels'][0]))
print('source: {}'.format(self.detokenize(batch['input_ids'][0])))
label_idx = batch['labels'][0] != -100
print('target: {}'.format(self.detokenize(
batch['labels'][0][label_idx])))
output = self(**batch)
self.log('train_loss', output.loss, sync_dist=True)
label_idx = batch['labels'] != -100
acc = self.comput_metrix(
output.prediction_logits[label_idx].view(-1, output.prediction_logits.size(-1)), batch['labels'][label_idx])
self.log('train_acc', acc, sync_dist=True)
return output.loss
def validation_step(self, batch, batch_idx):
output = self(**batch)
self.log('val_loss', output.loss, sync_dist=True)
return output.loss
def on_load_checkpoint(self, checkpoint) -> None:
# 兼容低版本lightning,低版本lightning从ckpt起来时steps数会被重置为0
global_step_offset = checkpoint["global_step"]
if 'global_samples' in checkpoint:
self.consumed_samples = checkpoint['global_samples']
self.trainer.fit_loop.epoch_loop._batches_that_stepped = global_step_offset
if __name__ == '__main__':
args_parser = argparse.ArgumentParser()
args_parser = add_module_args(args_parser)
args_parser = UniversalDataModule.add_data_specific_args(args_parser)
args_parser = Trainer.add_argparse_args(args_parser)
args_parser = ErLangShenBert.add_module_specific_args(args_parser)
args_parser = UniversalCheckpoint.add_argparse_args(args_parser)
args = args_parser.parse_args()
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
collate_fn = ErLangShenCollator(
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
masked_lm_prob=args.masked_lm_prob,
content_key=args.sample_content_key,
)
collate_fn.setup()
data_module = UniversalDataModule(tokenizer=tokenizer, args=args, collate_fn=collate_fn)
print('data load complete')
model = ErLangShenBert(args, tokenizer=tokenizer)
print('model load complete')
lr_monitor = LearningRateMonitor(logging_interval='step')
checkpoint_callback = UniversalCheckpoint(args)
# 做兼容,如果目录不存在的话把这个参数去掉,不然会报错
if args.load_ckpt_path is not None and \
not os.path.exists(args.load_ckpt_path):
print('--------warning no checkpoint found--------, remove args')
args.load_ckpt_path = None
trainer = Trainer.from_argparse_args(args,
callbacks=[
lr_monitor,
checkpoint_callback])
trainer.fit(model, data_module, ckpt_path=args.load_ckpt_path)