Spaces:
Runtime error
Runtime error
from dataclasses import dataclass | |
from transformers import ( | |
MegatronBertConfig, | |
MegatronBertForPreTraining, | |
AutoTokenizer, | |
) | |
from pytorch_lightning import ( | |
LightningModule, | |
Trainer, | |
) | |
from pytorch_lightning.callbacks import ( | |
LearningRateMonitor, | |
) | |
import argparse | |
import torch | |
import os | |
import numpy as np | |
import time | |
from fengshen.data.universal_datamodule import UniversalDataModule | |
from fengshen.data.data_utils.sop_utils import get_a_and_b_segments | |
from fengshen.data.data_utils.truncate_utils import truncate_segments | |
from fengshen.data.data_utils.token_type_utils import create_tokens_and_tokentypes | |
from fengshen.data.data_utils.mask_utils import create_masked_lm_predictions | |
from fengshen.models.model_utils import ( | |
add_module_args, | |
configure_optimizers, | |
get_total_steps, | |
) | |
from fengshen.utils.universal_checkpoint import UniversalCheckpoint | |
from torch.utils.data._utils.collate import default_collate | |
SHOW_DATA = False | |
class ErLangShenCollator: | |
''' | |
由input处理成samples,也就是最终模型的输入 | |
其中主要处理逻辑在__call__里 | |
包含Mask和Sop任务 | |
''' | |
tokenizer: None # 分词 | |
max_seq_length: 512 | |
masked_lm_prob: 0.15 | |
content_key: str = 'text' | |
# 一些预处理操作 | |
def setup(self): | |
from fengshen.data.data_utils.sentence_split import ChineseSentenceSplitter | |
self.sentence_split = ChineseSentenceSplitter() | |
self.np_rng = np.random.RandomState(seed=((int(time.time()) % 2**32))) | |
inv_vocab = {v: k for k, v in self.tokenizer.vocab.items()} | |
self.vocab_id_list = list(inv_vocab.keys()) | |
self.vocab_id_to_token_dict = inv_vocab | |
def __call__(self, samples): | |
''' | |
samples: 一个sample长这样{"text": "hello world"} | |
''' | |
model_inputs = [] | |
for s in samples: | |
sentences = self.sentence_split.tokenize(s[self.content_key]) | |
# Divide sample into two segments (A and B). | |
tokenized_sentences = [self.tokenizer.convert_tokens_to_ids( | |
self.tokenizer.tokenize(sent)) for sent in sentences] | |
if len(tokenized_sentences) == 0: | |
print('find empty sentence') | |
continue | |
if len(tokenized_sentences) > 1: | |
tokens_a, tokens_b, is_next_random = get_a_and_b_segments(tokenized_sentences, | |
self.np_rng) | |
else: | |
tokens_a = tokenized_sentences[0] | |
tokens_b = [] | |
is_next_random = False | |
# max_seq_length - 3因为还需要拼上[CLS] [SEP] [SEP] | |
if len(tokens_a) == 0: | |
continue | |
_ = truncate_segments(tokens_a, tokens_b, len(tokens_a), | |
len(tokens_b), self.max_seq_length-3, self.np_rng) | |
# Build tokens and toketypes. | |
tokens, tokentypes = create_tokens_and_tokentypes(tokens_a, tokens_b, | |
self.tokenizer.cls_token_id, self.tokenizer.sep_token_id) | |
# Masking. | |
max_predictions_per_seq = self.masked_lm_prob * len(tokens) | |
(tokens, masked_positions, masked_labels, _, _) = create_masked_lm_predictions( | |
tokens, self.vocab_id_list, self.vocab_id_to_token_dict, self.masked_lm_prob, | |
self.tokenizer.cls_token_id, self.tokenizer.sep_token_id, self.tokenizer.mask_token_id, | |
max_predictions_per_seq, self.np_rng, | |
masking_style='bert') | |
# Some checks. | |
num_tokens = len(tokens) | |
padding_length = self.max_seq_length - num_tokens | |
assert padding_length >= 0 | |
assert len(tokentypes) == num_tokens | |
assert len(masked_positions) == len(masked_labels) | |
# Tokens and token types. | |
filler = [self.tokenizer.pad_token_id] * padding_length | |
tokens_np = np.array(tokens + filler, dtype=np.int64) | |
tokentypes_np = np.array(tokentypes + filler, dtype=np.int64) | |
# Padding mask. | |
padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, | |
dtype=np.int64) | |
# Lables and loss mask. | |
labels = [-100] * self.max_seq_length | |
for i in range(len(masked_positions)): | |
assert masked_positions[i] < num_tokens | |
labels[masked_positions[i]] = masked_labels[i] | |
labels_np = np.array(labels, dtype=np.int64) | |
model_inputs.append( | |
{ | |
'input_ids': tokens_np, | |
'attention_mask': padding_mask_np, | |
'token_type_ids': tokentypes_np, | |
'labels': labels_np, | |
'next_sentence_label': int(is_next_random) | |
} | |
) | |
return default_collate(model_inputs) | |
class ErLangShenBert(LightningModule): | |
def add_module_specific_args(parent_parser): | |
parser = parent_parser.add_argument_group('Erlangshen Bert') | |
parser.add_argument('--masked_lm_prob', type=float, default=0.15) | |
parser.add_argument('--max_seq_length', type=int, default=512) | |
parser.add_argument('--sample_content_key', type=str, default='text') | |
return parent_parser | |
def __init__(self, args, tokenizer, **kwargs) -> None: | |
super().__init__() | |
self.save_hyperparameters(args) | |
config = MegatronBertConfig.from_pretrained(args.model_path) | |
self.config = config | |
self.tokenizer = tokenizer | |
self.model = MegatronBertForPreTraining(config) | |
def setup(self, stage) -> None: | |
if stage == 'fit': | |
self.total_steps = get_total_steps(self.trainer, self.hparams) | |
print('Total steps: {}' .format(self.total_steps)) | |
def configure_optimizers(self): | |
return configure_optimizers(self) | |
def forward(self, **batch): | |
return self.model(**batch) | |
def detokenize(self, token_ids): | |
toks = self.tokenizer.convert_ids_to_tokens(token_ids) | |
return self.tokenizer.convert_tokens_to_string(toks) | |
def comput_metrix(self, logits, labels): | |
y_pred = torch.argmax(logits, dim=-1) | |
y_pred = y_pred.view(size=(-1,)) | |
y_true = labels.view(size=(-1,)).float() | |
corr = torch.eq(y_pred, y_true) | |
acc = torch.sum(corr.float())/labels.shape[0] | |
return acc | |
def training_step(self, batch, batch_idx): | |
if self.trainer.global_rank == 0: | |
global SHOW_DATA | |
if not SHOW_DATA: | |
print(self.config) | |
print(self.model) | |
SHOW_DATA = True | |
print('source: {}'.format(batch['input_ids'][0])) | |
print('target: {}'.format(batch['labels'][0])) | |
print('source: {}'.format(self.detokenize(batch['input_ids'][0]))) | |
label_idx = batch['labels'][0] != -100 | |
print('target: {}'.format(self.detokenize( | |
batch['labels'][0][label_idx]))) | |
output = self(**batch) | |
self.log('train_loss', output.loss, sync_dist=True) | |
label_idx = batch['labels'] != -100 | |
acc = self.comput_metrix( | |
output.prediction_logits[label_idx].view(-1, output.prediction_logits.size(-1)), batch['labels'][label_idx]) | |
self.log('train_acc', acc, sync_dist=True) | |
return output.loss | |
def validation_step(self, batch, batch_idx): | |
output = self(**batch) | |
self.log('val_loss', output.loss, sync_dist=True) | |
return output.loss | |
def on_load_checkpoint(self, checkpoint) -> None: | |
# 兼容低版本lightning,低版本lightning从ckpt起来时steps数会被重置为0 | |
global_step_offset = checkpoint["global_step"] | |
if 'global_samples' in checkpoint: | |
self.consumed_samples = checkpoint['global_samples'] | |
self.trainer.fit_loop.epoch_loop._batches_that_stepped = global_step_offset | |
if __name__ == '__main__': | |
args_parser = argparse.ArgumentParser() | |
args_parser = add_module_args(args_parser) | |
args_parser = UniversalDataModule.add_data_specific_args(args_parser) | |
args_parser = Trainer.add_argparse_args(args_parser) | |
args_parser = ErLangShenBert.add_module_specific_args(args_parser) | |
args_parser = UniversalCheckpoint.add_argparse_args(args_parser) | |
args = args_parser.parse_args() | |
tokenizer = AutoTokenizer.from_pretrained(args.model_path) | |
collate_fn = ErLangShenCollator( | |
tokenizer=tokenizer, | |
max_seq_length=args.max_seq_length, | |
masked_lm_prob=args.masked_lm_prob, | |
content_key=args.sample_content_key, | |
) | |
collate_fn.setup() | |
data_module = UniversalDataModule(tokenizer=tokenizer, args=args, collate_fn=collate_fn) | |
print('data load complete') | |
model = ErLangShenBert(args, tokenizer=tokenizer) | |
print('model load complete') | |
lr_monitor = LearningRateMonitor(logging_interval='step') | |
checkpoint_callback = UniversalCheckpoint(args) | |
# 做兼容,如果目录不存在的话把这个参数去掉,不然会报错 | |
if args.load_ckpt_path is not None and \ | |
not os.path.exists(args.load_ckpt_path): | |
print('--------warning no checkpoint found--------, remove args') | |
args.load_ckpt_path = None | |
trainer = Trainer.from_argparse_args(args, | |
callbacks=[ | |
lr_monitor, | |
checkpoint_callback]) | |
trainer.fit(model, data_module, ckpt_path=args.load_ckpt_path) | |