|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import dataclasses
|
|
import inspect
|
|
import warnings
|
|
from functools import wraps
|
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from accelerate.state import PartialState
|
|
from datasets import Dataset
|
|
from datasets.arrow_writer import SchemaInferenceError
|
|
from datasets.builder import DatasetGenerationError
|
|
from transformers import (
|
|
AutoModelForCausalLM,
|
|
AutoTokenizer,
|
|
DataCollator,
|
|
DataCollatorForLanguageModeling,
|
|
PreTrainedModel,
|
|
PreTrainedTokenizerBase,
|
|
Trainer,
|
|
TrainingArguments,
|
|
)
|
|
from transformers.modeling_utils import unwrap_model
|
|
from transformers.trainer_callback import TrainerCallback
|
|
from transformers.trainer_utils import EvalPrediction
|
|
|
|
from ..extras.dataset_formatting import get_formatting_func_from_dataset
|
|
from ..import_utils import is_peft_available
|
|
from .utils import (
|
|
ConstantLengthDataset,
|
|
DataCollatorForCompletionOnlyLM,
|
|
neftune_post_forward_hook,
|
|
peft_module_casting_to_bf16,
|
|
trl_sanitze_kwargs_for_tagging,
|
|
)
|
|
|
|
|
|
if is_peft_available():
|
|
from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
|
|
|
|
|
|
class SFTTrainer(Trainer):
|
|
r"""
|
|
Class definition of the Supervised Finetuning Trainer (SFT Trainer).
|
|
This class is a wrapper around the `transformers.Trainer` class and inherits all of its attributes and methods.
|
|
The trainer takes care of properly initializing the PeftModel in case a user passes a `PeftConfig` object.
|
|
|
|
Args:
|
|
model (Union[`transformers.PreTrainedModel`, `nn.Module`, `str`]):
|
|
The model to train, can be a `PreTrainedModel`, a `torch.nn.Module` or a string with the model name to
|
|
load from cache or download. The model can be also converted to a `PeftModel` if a `PeftConfig` object is
|
|
passed to the `peft_config` argument.
|
|
args (Optional[`transformers.TrainingArguments`]):
|
|
The arguments to tweak for training. Please refer to the official documentation of `transformers.TrainingArguments`
|
|
for more information.
|
|
data_collator (Optional[`transformers.DataCollator`]):
|
|
The data collator to use for training.
|
|
train_dataset (Optional[`datasets.Dataset`]):
|
|
The dataset to use for training. We recommend users to use `trl.trainer.ConstantLengthDataset` to create their dataset.
|
|
eval_dataset (Optional[Union[`datasets.Dataset`, Dict[`str`, `datasets.Dataset`]]]):
|
|
The dataset to use for evaluation. We recommend users to use `trl.trainer.ConstantLengthDataset` to create their dataset.
|
|
tokenizer (Optional[`transformers.PreTrainedTokenizer`]):
|
|
The tokenizer to use for training. If not specified, the tokenizer associated to the model will be used.
|
|
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
|
The model initializer to use for training. If None is specified, the default model initializer will be used.
|
|
compute_metrics (`Callable[[transformers.EvalPrediction], Dict]`, *optional* defaults to None):
|
|
The function used to compute metrics during evaluation. It should return a dictionary mapping metric names to metric values.
|
|
If not specified, only the loss will be computed during evaluation.
|
|
callbacks (`List[transformers.TrainerCallback]`):
|
|
The callbacks to use for training.
|
|
optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
|
The optimizer and scheduler to use for training.
|
|
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
|
The function to use to preprocess the logits before computing the metrics.
|
|
peft_config (`Optional[PeftConfig]`):
|
|
The PeftConfig object to use to initialize the PeftModel.
|
|
dataset_text_field (`Optional[str]`):
|
|
The name of the text field of the dataset, in case this is passed by a user, the trainer will automatically create a
|
|
`ConstantLengthDataset` based on the `dataset_text_field` argument.
|
|
formatting_func (`Optional[Callable]`):
|
|
The formatting function to be used for creating the `ConstantLengthDataset`.
|
|
max_seq_length (`Optional[int]`):
|
|
The maximum sequence length to use for the `ConstantLengthDataset` and for automatically creating the Dataset. Defaults to `512`.
|
|
infinite (`Optional[bool]`):
|
|
Whether to use an infinite dataset or not. Defaults to `False`.
|
|
num_of_sequences (`Optional[int]`):
|
|
The number of sequences to use for the `ConstantLengthDataset`. Defaults to `1024`.
|
|
chars_per_token (`Optional[float]`):
|
|
The number of characters per token to use for the `ConstantLengthDataset`. Defaults to `3.6`. You can check how this is computed in the
|
|
stack-llama example: https://github.com/huggingface/trl/blob/08f550674c553c36c51d1027613c29f14f3676a5/examples/stack_llama/scripts/supervised_finetuning.py#L53.
|
|
packing (`Optional[bool]`):
|
|
Used only in case `dataset_text_field` is passed. This argument is used by the `ConstantLengthDataset` to pack the sequences
|
|
of the dataset.
|
|
dataset_num_proc (`Optional[int]`):
|
|
The number of workers to use to tokenize the data. Only used when `packing=False`. Defaults to None.
|
|
dataset_batch_size (`int`):
|
|
The number of examples to tokenize per batch. If batch_size <= 0 or batch_size == None,
|
|
tokenize the full dataset as a single batch. Defaults to 1000.
|
|
neftune_noise_alpha (`Optional[float]`):
|
|
If not `None`, this will activate NEFTune noise embeddings. This has been proven to drastically improve model performances for instruction
|
|
fine-tuning. Check out the original paper here: https://arxiv.org/abs/2310.05914 and the original code here: https://github.com/neelsjain/NEFTune
|
|
model_init_kwargs: (`Optional[Dict]`, *optional*):
|
|
Dict of Optional kwargs to pass when instantiating the model from a string
|
|
dataset_kwargs: (`Optional[Dict]`, *optional*):
|
|
Dict of Optional kwargs to pass when creating packed or non-packed datasets
|
|
"""
|
|
|
|
_tag_names = ["trl", "sft"]
|
|
|
|
def __init__(
|
|
self,
|
|
model: Union[PreTrainedModel, nn.Module, str] = None,
|
|
args: TrainingArguments = None,
|
|
data_collator: Optional[DataCollator] = None,
|
|
train_dataset: Optional[Dataset] = None,
|
|
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
|
|
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
|
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
|
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
|
callbacks: Optional[List[TrainerCallback]] = None,
|
|
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
|
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
|
peft_config: Optional["PeftConfig"] = None,
|
|
dataset_text_field: Optional[str] = None,
|
|
packing: Optional[bool] = False,
|
|
formatting_func: Optional[Callable] = None,
|
|
max_seq_length: Optional[int] = None,
|
|
infinite: Optional[bool] = None,
|
|
num_of_sequences: Optional[int] = 1024,
|
|
chars_per_token: Optional[float] = 3.6,
|
|
dataset_num_proc: Optional[int] = None,
|
|
dataset_batch_size: int = 1000,
|
|
neftune_noise_alpha: Optional[float] = None,
|
|
model_init_kwargs: Optional[Dict] = None,
|
|
dataset_kwargs: Optional[Dict] = None,
|
|
):
|
|
if model_init_kwargs is None:
|
|
model_init_kwargs = {}
|
|
elif not isinstance(model, str):
|
|
raise ValueError("You passed model_kwargs to the SFTTrainer. But your model is already instantiated.")
|
|
|
|
if infinite is not None:
|
|
warnings.warn("The `infinite` argument is deprecated and will be removed in a future version of TRL. Use `TrainingArguments.max_steps` or `TrainingArguments.num_train_epochs` instead to control training length.")
|
|
|
|
if isinstance(model, str):
|
|
warnings.warn("You passed a model_id to the SFTTrainer. This will automatically create an " "`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you.")
|
|
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
|
|
|
if packing and data_collator is not None and isinstance(data_collator, DataCollatorForCompletionOnlyLM):
|
|
raise ValueError("You passed a `DataCollatorForCompletionOnlyLM` to the SFTTrainer. This is not compatible with the `packing` argument.")
|
|
|
|
if is_peft_available() and peft_config is not None:
|
|
if not isinstance(peft_config, PeftConfig):
|
|
raise ValueError("If you want to use the PeftModel, you need to pass a PeftConfig object to the SFTTrainer." f" and you passed a {type(peft_config)}.")
|
|
|
|
if not isinstance(model, PeftModel):
|
|
_support_gc_kwargs = hasattr(args, "gradient_checkpointing_kwargs") and "gradient_checkpointing_kwargs" in list(inspect.signature(prepare_model_for_kbit_training).parameters)
|
|
gradient_checkpointing_kwargs = getattr(args, "gradient_checkpointing_kwargs", None) or {}
|
|
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
|
|
preprare_model_kwargs = {"use_gradient_checkpointing": getattr(args, "gradient_checkpointing", False)}
|
|
|
|
if _support_gc_kwargs:
|
|
preprare_model_kwargs["gradient_checkpointing_kwargs"] = gradient_checkpointing_kwargs
|
|
|
|
model = prepare_model_for_kbit_training(model, **preprare_model_kwargs)
|
|
|
|
if args is not None:
|
|
args = dataclasses.replace(args, gradient_checkpointing=False)
|
|
elif getattr(args, "gradient_checkpointing", False) and ("use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]):
|
|
|
|
if hasattr(model, "enable_input_require_grads"):
|
|
model.enable_input_require_grads()
|
|
else:
|
|
|
|
def make_inputs_require_grad(module, input, output):
|
|
output.requires_grad_(True)
|
|
|
|
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
|
|
|
model = get_peft_model(model, peft_config)
|
|
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
|
|
peft_module_casting_to_bf16(model)
|
|
|
|
if tokenizer is None:
|
|
tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path)
|
|
if getattr(tokenizer, "pad_token", None) is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
if max_seq_length is None:
|
|
|
|
max_seq_length = min(tokenizer.model_max_length, 1024)
|
|
|
|
warnings.warn(f"You didn't pass a `max_seq_length` argument to the SFTTrainer, this will default to {max_seq_length}")
|
|
|
|
self.dataset_num_proc = dataset_num_proc
|
|
self.dataset_batch_size = dataset_batch_size
|
|
|
|
self._trainer_supports_neftune = hasattr(args, "neftune_noise_alpha")
|
|
|
|
if neftune_noise_alpha is not None and self._trainer_supports_neftune:
|
|
args.neftune_noise_alpha = neftune_noise_alpha
|
|
warnings.warn("You passed a `neftune_noise_alpha` argument to the SFTTrainer, the value you passed will override the one in the `TrainingArguments`.")
|
|
|
|
elif not self._trainer_supports_neftune:
|
|
self.neftune_noise_alpha = neftune_noise_alpha
|
|
|
|
if formatting_func is None and dataset_text_field is None:
|
|
|
|
|
|
formatting_func = get_formatting_func_from_dataset(train_dataset, tokenizer)
|
|
|
|
if not packing:
|
|
if dataset_text_field is None and formatting_func is None:
|
|
raise ValueError("You passed `packing=False` to the SFTTrainer, but you didn't pass a `dataset_text_field` or `formatting_func` argument.")
|
|
|
|
if data_collator is None:
|
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
|
|
|
|
|
with PartialState().local_main_process_first():
|
|
if dataset_kwargs is None:
|
|
dataset_kwargs = {}
|
|
if train_dataset is not None:
|
|
train_dataset = self._prepare_dataset(
|
|
train_dataset,
|
|
tokenizer,
|
|
packing,
|
|
dataset_text_field,
|
|
max_seq_length,
|
|
formatting_func,
|
|
num_of_sequences,
|
|
chars_per_token,
|
|
remove_unused_columns=args.remove_unused_columns if args is not None else True,
|
|
**dataset_kwargs,
|
|
)
|
|
if eval_dataset is not None:
|
|
_multiple = isinstance(eval_dataset, dict)
|
|
_eval_datasets = eval_dataset if _multiple else {"singleton": eval_dataset}
|
|
for _eval_dataset_name, _eval_dataset in _eval_datasets.items():
|
|
_eval_datasets[_eval_dataset_name] = self._prepare_dataset(
|
|
_eval_dataset,
|
|
tokenizer,
|
|
packing,
|
|
dataset_text_field,
|
|
max_seq_length,
|
|
formatting_func,
|
|
num_of_sequences,
|
|
chars_per_token,
|
|
remove_unused_columns=args.remove_unused_columns if args is not None else True,
|
|
**dataset_kwargs,
|
|
)
|
|
if not _multiple:
|
|
eval_dataset = _eval_datasets["singleton"]
|
|
|
|
if tokenizer.padding_side is not None and tokenizer.padding_side != "right":
|
|
warnings.warn(
|
|
"You passed a tokenizer with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to "
|
|
"overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code."
|
|
)
|
|
|
|
super().__init__(
|
|
model=model,
|
|
args=args,
|
|
data_collator=data_collator,
|
|
train_dataset=train_dataset,
|
|
eval_dataset=eval_dataset,
|
|
tokenizer=tokenizer,
|
|
model_init=model_init,
|
|
compute_metrics=compute_metrics,
|
|
callbacks=callbacks,
|
|
optimizers=optimizers,
|
|
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
|
)
|
|
|
|
if self.args.max_steps > 0 and packing:
|
|
warnings.warn("You passed `packing=True` to the SFTTrainer, and you are training your model with `max_steps` strategy. The dataset will be iterated until the `max_steps` are reached.")
|
|
self.train_dataset.infinite = True
|
|
elif self.args.max_steps == -1 and packing:
|
|
self.train_dataset.infinite = False
|
|
|
|
@wraps(Trainer.train)
|
|
def train(self, *args, **kwargs):
|
|
|
|
if self.neftune_noise_alpha is not None and not self._trainer_supports_neftune:
|
|
self.model = self._trl_activate_neftune(self.model)
|
|
|
|
output = super().train(*args, **kwargs)
|
|
|
|
|
|
|
|
if self.neftune_noise_alpha is not None and not self._trainer_supports_neftune:
|
|
unwrapped_model = unwrap_model(self.model)
|
|
if is_peft_available() and isinstance(unwrapped_model, PeftModel):
|
|
embeddings = unwrapped_model.base_model.model.get_input_embeddings()
|
|
else:
|
|
embeddings = unwrapped_model.get_input_embeddings()
|
|
|
|
self.neftune_hook_handle.remove()
|
|
del embeddings.neftune_noise_alpha
|
|
|
|
return output
|
|
|
|
@wraps(Trainer.push_to_hub)
|
|
def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
|
|
"""
|
|
Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the
|
|
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
|
"""
|
|
kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs)
|
|
|
|
return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)
|
|
|
|
def _prepare_dataset(
|
|
self,
|
|
dataset,
|
|
tokenizer,
|
|
packing,
|
|
dataset_text_field,
|
|
max_seq_length,
|
|
formatting_func,
|
|
num_of_sequences,
|
|
chars_per_token,
|
|
remove_unused_columns=True,
|
|
append_concat_token=True,
|
|
add_special_tokens=True,
|
|
):
|
|
if dataset is None:
|
|
raise ValueError("The dataset should not be None")
|
|
|
|
|
|
if isinstance(dataset, (torch.utils.data.IterableDataset, torch.utils.data.Dataset, ConstantLengthDataset)):
|
|
return dataset
|
|
|
|
if not packing:
|
|
return self._prepare_non_packed_dataloader(
|
|
tokenizer,
|
|
dataset,
|
|
dataset_text_field,
|
|
max_seq_length,
|
|
formatting_func,
|
|
add_special_tokens,
|
|
remove_unused_columns,
|
|
)
|
|
|
|
else:
|
|
return self._prepare_packed_dataloader(
|
|
tokenizer,
|
|
dataset,
|
|
dataset_text_field,
|
|
max_seq_length,
|
|
num_of_sequences,
|
|
chars_per_token,
|
|
formatting_func,
|
|
append_concat_token,
|
|
add_special_tokens,
|
|
)
|
|
|
|
def _prepare_non_packed_dataloader(
|
|
self,
|
|
tokenizer,
|
|
dataset,
|
|
dataset_text_field,
|
|
max_seq_length,
|
|
formatting_func=None,
|
|
add_special_tokens=True,
|
|
remove_unused_columns=True,
|
|
):
|
|
use_formatting_func = formatting_func is not None and dataset_text_field is None
|
|
self._dataset_sanity_checked = False
|
|
|
|
|
|
def tokenize(element):
|
|
outputs = tokenizer(
|
|
element[dataset_text_field] if not use_formatting_func else formatting_func(element),
|
|
add_special_tokens=add_special_tokens,
|
|
truncation=True,
|
|
padding=False,
|
|
max_length=max_seq_length,
|
|
return_overflowing_tokens=False,
|
|
return_length=False,
|
|
)
|
|
|
|
if use_formatting_func and not self._dataset_sanity_checked:
|
|
if not isinstance(formatting_func(element), list):
|
|
raise ValueError("The `formatting_func` should return a list of processed strings since it can lead to silent bugs.")
|
|
else:
|
|
self._dataset_sanity_checked = True
|
|
|
|
return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]}
|
|
|
|
signature_columns = ["input_ids", "labels", "attention_mask"]
|
|
|
|
extra_columns = list(set(dataset.column_names) - set(signature_columns))
|
|
|
|
if not remove_unused_columns and len(extra_columns) > 0:
|
|
warnings.warn(
|
|
"You passed `remove_unused_columns=False` on a non-packed dataset. This might create some issues with the default collator and yield to errors. If you want to "
|
|
f"inspect dataset other columns (in this case {extra_columns}), you can subclass `DataCollatorForLanguageModeling` in case you used the default collator and create your own data collator in order to inspect the unused dataset columns."
|
|
)
|
|
|
|
tokenized_dataset = dataset.map(
|
|
tokenize,
|
|
batched=True,
|
|
remove_columns=dataset.column_names if remove_unused_columns else None,
|
|
num_proc=self.dataset_num_proc,
|
|
batch_size=self.dataset_batch_size,
|
|
)
|
|
|
|
return tokenized_dataset
|
|
|
|
def _prepare_packed_dataloader(
|
|
self,
|
|
tokenizer,
|
|
dataset,
|
|
dataset_text_field,
|
|
max_seq_length,
|
|
num_of_sequences,
|
|
chars_per_token,
|
|
formatting_func=None,
|
|
append_concat_token=True,
|
|
add_special_tokens=True,
|
|
):
|
|
if dataset_text_field is not None or formatting_func is not None:
|
|
if tokenizer is None:
|
|
raise ValueError("You need to pass a tokenizer when using `dataset_text_field` with `SFTTrainer`.")
|
|
|
|
constant_length_iterator = ConstantLengthDataset(
|
|
tokenizer,
|
|
dataset,
|
|
dataset_text_field=dataset_text_field,
|
|
formatting_func=formatting_func,
|
|
seq_length=max_seq_length,
|
|
infinite=False,
|
|
num_of_sequences=num_of_sequences,
|
|
chars_per_token=chars_per_token,
|
|
eos_token_id=tokenizer.eos_token_id,
|
|
append_concat_token=append_concat_token,
|
|
add_special_tokens=add_special_tokens,
|
|
)
|
|
|
|
def data_generator(constant_length_iterator):
|
|
for i in constant_length_iterator:
|
|
yield i
|
|
|
|
try:
|
|
packed_dataset = Dataset.from_generator(data_generator, gen_kwargs={"constant_length_iterator": constant_length_iterator})
|
|
except (DatasetGenerationError, SchemaInferenceError):
|
|
raise ValueError("Error occurred while packing the dataset. Make sure that your dataset has enough samples to at least yield one packed sequence.")
|
|
return packed_dataset
|
|
else:
|
|
raise ValueError("You need to pass a `dataset_text_field` or `formatting_func` argument to the SFTTrainer if you want to use the `ConstantLengthDataset`.")
|
|
|
|
def _trl_activate_neftune(self, model):
|
|
r"""
|
|
Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper: https://arxiv.org/abs/2310.05914
|
|
Since in transformers Trainer we do have an `_activate_neftune` method, we need to rename this method to avoid conflicts.
|
|
"""
|
|
unwrapped_model = unwrap_model(model)
|
|
if is_peft_available() and isinstance(unwrapped_model, PeftModel):
|
|
embeddings = unwrapped_model.base_model.model.get_input_embeddings()
|
|
else:
|
|
embeddings = unwrapped_model.get_input_embeddings()
|
|
|
|
embeddings.neftune_noise_alpha = self.neftune_noise_alpha
|
|
hook_handle = embeddings.register_forward_hook(neftune_post_forward_hook)
|
|
self.neftune_hook_handle = hook_handle
|
|
return model
|
|
|