Spaces:
Sleeping
Sleeping
# Copyright 2023 Databricks, Inc. | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import logging | |
from functools import partial | |
from pathlib import Path | |
from typing import Any, Dict, List, Tuple, Union | |
import click | |
import numpy as np | |
from datasets import Dataset, load_dataset | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
DataCollatorForLanguageModeling, | |
PreTrainedTokenizer, | |
Trainer, | |
TrainingArguments, | |
set_seed, | |
) | |
from .consts import ( | |
DEFAULT_INPUT_MODEL, | |
DEFAULT_SEED, | |
PROMPT_WITH_INPUT_FORMAT, | |
PROMPT_NO_INPUT_FORMAT, | |
END_KEY, | |
INSTRUCTION_KEY, | |
RESPONSE_KEY_NL, | |
) | |
logger = logging.getLogger(__name__) | |
ROOT_PATH = Path(__file__).parent.parent | |
DATABRICKS_DOLLY_15K_PATH = ROOT_PATH / "data" / "databricks-dolly-15k.jsonl" | |
class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling): | |
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: | |
batch = super().torch_call(examples) | |
# The prompt ends with the response key plus a newline. We encode this and then try to find it in the | |
# sequence of tokens. This should just be a single token. | |
response_token_ids = self.tokenizer.encode(RESPONSE_KEY_NL) | |
labels = batch["labels"].clone() | |
for i in range(len(examples)): | |
response_token_ids_start_idx = None | |
for idx in np.where(batch["labels"][i] == response_token_ids[0])[0]: | |
response_token_ids_start_idx = idx | |
break | |
if response_token_ids_start_idx is None: | |
raise RuntimeError( | |
f'Could not find response key {response_token_ids} in token IDs {batch["labels"][i]}' | |
) | |
response_token_ids_end_idx = response_token_ids_start_idx + 1 | |
# Make pytorch loss function ignore all tokens up through the end of the response key | |
labels[i, :response_token_ids_end_idx] = -100 | |
batch["labels"] = labels | |
return batch | |
def preprocess_batch(batch: Dict[str, List], tokenizer: AutoTokenizer, max_length: int) -> dict: | |
return tokenizer( | |
batch["text"], | |
max_length=max_length, | |
truncation=True, | |
) | |
def load_training_dataset() -> Dataset: | |
logger.info(f"Loading dataset from {DATABRICKS_DOLLY_15K_PATH}") | |
dataset = load_dataset("json", data_files=str(DATABRICKS_DOLLY_15K_PATH))["train"] | |
logger.info("Found %d rows", dataset.num_rows) | |
def _add_text(rec): | |
instruction = rec["instruction"] | |
response = rec["response"] | |
context = rec.get("context") | |
if not instruction: | |
raise ValueError(f"Expected an instruction in: {rec}") | |
if not response: | |
raise ValueError(f"Expected a response in: {rec}") | |
# For some instructions there is an input that goes along with the instruction, providing context for the | |
# instruction. For example, the input might be a passage from Wikipedia and the instruction says to extract | |
# some piece of information from it. The response is that information to extract. In other cases there is | |
# no input. For example, the instruction might be open QA such as asking what year some historic figure was | |
# born. | |
if context: | |
rec["text"] = PROMPT_WITH_INPUT_FORMAT.format(instruction=instruction, response=response, input=context) | |
else: | |
rec["text"] = PROMPT_NO_INPUT_FORMAT.format(instruction=instruction, response=response) | |
return rec | |
dataset = dataset.map(_add_text) | |
return dataset | |
def load_tokenizer(pretrained_model_name_or_path: str = DEFAULT_INPUT_MODEL) -> PreTrainedTokenizer: | |
logger.info(f"Loading tokenizer for {pretrained_model_name_or_path}") | |
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.add_special_tokens({"additional_special_tokens": [END_KEY, INSTRUCTION_KEY, RESPONSE_KEY_NL]}) | |
return tokenizer | |
def load_model( | |
pretrained_model_name_or_path: str = DEFAULT_INPUT_MODEL, *, gradient_checkpointing: bool = False | |
) -> AutoModelForCausalLM: | |
logger.info(f"Loading model for {pretrained_model_name_or_path}") | |
model = AutoModelForCausalLM.from_pretrained( | |
pretrained_model_name_or_path, trust_remote_code=True, use_cache=False if gradient_checkpointing else True | |
) | |
return model | |
def get_model_tokenizer( | |
pretrained_model_name_or_path: str = DEFAULT_INPUT_MODEL, *, gradient_checkpointing: bool = False | |
) -> Tuple[AutoModelForCausalLM, PreTrainedTokenizer]: | |
tokenizer = load_tokenizer(pretrained_model_name_or_path) | |
model = load_model(pretrained_model_name_or_path, gradient_checkpointing=gradient_checkpointing) | |
model.resize_token_embeddings(len(tokenizer)) | |
return model, tokenizer | |
def preprocess_dataset(tokenizer: AutoTokenizer, max_length: int, seed=DEFAULT_SEED) -> Dataset: | |
"""Loads the training dataset and tokenizes it so it is ready for training. | |
Args: | |
tokenizer (AutoTokenizer): Tokenizer tied to the model. | |
max_length (int): Maximum number of tokens to emit from tokenizer. | |
Returns: | |
Dataset: HuggingFace dataset | |
""" | |
dataset = load_training_dataset() | |
logger.info("Preprocessing dataset") | |
_preprocessing_function = partial(preprocess_batch, max_length=max_length, tokenizer=tokenizer) | |
dataset = dataset.map( | |
_preprocessing_function, | |
batched=True, | |
remove_columns=["instruction", "context", "response", "text", "category"], | |
) | |
# Make sure we don't have any truncated records, as this would mean the end keyword is missing. | |
logger.info("Processed dataset has %d rows", dataset.num_rows) | |
dataset = dataset.filter(lambda rec: len(rec["input_ids"]) < max_length) | |
logger.info("Processed dataset has %d rows after filtering for truncated records", dataset.num_rows) | |
logger.info("Shuffling dataset") | |
dataset = dataset.shuffle(seed=seed) | |
logger.info("Done preprocessing") | |
return dataset | |
def train( | |
*, | |
input_model: str, | |
local_output_dir: str, | |
dbfs_output_dir: str, | |
epochs: int, | |
per_device_train_batch_size: int, | |
per_device_eval_batch_size: int, | |
lr: float, | |
seed: int, | |
deepspeed: str, | |
gradient_checkpointing: bool, | |
local_rank: str, | |
bf16: bool, | |
logging_steps: int, | |
save_steps: int, | |
eval_steps: int, | |
test_size: Union[float, int], | |
save_total_limit: int, | |
warmup_steps: int, | |
): | |
set_seed(seed) | |
model, tokenizer = get_model_tokenizer( | |
pretrained_model_name_or_path=input_model, gradient_checkpointing=gradient_checkpointing | |
) | |
# Use the same max length that the model supports. Fall back to 1024 if the setting can't be found. | |
# The configuraton for the length can be stored under different names depending on the model. Here we attempt | |
# a few possible names we've encountered. | |
conf = model.config | |
max_length = None | |
for length_setting in ["n_positions", "max_position_embeddings", "seq_length"]: | |
max_length = getattr(model.config, length_setting, None) | |
if max_length: | |
logger.info(f"Found max lenth: {max_length}") | |
break | |
if not max_length: | |
max_length = 1024 | |
logger.info(f"Using default max length: {max_length}") | |
processed_dataset = preprocess_dataset(tokenizer=tokenizer, max_length=max_length, seed=seed) | |
split_dataset = processed_dataset.train_test_split(test_size=test_size, seed=seed) | |
logger.info("Train data size: %d", split_dataset["train"].num_rows) | |
logger.info("Test data size: %d", split_dataset["test"].num_rows) | |
data_collator = DataCollatorForCompletionOnlyLM( | |
tokenizer=tokenizer, mlm=False, return_tensors="pt", pad_to_multiple_of=8 | |
) | |
if not dbfs_output_dir: | |
logger.warn("Will NOT save to DBFS") | |
training_args = TrainingArguments( | |
output_dir=local_output_dir, | |
per_device_train_batch_size=per_device_train_batch_size, | |
per_device_eval_batch_size=per_device_eval_batch_size, | |
fp16=False, | |
bf16=bf16, | |
learning_rate=lr, | |
num_train_epochs=epochs, | |
deepspeed=deepspeed, | |
gradient_checkpointing=gradient_checkpointing, | |
logging_dir=f"{local_output_dir}/runs", | |
logging_strategy="steps", | |
logging_steps=logging_steps, | |
evaluation_strategy="steps", | |
eval_steps=eval_steps, | |
save_strategy="steps", | |
save_steps=save_steps, | |
save_total_limit=save_total_limit, | |
load_best_model_at_end=False, | |
report_to="tensorboard", | |
disable_tqdm=True, | |
remove_unused_columns=False, | |
local_rank=local_rank, | |
warmup_steps=warmup_steps, | |
) | |
logger.info("Instantiating Trainer") | |
trainer = Trainer( | |
model=model, | |
tokenizer=tokenizer, | |
args=training_args, | |
train_dataset=split_dataset["train"], | |
eval_dataset=split_dataset["test"], | |
data_collator=data_collator, | |
) | |
logger.info("Training") | |
trainer.train() | |
logger.info(f"Saving Model to {local_output_dir}") | |
trainer.save_model(output_dir=local_output_dir) | |
if dbfs_output_dir: | |
logger.info(f"Saving Model to {dbfs_output_dir}") | |
trainer.save_model(output_dir=dbfs_output_dir) | |
logger.info("Done.") | |
def main(**kwargs): | |
train(**kwargs) | |
if __name__ == "__main__": | |
logging.basicConfig( | |
format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S" | |
) | |
try: | |
main() | |
except Exception: | |
logger.exception("main failed") | |
raise | |