|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass, field |
|
from functools import partial |
|
from typing import Dict, Optional, Sequence |
|
|
|
|
|
import torch |
|
import transformers |
|
|
|
from transformers import Trainer, DataCollatorForLanguageModeling, get_cosine_schedule_with_warmup |
|
|
|
from modelling_RW import RWForCausalLM |
|
|
|
|
|
|
|
from torch.distributed import barrier |
|
import os |
|
|
|
|
|
from datasets import load_dataset |
|
|
|
IGNORE_INDEX = -100 |
|
DEFAULT_PAD_TOKEN = "[PAD]" |
|
DEFAULT_EOS_TOKEN = "</s>" |
|
DEFAULT_BOS_TOKEN = "<s>" |
|
DEFAULT_UNK_TOKEN = "<unk>" |
|
|
|
|
|
@dataclass |
|
class ModelArguments: |
|
model_name_or_path: Optional[str] = field(default="facebook/opt-125m") |
|
|
|
|
|
@dataclass |
|
class TrainingArguments(transformers.TrainingArguments): |
|
cache_dir: Optional[str] = field(default=None) |
|
|
|
optim: str = field(default="adamw_torch") |
|
model_max_length: int = field( |
|
default=128, |
|
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, |
|
) |
|
use_flash: bool = field(default=False) |
|
mem_freq: int = field(default=63) |
|
|
|
|
|
|
|
class TrainerCosine(Trainer): |
|
def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None): |
|
""" |
|
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or |
|
passed as an argument. |
|
|
|
Args: |
|
num_training_steps (int): The number of training steps to do. |
|
""" |
|
if self.args.lr_scheduler_type != "cosine": |
|
return super().create_scheduler(num_training_steps, optimizer) |
|
if self.lr_scheduler is None: |
|
self.lr_scheduler = get_cosine_schedule_with_warmup( |
|
optimizer=self.optimizer if optimizer is None else optimizer, |
|
num_warmup_steps=self.args.get_warmup_steps(num_training_steps), |
|
num_training_steps=num_training_steps, |
|
num_cycles=0.4 |
|
) |
|
return self.lr_scheduler |
|
|
|
|
|
def smart_tokenizer_and_embedding_resize( |
|
special_tokens_dict: Dict, |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
model: transformers.PreTrainedModel, |
|
): |
|
"""Resize tokenizer and embedding. |
|
|
|
Note: This is the unoptimized version that may make your embedding size not be divisible by 64. |
|
""" |
|
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) |
|
model.resize_token_embeddings(len(tokenizer)) |
|
|
|
if num_new_tokens > 0: |
|
input_embeddings = model.get_input_embeddings().weight.data |
|
output_embeddings = model.get_output_embeddings().weight.data |
|
|
|
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) |
|
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) |
|
|
|
input_embeddings[-num_new_tokens:] = input_embeddings_avg |
|
output_embeddings[-num_new_tokens:] = output_embeddings_avg |
|
|
|
def tokenize_fn(tokenizer, example): |
|
context_length = tokenizer.model_max_length |
|
outputs = tokenizer( |
|
tokenizer.eos_token.join(example["text"]), |
|
truncation=False, |
|
return_tensors="pt", |
|
pad_to_multiple_of=context_length, |
|
padding=True, |
|
) |
|
return {"input_ids": outputs["input_ids"].view(-1, context_length)} |
|
|
|
def train(): |
|
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments)) |
|
model_args, training_args = parser.parse_args_into_dataclasses() |
|
|
|
|
|
model_max_length = training_args.model_max_length - (training_args.model_max_length // training_args.mem_freq) |
|
model_max_length = model_max_length // training_args.mem_freq * training_args.mem_freq |
|
|
|
tokenizer = transformers.AutoTokenizer.from_pretrained( |
|
model_args.model_name_or_path, |
|
cache_dir=training_args.cache_dir, |
|
model_max_length=model_max_length, |
|
padding_side="right", |
|
use_fast=False, |
|
) |
|
special_tokens_dict = dict() |
|
if tokenizer.pad_token is None: |
|
special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN |
|
if tokenizer.eos_token is None: |
|
special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN |
|
if tokenizer.bos_token is None: |
|
special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN |
|
if tokenizer.unk_token is None: |
|
special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN |
|
mem_token = "<landmark>" |
|
special_tokens_dict["additional_special_tokens"] = [mem_token] |
|
|
|
model = RWForCausalLM.from_pretrained( |
|
model_args.model_name_or_path, |
|
cache_dir=training_args.cache_dir, |
|
mem_freq=training_args.mem_freq, |
|
torch_dtype=torch.bfloat16, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
smart_tokenizer_and_embedding_resize( |
|
special_tokens_dict=special_tokens_dict, |
|
tokenizer=tokenizer, |
|
model=model, |
|
) |
|
|
|
mem_id = tokenizer.convert_tokens_to_ids(mem_token) |
|
model.set_mem_id(mem_id) |
|
print(f"Landmark token: {mem_token}: {mem_id}") |
|
|
|
rank = int(os.environ.get('RANK', -1)) |
|
if rank > 0: |
|
barrier() |
|
|
|
dataset = load_dataset("togethercomputer/RedPajama-Data-1T-Sample", cache_dir=training_args.cache_dir, split='train') |
|
|
|
dataset = dataset.map(partial(tokenize_fn, tokenizer), batched=True, num_proc=32, remove_columns=["text", "meta"]) |
|
|
|
model.enable_landmark_insertion() |
|
model.enable_flash() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if rank == 0: |
|
barrier() |
|
print(dataset) |
|
|
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) |
|
|
|
trainer = TrainerCosine( |
|
model=model, tokenizer=tokenizer, args=training_args, |
|
train_dataset=dataset, |
|
eval_dataset=None, |
|
data_collator=data_collator) |
|
trainer.train() |
|
trainer.save_state() |
|
trainer.save_model(output_dir=training_args.output_dir) |
|
|
|
|
|
if __name__ == "__main__": |
|
train() |
|
|