Andreas Köpf
add falcon landmark code (incomplete)
0d0ff25
# Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#import copy
#import logging
from dataclasses import dataclass, field
from functools import partial
from typing import Dict, Optional, Sequence
import torch
import transformers
#from torch.utils.data import Dataset
from transformers import Trainer, DataCollatorForLanguageModeling, get_cosine_schedule_with_warmup
from modelling_RW import RWForCausalLM
#from transformers import AutoModelForCausalLM
from torch.distributed import barrier
import os
from datasets import load_dataset
IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
#optim: str = field(default="adamw_hf")
optim: str = field(default="adamw_torch")
model_max_length: int = field(
default=128,
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
)
use_flash: bool = field(default=False)
mem_freq: int = field(default=63)
#report_to: str = "none" # disable logging
class TrainerCosine(Trainer):
def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
"""
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.
Args:
num_training_steps (int): The number of training steps to do.
"""
if self.args.lr_scheduler_type != "cosine":
return super().create_scheduler(num_training_steps, optimizer)
if self.lr_scheduler is None:
self.lr_scheduler = get_cosine_schedule_with_warmup(
optimizer=self.optimizer if optimizer is None else optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
num_cycles=0.4 # ~10% of the init lr
)
return self.lr_scheduler
def smart_tokenizer_and_embedding_resize(
special_tokens_dict: Dict,
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
):
"""Resize tokenizer and embedding.
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
"""
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
def tokenize_fn(tokenizer, example):
context_length = tokenizer.model_max_length
outputs = tokenizer(
tokenizer.eos_token.join(example["text"]),
truncation=False,
return_tensors="pt",
pad_to_multiple_of=context_length,
padding=True,
)
return {"input_ids": outputs["input_ids"].view(-1, context_length)}
def train():
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
model_args, training_args = parser.parse_args_into_dataclasses()
# ensure max length leaves room for landmark tokens
model_max_length = training_args.model_max_length - (training_args.model_max_length // training_args.mem_freq)
model_max_length = model_max_length // training_args.mem_freq * training_args.mem_freq
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=model_max_length,
padding_side="right",
use_fast=False,
)
special_tokens_dict = dict()
if tokenizer.pad_token is None:
special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
if tokenizer.eos_token is None:
special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
if tokenizer.bos_token is None:
special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
if tokenizer.unk_token is None:
special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
mem_token = "<landmark>"
special_tokens_dict["additional_special_tokens"] = [mem_token]
model = RWForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
mem_freq=training_args.mem_freq,
torch_dtype=torch.bfloat16,
)
# model = AutoModelForCausalLM.from_pretrained(
# model_args.model_name_or_path,
# cache_dir=training_args.cache_dir,
# torch_dtype=torch.bfloat16,
# trust_remote_code=True,
# )
smart_tokenizer_and_embedding_resize(
special_tokens_dict=special_tokens_dict,
tokenizer=tokenizer,
model=model,
)
mem_id = tokenizer.convert_tokens_to_ids(mem_token)
model.set_mem_id(mem_id)
print(f"Landmark token: {mem_token}: {mem_id}")
rank = int(os.environ.get('RANK', -1))
if rank > 0:
barrier()
#dataset = load_dataset("togethercomputer/RedPajama-Data-1T-Sample", cache_dir=training_args.cache_dir, split='train[:100]')
dataset = load_dataset("togethercomputer/RedPajama-Data-1T-Sample", cache_dir=training_args.cache_dir, split='train')
dataset = dataset.map(partial(tokenize_fn, tokenizer), batched=True, num_proc=32, remove_columns=["text", "meta"])
model.enable_landmark_insertion()
model.enable_flash()
# if training_args.use_flash:
# model.enable_landmark_insertion()
# model.enable_flash()
# else:
# dataset = dataset.map(
# partial(
# add_mem_tokens,
# mem_freq=training_args.mem_freq,
# mem_id=mem_id
# ), batched=False, num_proc=32)
if rank == 0:
barrier()
print(dataset)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
trainer = TrainerCosine(
model=model, tokenizer=tokenizer, args=training_args,
train_dataset=dataset, #dataset["train"],
eval_dataset=None,
data_collator=data_collator)
trainer.train()
trainer.save_state()
trainer.save_model(output_dir=training_args.output_dir)
if __name__ == "__main__":
train()