Spaces:
Paused
Paused
import math | |
import os | |
import shutil | |
import time | |
import warnings | |
from pathlib import Path | |
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Tuple, Union | |
import evaluate | |
import torch | |
from datasets import Dataset, DatasetDict | |
from sentence_transformers import InputExample, SentenceTransformer, losses | |
from sentence_transformers.datasets import SentenceLabelDataset | |
from sentence_transformers.losses.BatchHardTripletLoss import BatchHardTripletLossDistanceFunction | |
from sentence_transformers.util import batch_to_device | |
from sklearn.preprocessing import LabelEncoder | |
from torch import nn | |
from torch.cuda.amp import autocast | |
from torch.utils.data import DataLoader | |
from tqdm.autonotebook import tqdm | |
from transformers.integrations import WandbCallback, get_reporting_integration_callbacks | |
from transformers.trainer_callback import ( | |
CallbackHandler, | |
DefaultFlowCallback, | |
IntervalStrategy, | |
PrinterCallback, | |
ProgressCallback, | |
TrainerCallback, | |
TrainerControl, | |
TrainerState, | |
) | |
from transformers.trainer_utils import ( | |
HPSearchBackend, | |
default_compute_objective, | |
number_of_arguments, | |
set_seed, | |
speed_metrics, | |
) | |
from transformers.utils.import_utils import is_in_notebook | |
from setfit.model_card import ModelCardCallback | |
from . import logging | |
from .integrations import default_hp_search_backend, is_optuna_available, run_hp_search_optuna | |
from .losses import SupConLoss | |
from .sampler import ContrastiveDataset | |
from .training_args import TrainingArguments | |
from .utils import BestRun, default_hp_space_optuna | |
# For Python 3.7 compatibility | |
try: | |
from typing import Literal | |
except ImportError: | |
from typing_extensions import Literal | |
if TYPE_CHECKING: | |
import optuna | |
from .modeling import SetFitModel | |
logging.set_verbosity_info() | |
logger = logging.get_logger(__name__) | |
DEFAULT_CALLBACKS = [DefaultFlowCallback] | |
DEFAULT_PROGRESS_CALLBACK = ProgressCallback | |
if is_in_notebook(): | |
from transformers.utils.notebook import NotebookProgressCallback | |
DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback | |
class ColumnMappingMixin: | |
_REQUIRED_COLUMNS = {"text", "label"} | |
def _validate_column_mapping(self, dataset: "Dataset") -> None: | |
""" | |
Validates the provided column mapping against the dataset. | |
""" | |
column_names = set(dataset.column_names) | |
if self.column_mapping is None and not self._REQUIRED_COLUMNS.issubset(column_names): | |
# Issue #226: load_dataset will automatically assign points to "train" if no split is specified | |
if column_names == {"train"} and isinstance(dataset, DatasetDict): | |
raise ValueError( | |
"SetFit expected a Dataset, but it got a DatasetDict with the split ['train']. " | |
"Did you mean to select the training split with dataset['train']?" | |
) | |
elif isinstance(dataset, DatasetDict): | |
raise ValueError( | |
f"SetFit expected a Dataset, but it got a DatasetDict with the splits {sorted(column_names)}. " | |
"Did you mean to select one of these splits from the dataset?" | |
) | |
else: | |
raise ValueError( | |
f"SetFit expected the dataset to have the columns {sorted(self._REQUIRED_COLUMNS)}, " | |
f"but only the columns {sorted(column_names)} were found. " | |
"Either make sure these columns are present, or specify which columns to use with column_mapping in Trainer." | |
) | |
if self.column_mapping is not None: | |
missing_columns = set(self._REQUIRED_COLUMNS) | |
# Remove columns that will be provided via the column mapping | |
missing_columns -= set(self.column_mapping.values()) | |
# Remove columns that will be provided because they are in the dataset & not mapped away | |
missing_columns -= set(dataset.column_names) - set(self.column_mapping.keys()) | |
if missing_columns: | |
raise ValueError( | |
f"The following columns are missing from the column mapping: {missing_columns}. " | |
"Please provide a mapping for all required columns." | |
) | |
if not set(self.column_mapping.keys()).issubset(column_names): | |
raise ValueError( | |
f"The column mapping expected the columns {sorted(self.column_mapping.keys())} in the dataset, " | |
f"but the dataset had the columns {sorted(column_names)}." | |
) | |
def _apply_column_mapping(self, dataset: "Dataset", column_mapping: Dict[str, str]) -> "Dataset": | |
""" | |
Applies the provided column mapping to the dataset, renaming columns accordingly. | |
Extra features not in the column mapping are prefixed with `"feat_"`. | |
""" | |
dataset = dataset.rename_columns( | |
{ | |
**column_mapping, | |
**{ | |
col: f"feat_{col}" | |
for col in dataset.column_names | |
if col not in column_mapping and col not in self._REQUIRED_COLUMNS | |
}, | |
} | |
) | |
dset_format = dataset.format | |
dataset = dataset.with_format( | |
type=dset_format["type"], | |
columns=dataset.column_names, | |
output_all_columns=dset_format["output_all_columns"], | |
**dset_format["format_kwargs"], | |
) | |
return dataset | |
class Trainer(ColumnMappingMixin): | |
"""Trainer to train a SetFit model. | |
Args: | |
model (`SetFitModel`, *optional*): | |
The model to train. If not provided, a `model_init` must be passed. | |
args (`TrainingArguments`, *optional*): | |
The training arguments to use. | |
train_dataset (`Dataset`): | |
The training dataset. | |
eval_dataset (`Dataset`, *optional*): | |
The evaluation dataset. | |
model_init (`Callable[[], SetFitModel]`, *optional*): | |
A function that instantiates the model to be used. If provided, each call to | |
[`Trainer.train`] will start from a new instance of the model as given by this | |
function when a `trial` is passed. | |
metric (`str` or `Callable`, *optional*, defaults to `"accuracy"`): | |
The metric to use for evaluation. If a string is provided, we treat it as the metric | |
name and load it with default settings. If a callable is provided, it must take two arguments | |
(`y_pred`, `y_test`) and return a dictionary with metric keys to values. | |
metric_kwargs (`Dict[str, Any]`, *optional*): | |
Keyword arguments passed to the evaluation function if `metric` is an evaluation string like "f1". | |
For example useful for providing an averaging strategy for computing f1 in a multi-label setting. | |
callbacks (`List[`[`~transformers.TrainerCallback`]`]`, *optional*): | |
A list of callbacks to customize the training loop. Will add those to the list of default callbacks | |
detailed in [here](https://huggingface.co/docs/transformers/main/en/main_classes/callback). | |
If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method. | |
column_mapping (`Dict[str, str]`, *optional*): | |
A mapping from the column names in the dataset to the column names expected by the model. | |
The expected format is a dictionary with the following format: | |
`{"text_column_name": "text", "label_column_name: "label"}`. | |
""" | |
def __init__( | |
self, | |
model: Optional["SetFitModel"] = None, | |
args: Optional[TrainingArguments] = None, | |
train_dataset: Optional["Dataset"] = None, | |
eval_dataset: Optional["Dataset"] = None, | |
model_init: Optional[Callable[[], "SetFitModel"]] = None, | |
metric: Union[str, Callable[["Dataset", "Dataset"], Dict[str, float]]] = "accuracy", | |
metric_kwargs: Optional[Dict[str, Any]] = None, | |
callbacks: Optional[List[TrainerCallback]] = None, | |
column_mapping: Optional[Dict[str, str]] = None, | |
) -> None: | |
if args is not None and not isinstance(args, TrainingArguments): | |
raise ValueError("`args` must be a `TrainingArguments` instance imported from `setfit`.") | |
self.args = args or TrainingArguments() | |
self.column_mapping = column_mapping | |
if train_dataset: | |
self._validate_column_mapping(train_dataset) | |
if self.column_mapping is not None: | |
logger.info("Applying column mapping to the training dataset") | |
train_dataset = self._apply_column_mapping(train_dataset, self.column_mapping) | |
self.train_dataset = train_dataset | |
if eval_dataset: | |
self._validate_column_mapping(eval_dataset) | |
if self.column_mapping is not None: | |
logger.info("Applying column mapping to the evaluation dataset") | |
eval_dataset = self._apply_column_mapping(eval_dataset, self.column_mapping) | |
self.eval_dataset = eval_dataset | |
self.model_init = model_init | |
self.metric = metric | |
self.metric_kwargs = metric_kwargs | |
self.logs_mapper = {} | |
# Seed must be set before instantiating the model when using model_init. | |
set_seed(12) | |
if model is None: | |
if model_init is not None: | |
model = self.call_model_init() | |
else: | |
raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument.") | |
else: | |
if model_init is not None: | |
raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument, but not both.") | |
self.model = model | |
self.hp_search_backend = None | |
# Setup the callbacks | |
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) | |
callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks | |
if WandbCallback in callbacks: | |
# Set the W&B project via environment variables if it's not already set | |
os.environ.setdefault("WANDB_PROJECT", "setfit") | |
# TODO: Observe optimizer and scheduler by wrapping SentenceTransformer._get_scheduler | |
self.callback_handler = CallbackHandler(callbacks, self.model, self.model.model_body.tokenizer, None, None) | |
self.state = TrainerState() | |
self.control = TrainerControl() | |
self.add_callback(DEFAULT_PROGRESS_CALLBACK if self.args.show_progress_bar else PrinterCallback) | |
self.control = self.callback_handler.on_init_end(self.args, self.state, self.control) | |
# Add the callback for filling the model card data with hyperparameters | |
# and evaluation results | |
self.add_callback(ModelCardCallback(self)) | |
self.callback_handler.on_init_end(args, self.state, self.control) | |
def add_callback(self, callback: Union[type, TrainerCallback]) -> None: | |
""" | |
Add a callback to the current list of [`~transformers.TrainerCallback`]. | |
Args: | |
callback (`type` or [`~transformers.TrainerCallback`]): | |
A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the | |
first case, will instantiate a member of that class. | |
""" | |
self.callback_handler.add_callback(callback) | |
def pop_callback(self, callback: Union[type, TrainerCallback]) -> TrainerCallback: | |
""" | |
Remove a callback from the current list of [`~transformers.TrainerCallback`] and returns it. | |
If the callback is not found, returns `None` (and no error is raised). | |
Args: | |
callback (`type` or [`~transformers.TrainerCallback`]): | |
A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the | |
first case, will pop the first member of that class found in the list of callbacks. | |
Returns: | |
[`~transformers.TrainerCallback`]: The callback removed, if found. | |
""" | |
return self.callback_handler.pop_callback(callback) | |
def remove_callback(self, callback: Union[type, TrainerCallback]) -> None: | |
""" | |
Remove a callback from the current list of [`~transformers.TrainerCallback`]. | |
Args: | |
callback (`type` or [`~transformers.TrainerCallback`]): | |
A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the | |
first case, will remove the first member of that class found in the list of callbacks. | |
""" | |
self.callback_handler.remove_callback(callback) | |
def apply_hyperparameters(self, params: Dict[str, Any], final_model: bool = False) -> None: | |
"""Applies a dictionary of hyperparameters to both the trainer and the model | |
Args: | |
params (`Dict[str, Any]`): The parameters, usually from `BestRun.hyperparameters` | |
final_model (`bool`, *optional*, defaults to `False`): If `True`, replace the `model_init()` function with a fixed model based on the parameters. | |
""" | |
if self.args is not None: | |
self.args = self.args.update(params, ignore_extra=True) | |
else: | |
self.args = TrainingArguments.from_dict(params, ignore_extra=True) | |
# Seed must be set before instantiating the model when using model_init. | |
set_seed(self.args.seed) | |
self.model = self.model_init(params) | |
if final_model: | |
self.model_init = None | |
def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]) -> None: | |
"""HP search setup code""" | |
# Heavily inspired by transformers.Trainer._hp_search_setup | |
if self.hp_search_backend is None or trial is None: | |
return | |
if isinstance(trial, Dict): # For passing a Dict to train() -- mostly unused for now | |
params = trial | |
elif self.hp_search_backend == HPSearchBackend.OPTUNA: | |
params = self.hp_space(trial) | |
else: | |
raise ValueError("Invalid trial parameter") | |
logger.info(f"Trial: {params}") | |
self.apply_hyperparameters(params, final_model=False) | |
def call_model_init(self, params: Optional[Dict[str, Any]] = None) -> "SetFitModel": | |
model_init_argcount = number_of_arguments(self.model_init) | |
if model_init_argcount == 0: | |
model = self.model_init() | |
elif model_init_argcount == 1: | |
model = self.model_init(params) | |
else: | |
raise RuntimeError("`model_init` should have 0 or 1 argument.") | |
if model is None: | |
raise RuntimeError("`model_init` should not return None.") | |
return model | |
def freeze(self, component: Optional[Literal["body", "head"]] = None) -> None: | |
"""Freeze the model body and/or the head, preventing further training on that component until unfrozen. | |
This method is deprecated, use `SetFitModel.freeze` instead. | |
Args: | |
component (`Literal["body", "head"]`, *optional*): Either "body" or "head" to freeze that component. | |
If no component is provided, freeze both. Defaults to None. | |
""" | |
warnings.warn( | |
f"`{self.__class__.__name__}.freeze` is deprecated and will be removed in v2.0.0 of SetFit. " | |
"Please use `SetFitModel.freeze` directly instead.", | |
DeprecationWarning, | |
stacklevel=2, | |
) | |
return self.model.freeze(component) | |
def unfreeze( | |
self, component: Optional[Literal["body", "head"]] = None, keep_body_frozen: Optional[bool] = None | |
) -> None: | |
"""Unfreeze the model body and/or the head, allowing further training on that component. | |
This method is deprecated, use `SetFitModel.unfreeze` instead. | |
Args: | |
component (`Literal["body", "head"]`, *optional*): Either "body" or "head" to unfreeze that component. | |
If no component is provided, unfreeze both. Defaults to None. | |
keep_body_frozen (`bool`, *optional*): Deprecated argument, use `component` instead. | |
""" | |
warnings.warn( | |
f"`{self.__class__.__name__}.unfreeze` is deprecated and will be removed in v2.0.0 of SetFit. " | |
"Please use `SetFitModel.unfreeze` directly instead.", | |
DeprecationWarning, | |
stacklevel=2, | |
) | |
return self.model.unfreeze(component, keep_body_frozen=keep_body_frozen) | |
def train( | |
self, | |
args: Optional[TrainingArguments] = None, | |
trial: Optional[Union["optuna.Trial", Dict[str, Any]]] = None, | |
**kwargs, | |
) -> None: | |
""" | |
Main training entry point. | |
Args: | |
args (`TrainingArguments`, *optional*): | |
Temporarily change the training arguments for this training call. | |
trial (`optuna.Trial` or `Dict[str, Any]`, *optional*): | |
The trial run or the hyperparameter dictionary for hyperparameter search. | |
""" | |
if len(kwargs): | |
warnings.warn( | |
f"`{self.__class__.__name__}.train` does not accept keyword arguments anymore. " | |
f"Please provide training arguments via a `TrainingArguments` instance to the `{self.__class__.__name__}` " | |
f"initialisation or the `{self.__class__.__name__}.train` method.", | |
DeprecationWarning, | |
stacklevel=2, | |
) | |
if trial: # Trial and model initialization | |
self._hp_search_setup(trial) # sets trainer parameters and initializes model | |
args = args or self.args or TrainingArguments() | |
if self.train_dataset is None: | |
raise ValueError( | |
f"Training requires a `train_dataset` given to the `{self.__class__.__name__}` initialization." | |
) | |
train_parameters = self.dataset_to_parameters(self.train_dataset) | |
full_parameters = ( | |
train_parameters + self.dataset_to_parameters(self.eval_dataset) if self.eval_dataset else train_parameters | |
) | |
self.train_embeddings(*full_parameters, args=args) | |
self.train_classifier(*train_parameters, args=args) | |
def dataset_to_parameters(self, dataset: Dataset) -> List[Iterable]: | |
return [dataset["text"], dataset["label"]] | |
def train_embeddings( | |
self, | |
x_train: List[str], | |
y_train: Optional[Union[List[int], List[List[int]]]] = None, | |
x_eval: Optional[List[str]] = None, | |
y_eval: Optional[Union[List[int], List[List[int]]]] = None, | |
args: Optional[TrainingArguments] = None, | |
) -> None: | |
""" | |
Method to perform the embedding phase: finetuning the `SentenceTransformer` body. | |
Args: | |
x_train (`List[str]`): A list of training sentences. | |
y_train (`Union[List[int], List[List[int]]]`): A list of labels corresponding to the training sentences. | |
args (`TrainingArguments`, *optional*): | |
Temporarily change the training arguments for this training call. | |
""" | |
args = args or self.args or TrainingArguments() | |
# Since transformers v4.32.0, the log/eval/save steps should be saved on the state instead | |
self.state.logging_steps = args.logging_steps | |
self.state.eval_steps = args.eval_steps | |
self.state.save_steps = args.save_steps | |
# Reset the state | |
self.state.global_step = 0 | |
self.state.total_flos = 0 | |
train_max_pairs = -1 if args.max_steps == -1 else args.max_steps * args.embedding_batch_size | |
train_dataloader, loss_func, batch_size = self.get_dataloader( | |
x_train, y_train, args=args, max_pairs=train_max_pairs | |
) | |
if x_eval is not None and args.evaluation_strategy != IntervalStrategy.NO: | |
eval_max_pairs = -1 if args.eval_max_steps == -1 else args.eval_max_steps * args.embedding_batch_size | |
eval_dataloader, _, _ = self.get_dataloader(x_eval, y_eval, args=args, max_pairs=eval_max_pairs) | |
else: | |
eval_dataloader = None | |
if args.max_steps > 0: | |
total_train_steps = args.max_steps | |
else: | |
total_train_steps = len(train_dataloader) * args.embedding_num_epochs | |
logger.info("***** Running training *****") | |
logger.info(f" Num examples = {len(train_dataloader)}") | |
logger.info(f" Num epochs = {args.embedding_num_epochs}") | |
logger.info(f" Total optimization steps = {total_train_steps}") | |
logger.info(f" Total train batch size = {batch_size}") | |
warmup_steps = math.ceil(total_train_steps * args.warmup_proportion) | |
self._train_sentence_transformer( | |
self.model.model_body, | |
train_dataloader=train_dataloader, | |
eval_dataloader=eval_dataloader, | |
args=args, | |
loss_func=loss_func, | |
warmup_steps=warmup_steps, | |
) | |
def get_dataloader( | |
self, x: List[str], y: Union[List[int], List[List[int]]], args: TrainingArguments, max_pairs: int = -1 | |
) -> Tuple[DataLoader, nn.Module, int]: | |
# sentence-transformers adaptation | |
input_data = [InputExample(texts=[text], label=label) for text, label in zip(x, y)] | |
if args.loss in [ | |
losses.BatchAllTripletLoss, | |
losses.BatchHardTripletLoss, | |
losses.BatchSemiHardTripletLoss, | |
losses.BatchHardSoftMarginTripletLoss, | |
SupConLoss, | |
]: | |
data_sampler = SentenceLabelDataset(input_data, samples_per_label=args.samples_per_label) | |
batch_size = min(args.embedding_batch_size, len(data_sampler)) | |
dataloader = DataLoader(data_sampler, batch_size=batch_size, drop_last=True) | |
if args.loss is losses.BatchHardSoftMarginTripletLoss: | |
loss = args.loss( | |
model=self.model.model_body, | |
distance_metric=args.distance_metric, | |
) | |
elif args.loss is SupConLoss: | |
loss = args.loss(model=self.model.model_body) | |
else: | |
loss = args.loss( | |
model=self.model.model_body, | |
distance_metric=args.distance_metric, | |
margin=args.margin, | |
) | |
else: | |
data_sampler = ContrastiveDataset( | |
input_data, | |
self.model.multi_target_strategy, | |
args.num_iterations, | |
args.sampling_strategy, | |
max_pairs=max_pairs, | |
) | |
# shuffle_sampler = True can be dropped in for further 'randomising' | |
shuffle_sampler = True if args.sampling_strategy == "unique" else False | |
batch_size = min(args.embedding_batch_size, len(data_sampler)) | |
dataloader = DataLoader(data_sampler, batch_size=batch_size, shuffle=shuffle_sampler, drop_last=False) | |
loss = args.loss(self.model.model_body) | |
return dataloader, loss, batch_size | |
def log(self, args: TrainingArguments, logs: Dict[str, float]) -> None: | |
""" | |
Log `logs` on the various objects watching training. | |
Subclass and override this method to inject custom behavior. | |
Args: | |
logs (`Dict[str, float]`): | |
The values to log. | |
""" | |
logs = {self.logs_mapper.get(key, key): value for key, value in logs.items()} | |
if self.state.epoch is not None: | |
logs["epoch"] = round(self.state.epoch, 2) | |
output = {**logs, **{"step": self.state.global_step}} | |
self.state.log_history.append(output) | |
return self.callback_handler.on_log(args, self.state, self.control, logs) | |
def _set_logs_mapper(self, logs_mapper: Dict[str, str]) -> None: | |
"""Set the logging mapper. | |
Args: | |
logs_mapper (str): The logging mapper, e.g. {"eval_embedding_loss": "eval_aspect_embedding_loss"}. | |
""" | |
self.logs_mapper = logs_mapper | |
def _train_sentence_transformer( | |
self, | |
model_body: SentenceTransformer, | |
train_dataloader: DataLoader, | |
eval_dataloader: Optional[DataLoader], | |
args: TrainingArguments, | |
loss_func: nn.Module, | |
warmup_steps: int = 10000, | |
) -> None: | |
""" | |
Train the model with the given training objective | |
Each training objective is sampled in turn for one batch. | |
We sample only as many batches from each objective as there are in the smallest one | |
to make sure of equal training with each dataset. | |
""" | |
# TODO: args.gradient_accumulation_steps | |
# TODO: fp16/bf16, etc. | |
# TODO: Safetensors | |
# Hardcoded training arguments | |
max_grad_norm = 1 | |
# | |
# | |
# | |
# | |
# | |
weight_decay = 5e-3 # 5e-3 best | |
# | |
# | |
# | |
# | |
# | |
self.state.epoch = 0 | |
start_time = time.time() | |
if args.max_steps > 0: | |
self.state.max_steps = args.max_steps | |
else: | |
self.state.max_steps = len(train_dataloader) * args.embedding_num_epochs | |
self.control = self.callback_handler.on_train_begin(args, self.state, self.control) | |
steps_per_epoch = len(train_dataloader) | |
if args.use_amp: | |
scaler = torch.cuda.amp.GradScaler() | |
model_body.to(model_body._target_device) | |
loss_func.to(model_body._target_device) | |
# Use smart batching | |
train_dataloader.collate_fn = model_body.smart_batching_collate | |
if eval_dataloader: | |
eval_dataloader.collate_fn = model_body.smart_batching_collate | |
# Prepare optimizers | |
param_optimizer = list(loss_func.named_parameters()) | |
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] | |
optimizer_grouped_parameters = [ | |
{ | |
"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], | |
"weight_decay": weight_decay, | |
}, | |
{"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, | |
] | |
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, **{"lr": args.body_embedding_learning_rate}) | |
scheduler_obj = model_body._get_scheduler( | |
optimizer, scheduler="WarmupLinear", warmup_steps=warmup_steps, t_total=self.state.max_steps | |
) | |
self.callback_handler.optimizer = optimizer | |
self.callback_handler.lr_scheduler = scheduler_obj | |
self.callback_handler.train_dataloader = train_dataloader | |
self.callback_handler.eval_dataloader = eval_dataloader | |
self.callback_handler.on_train_begin(args, self.state, self.control) | |
data_iterator = iter(train_dataloader) | |
skip_scheduler = False | |
for epoch in range(args.embedding_num_epochs): | |
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) | |
loss_func.zero_grad() | |
loss_func.train() | |
for step in range(steps_per_epoch): | |
self.control = self.callback_handler.on_step_begin(args, self.state, self.control) | |
try: | |
data = next(data_iterator) | |
except StopIteration: | |
data_iterator = iter(train_dataloader) | |
data = next(data_iterator) | |
features, labels = data | |
labels = labels.to(model_body._target_device) | |
features = list(map(lambda batch: batch_to_device(batch, model_body._target_device), features)) | |
if args.use_amp: | |
with autocast(): | |
loss_value = loss_func(features, labels) | |
scale_before_step = scaler.get_scale() | |
scaler.scale(loss_value).backward() | |
scaler.unscale_(optimizer) | |
torch.nn.utils.clip_grad_norm_(loss_func.parameters(), max_grad_norm) | |
scaler.step(optimizer) | |
scaler.update() | |
skip_scheduler = scaler.get_scale() != scale_before_step | |
else: | |
loss_value = loss_func(features, labels) | |
loss_value.backward() | |
torch.nn.utils.clip_grad_norm_(loss_func.parameters(), max_grad_norm) | |
optimizer.step() | |
optimizer.zero_grad() | |
if not skip_scheduler: | |
scheduler_obj.step() | |
self.state.global_step += 1 | |
self.state.epoch = epoch + (step + 1) / steps_per_epoch | |
self.control = self.callback_handler.on_step_end(args, self.state, self.control) | |
self.maybe_log_eval_save(model_body, eval_dataloader, args, scheduler_obj, loss_func, loss_value) | |
if self.control.should_epoch_stop or self.control.should_training_stop: | |
break | |
self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) | |
self.maybe_log_eval_save(model_body, eval_dataloader, args, scheduler_obj, loss_func, loss_value) | |
if self.control.should_training_stop: | |
break | |
if self.args.load_best_model_at_end and self.state.best_model_checkpoint: | |
dir_name = Path(self.state.best_model_checkpoint).name | |
if dir_name.startswith("step_"): | |
step_to_load = dir_name[5:] | |
logger.info(f"Loading best SentenceTransformer model from step {step_to_load}.") | |
self.model.model_card_data.set_best_model_step(int(step_to_load)) | |
self.model.model_body = SentenceTransformer( | |
self.state.best_model_checkpoint, device=model_body._target_device | |
) | |
self.model.model_body.to(model_body._target_device) | |
# Ensure logging the speed metrics | |
num_train_samples = self.state.max_steps * args.embedding_batch_size # * args.gradient_accumulation_steps | |
metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps) | |
self.control.should_log = True | |
self.log(args, metrics) | |
self.control = self.callback_handler.on_train_end(args, self.state, self.control) | |
def maybe_log_eval_save( | |
self, | |
model_body: SentenceTransformer, | |
eval_dataloader: Optional[DataLoader], | |
args: TrainingArguments, | |
scheduler_obj, | |
loss_func, | |
loss_value: torch.Tensor, | |
) -> None: | |
if self.control.should_log: | |
learning_rate = scheduler_obj.get_last_lr()[0] | |
metrics = {"embedding_loss": round(loss_value.item(), 4), "learning_rate": learning_rate} | |
self.control = self.log(args, metrics) | |
eval_loss = None | |
if self.control.should_evaluate and eval_dataloader is not None: | |
eval_loss = self._evaluate_with_loss(model_body, eval_dataloader, args, loss_func) | |
learning_rate = scheduler_obj.get_last_lr()[0] | |
metrics = {"eval_embedding_loss": round(eval_loss, 4), "learning_rate": learning_rate} | |
self.control = self.log(args, metrics) | |
self.control = self.callback_handler.on_evaluate(args, self.state, self.control, metrics) | |
loss_func.zero_grad() | |
loss_func.train() | |
if self.control.should_save: | |
checkpoint_dir = self._checkpoint(self.args.output_dir, args.save_total_limit, self.state.global_step) | |
self.control = self.callback_handler.on_save(self.args, self.state, self.control) | |
if eval_loss is not None and (self.state.best_metric is None or eval_loss < self.state.best_metric): | |
self.state.best_metric = eval_loss | |
self.state.best_model_checkpoint = checkpoint_dir | |
def _evaluate_with_loss( | |
self, | |
model_body: SentenceTransformer, | |
eval_dataloader: DataLoader, | |
args: TrainingArguments, | |
loss_func: nn.Module, | |
) -> float: | |
model_body.eval() | |
losses = [] | |
eval_steps = ( | |
min(len(eval_dataloader), args.eval_max_steps) if args.eval_max_steps != -1 else len(eval_dataloader) | |
) | |
for step, data in enumerate( | |
tqdm(iter(eval_dataloader), total=eval_steps, leave=False, disable=not args.show_progress_bar), start=1 | |
): | |
features, labels = data | |
labels = labels.to(model_body._target_device) | |
features = list(map(lambda batch: batch_to_device(batch, model_body._target_device), features)) | |
if args.use_amp: | |
with autocast(): | |
loss_value = loss_func(features, labels) | |
losses.append(loss_value.item()) | |
else: | |
losses.append(loss_func(features, labels).item()) | |
if step >= eval_steps: | |
break | |
model_body.train() | |
return sum(losses) / len(losses) | |
def _checkpoint(self, checkpoint_path: str, checkpoint_save_total_limit: int, step: int) -> None: | |
# Delete old checkpoints | |
if checkpoint_save_total_limit is not None and checkpoint_save_total_limit > 0: | |
old_checkpoints = [] | |
for subdir in Path(checkpoint_path).glob("step_*"): | |
if subdir.name[5:].isdigit() and ( | |
self.state.best_model_checkpoint is None or subdir != Path(self.state.best_model_checkpoint) | |
): | |
old_checkpoints.append({"step": int(subdir.name[5:]), "path": str(subdir)}) | |
if len(old_checkpoints) > checkpoint_save_total_limit - 1: | |
old_checkpoints = sorted(old_checkpoints, key=lambda x: x["step"]) | |
shutil.rmtree(old_checkpoints[0]["path"]) | |
checkpoint_file_path = str(Path(checkpoint_path) / f"step_{step}") | |
self.model.save_pretrained(checkpoint_file_path) | |
return checkpoint_file_path | |
def train_classifier( | |
self, x_train: List[str], y_train: Union[List[int], List[List[int]]], args: Optional[TrainingArguments] = None | |
) -> None: | |
""" | |
Method to perform the classifier phase: fitting a classifier head. | |
Args: | |
x_train (`List[str]`): A list of training sentences. | |
y_train (`Union[List[int], List[List[int]]]`): A list of labels corresponding to the training sentences. | |
args (`TrainingArguments`, *optional*): | |
Temporarily change the training arguments for this training call. | |
""" | |
args = args or self.args or TrainingArguments() | |
self.model.fit( | |
x_train, | |
y_train, | |
num_epochs=args.classifier_num_epochs, | |
batch_size=args.classifier_batch_size, | |
body_learning_rate=args.body_classifier_learning_rate, | |
head_learning_rate=args.head_learning_rate, | |
l2_weight=args.l2_weight, | |
max_length=args.max_length, | |
show_progress_bar=args.show_progress_bar, | |
end_to_end=args.end_to_end, | |
) | |
def evaluate(self, dataset: Optional[Dataset] = None, metric_key_prefix: str = "test") -> Dict[str, float]: | |
""" | |
Computes the metrics for a given classifier. | |
Args: | |
dataset (`Dataset`, *optional*): | |
The dataset to compute the metrics on. If not provided, will use the evaluation dataset passed via | |
the `eval_dataset` argument at `Trainer` initialization. | |
Returns: | |
`Dict[str, float]`: The evaluation metrics. | |
""" | |
if dataset is not None: | |
self._validate_column_mapping(dataset) | |
if self.column_mapping is not None: | |
logger.info("Applying column mapping to the evaluation dataset") | |
eval_dataset = self._apply_column_mapping(dataset, self.column_mapping) | |
else: | |
eval_dataset = dataset | |
else: | |
eval_dataset = self.eval_dataset | |
if eval_dataset is None: | |
raise ValueError("No evaluation dataset provided to `Trainer.evaluate` nor the `Trainer` initialzation.") | |
x_test = eval_dataset["text"] | |
y_test = eval_dataset["label"] | |
logger.info("***** Running evaluation *****") | |
y_pred = self.model.predict(x_test, use_labels=False) | |
# | |
# | |
# | |
# | |
# | |
if isinstance(y_pred, torch.Tensor): | |
y_pred = y_pred.cpu() | |
# | |
# | |
# | |
# | |
# | |
# Normalize string outputs | |
if y_test and isinstance(y_test[0], str): | |
encoder = LabelEncoder() | |
encoder.fit(list(y_test) + list(y_pred)) | |
y_test = encoder.transform(y_test) | |
y_pred = encoder.transform(y_pred) | |
metric_kwargs = self.metric_kwargs or {} | |
if isinstance(self.metric, str): | |
metric_config = "multilabel" if self.model.multi_target_strategy is not None else None | |
metric_fn = evaluate.load(self.metric, config_name=metric_config) | |
results = metric_fn.compute(predictions=y_pred, references=y_test, **metric_kwargs) | |
elif callable(self.metric): | |
results = self.metric(y_pred, y_test, **metric_kwargs) | |
else: | |
raise ValueError("metric must be a string or a callable") | |
if not isinstance(results, dict): | |
results = {"metric": results} | |
self.model.model_card_data.post_training_eval_results( | |
{f"{metric_key_prefix}_{key}": value for key, value in results.items()} | |
) | |
return results | |
def hyperparameter_search( | |
self, | |
hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None, | |
compute_objective: Optional[Callable[[Dict[str, float]], float]] = None, | |
n_trials: int = 10, | |
direction: str = "maximize", | |
backend: Optional[Union["str", HPSearchBackend]] = None, | |
hp_name: Optional[Callable[["optuna.Trial"], str]] = None, | |
**kwargs, | |
) -> BestRun: | |
""" | |
Launch a hyperparameter search using `optuna`. The optimized quantity is determined | |
by `compute_objective`, which defaults to a function returning the evaluation loss when no metric is provided, | |
the sum of all metrics otherwise. | |
<Tip warning={true}> | |
To use this method, you need to have provided a `model_init` when initializing your [`Trainer`]: we need to | |
reinitialize the model at each new run. | |
</Tip> | |
Args: | |
hp_space (`Callable[["optuna.Trial"], Dict[str, float]]`, *optional*): | |
A function that defines the hyperparameter search space. Will default to | |
[`~transformers.trainer_utils.default_hp_space_optuna`]. | |
compute_objective (`Callable[[Dict[str, float]], float]`, *optional*): | |
A function computing the objective to minimize or maximize from the metrics returned by the `evaluate` | |
method. Will default to [`~transformers.trainer_utils.default_compute_objective`] which uses the sum of metrics. | |
n_trials (`int`, *optional*, defaults to 100): | |
The number of trial runs to test. | |
direction (`str`, *optional*, defaults to `"maximize"`): | |
Whether to optimize greater or lower objects. Can be `"minimize"` or `"maximize"`, you should pick | |
`"minimize"` when optimizing the validation loss, `"maximize"` when optimizing one or several metrics. | |
backend (`str` or [`~transformers.training_utils.HPSearchBackend`], *optional*): | |
The backend to use for hyperparameter search. Only optuna is supported for now. | |
TODO: add support for ray and sigopt. | |
hp_name (`Callable[["optuna.Trial"], str]]`, *optional*): | |
A function that defines the trial/run name. Will default to None. | |
kwargs (`Dict[str, Any]`, *optional*): | |
Additional keyword arguments passed along to `optuna.create_study`. For more | |
information see: | |
- the documentation of | |
[optuna.create_study](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.create_study.html) | |
Returns: | |
[`trainer_utils.BestRun`]: All the information about the best run. | |
""" | |
if backend is None: | |
backend = default_hp_search_backend() | |
if backend is None: | |
raise RuntimeError("optuna should be installed. To install optuna run `pip install optuna`.") | |
backend = HPSearchBackend(backend) | |
if backend == HPSearchBackend.OPTUNA and not is_optuna_available(): | |
raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.") | |
elif backend != HPSearchBackend.OPTUNA: | |
raise RuntimeError("Only optuna backend is supported for hyperparameter search.") | |
self.hp_search_backend = backend | |
if self.model_init is None: | |
raise RuntimeError( | |
"To use hyperparameter search, you need to pass your model through a model_init function." | |
) | |
self.hp_space = default_hp_space_optuna if hp_space is None else hp_space | |
self.hp_name = hp_name | |
self.compute_objective = default_compute_objective if compute_objective is None else compute_objective | |
backend_dict = { | |
HPSearchBackend.OPTUNA: run_hp_search_optuna, | |
} | |
best_run = backend_dict[backend](self, n_trials, direction, **kwargs) | |
self.hp_search_backend = None | |
return best_run | |
def push_to_hub(self, repo_id: str, **kwargs) -> str: | |
"""Upload model checkpoint to the Hub using `huggingface_hub`. | |
See the full list of parameters for your `huggingface_hub` version in the\ | |
[huggingface_hub documentation](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.ModelHubMixin.push_to_hub). | |
Args: | |
repo_id (`str`): | |
The full repository ID to push to, e.g. `"tomaarsen/setfit-sst2"`. | |
config (`dict`, *optional*): | |
Configuration object to be saved alongside the model weights. | |
commit_message (`str`, *optional*): | |
Message to commit while pushing. | |
private (`bool`, *optional*, defaults to `False`): | |
Whether the repository created should be private. | |
api_endpoint (`str`, *optional*): | |
The API endpoint to use when pushing the model to the hub. | |
token (`str`, *optional*): | |
The token to use as HTTP bearer authorization for remote files. | |
If not set, will use the token set when logging in with | |
`transformers-cli login` (stored in `~/.huggingface`). | |
branch (`str`, *optional*): | |
The git branch on which to push the model. This defaults to | |
the default branch as specified in your repository, which | |
defaults to `"main"`. | |
create_pr (`boolean`, *optional*): | |
Whether or not to create a Pull Request from `branch` with that commit. | |
Defaults to `False`. | |
allow_patterns (`List[str]` or `str`, *optional*): | |
If provided, only files matching at least one pattern are pushed. | |
ignore_patterns (`List[str]` or `str`, *optional*): | |
If provided, files matching any of the patterns are not pushed. | |
Returns: | |
str: The url of the commit of your model in the given repository. | |
""" | |
if "/" not in repo_id: | |
raise ValueError( | |
'`repo_id` must be a full repository ID, including organisation, e.g. "tomaarsen/setfit-sst2".' | |
) | |
commit_message = kwargs.pop("commit_message", "Add SetFit model") | |
return self.model.push_to_hub(repo_id, commit_message=commit_message, **kwargs) | |
class SetFitTrainer(Trainer): | |
""" | |
`SetFitTrainer` has been deprecated and will be removed in v2.0.0 of SetFit. | |
Please use `Trainer` instead. | |
""" | |
def __init__( | |
self, | |
model: Optional["SetFitModel"] = None, | |
train_dataset: Optional["Dataset"] = None, | |
eval_dataset: Optional["Dataset"] = None, | |
model_init: Optional[Callable[[], "SetFitModel"]] = None, | |
metric: Union[str, Callable[["Dataset", "Dataset"], Dict[str, float]]] = "accuracy", | |
metric_kwargs: Optional[Dict[str, Any]] = None, | |
loss_class=losses.CosineSimilarityLoss, | |
num_iterations: int = 20, | |
num_epochs: int = 1, | |
learning_rate: float = 2e-5, | |
batch_size: int = 16, | |
seed: int = 42, | |
column_mapping: Optional[Dict[str, str]] = None, | |
use_amp: bool = False, | |
warmup_proportion: float = 0.1, | |
distance_metric: Callable = BatchHardTripletLossDistanceFunction.cosine_distance, | |
margin: float = 0.25, | |
samples_per_label: int = 2, | |
): | |
warnings.warn( | |
"`SetFitTrainer` has been deprecated and will be removed in v2.0.0 of SetFit. " | |
"Please use `Trainer` instead.", | |
DeprecationWarning, | |
stacklevel=2, | |
) | |
args = TrainingArguments( | |
num_iterations=num_iterations, | |
num_epochs=num_epochs, | |
body_learning_rate=learning_rate, | |
head_learning_rate=learning_rate, | |
batch_size=batch_size, | |
seed=seed, | |
use_amp=use_amp, | |
warmup_proportion=warmup_proportion, | |
distance_metric=distance_metric, | |
margin=margin, | |
samples_per_label=samples_per_label, | |
loss=loss_class, | |
) | |
super().__init__( | |
model=model, | |
args=args, | |
train_dataset=train_dataset, | |
eval_dataset=eval_dataset, | |
model_init=model_init, | |
metric=metric, | |
metric_kwargs=metric_kwargs, | |
column_mapping=column_mapping, | |
) | |