import warnings from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple, Union import torch from datasets import Dataset from sentence_transformers import InputExample, losses, util from torch import nn from torch.utils.data import DataLoader from . import logging from .sampler import ContrastiveDistillationDataset from .trainer import Trainer from .training_args import TrainingArguments if TYPE_CHECKING: from .modeling import SetFitModel logging.set_verbosity_info() logger = logging.get_logger(__name__) class DistillationTrainer(Trainer): """Trainer to compress a SetFit model with knowledge distillation. Args: teacher_model (`SetFitModel`): The teacher model to mimic. student_model (`SetFitModel`, *optional*): The model to train. If not provided, a `model_init` must be passed. args (`TrainingArguments`, *optional*): The training arguments to use. train_dataset (`Dataset`): The training dataset. eval_dataset (`Dataset`, *optional*): The evaluation dataset. model_init (`Callable[[], SetFitModel]`, *optional*): A function that instantiates the model to be used. If provided, each call to [`~DistillationTrainer.train`] will start from a new instance of the model as given by this function when a `trial` is passed. metric (`str` or `Callable`, *optional*, defaults to `"accuracy"`): The metric to use for evaluation. If a string is provided, we treat it as the metric name and load it with default settings. If a callable is provided, it must take two arguments (`y_pred`, `y_test`). column_mapping (`Dict[str, str]`, *optional*): A mapping from the column names in the dataset to the column names expected by the model. The expected format is a dictionary with the following format: `{"text_column_name": "text", "label_column_name: "label"}`. """ _REQUIRED_COLUMNS = {"text"} def __init__( self, teacher_model: "SetFitModel", student_model: Optional["SetFitModel"] = None, args: TrainingArguments = None, train_dataset: Optional["Dataset"] = None, eval_dataset: Optional["Dataset"] = None, model_init: Optional[Callable[[], "SetFitModel"]] = None, metric: Union[str, Callable[["Dataset", "Dataset"], Dict[str, float]]] = "accuracy", column_mapping: Optional[Dict[str, str]] = None, ) -> None: super().__init__( model=student_model, args=args, train_dataset=train_dataset, eval_dataset=eval_dataset, model_init=model_init, metric=metric, column_mapping=column_mapping, ) self.teacher_model = teacher_model self.student_model = self.model def dataset_to_parameters(self, dataset: Dataset) -> List[Iterable]: return [dataset["text"]] def get_dataloader( self, x: List[str], y: Optional[Union[List[int], List[List[int]]]], args: TrainingArguments, max_pairs: int = -1, ) -> Tuple[DataLoader, nn.Module, int]: x_embd_student = self.teacher_model.model_body.encode( x, convert_to_tensor=self.teacher_model.has_differentiable_head ) cos_sim_matrix = util.cos_sim(x_embd_student, x_embd_student) input_data = [InputExample(texts=[text]) for text in x] data_sampler = ContrastiveDistillationDataset( input_data, cos_sim_matrix, args.num_iterations, args.sampling_strategy, max_pairs=max_pairs ) # shuffle_sampler = True can be dropped in for further 'randomising' shuffle_sampler = True if args.sampling_strategy == "unique" else False batch_size = min(args.embedding_batch_size, len(data_sampler)) dataloader = DataLoader(data_sampler, batch_size=batch_size, shuffle=shuffle_sampler, drop_last=False) loss = args.loss(self.model.model_body) return dataloader, loss, batch_size def train_classifier(self, x_train: List[str], args: Optional[TrainingArguments] = None) -> None: """ Method to perform the classifier phase: fitting the student classifier head. Args: x_train (`List[str]`): A list of training sentences. args (`TrainingArguments`, *optional*): Temporarily change the training arguments for this training call. """ y_train = self.teacher_model.predict(x_train, as_numpy=not self.student_model.has_differentiable_head) return super().train_classifier(x_train, y_train, args) class DistillationSetFitTrainer(DistillationTrainer): """ `DistillationSetFitTrainer` has been deprecated and will be removed in v2.0.0 of SetFit. Please use `DistillationTrainer` instead. """ def __init__( self, teacher_model: "SetFitModel", student_model: Optional["SetFitModel"] = None, train_dataset: Optional["Dataset"] = None, eval_dataset: Optional["Dataset"] = None, model_init: Optional[Callable[[], "SetFitModel"]] = None, metric: Union[str, Callable[["Dataset", "Dataset"], Dict[str, float]]] = "accuracy", loss_class: torch.nn.Module = losses.CosineSimilarityLoss, num_iterations: int = 20, num_epochs: int = 1, learning_rate: float = 2e-5, batch_size: int = 16, seed: int = 42, column_mapping: Optional[Dict[str, str]] = None, use_amp: bool = False, warmup_proportion: float = 0.1, ) -> None: warnings.warn( "`DistillationSetFitTrainer` has been deprecated and will be removed in v2.0.0 of SetFit. " "Please use `DistillationTrainer` instead.", DeprecationWarning, stacklevel=2, ) args = TrainingArguments( num_iterations=num_iterations, num_epochs=num_epochs, body_learning_rate=learning_rate, head_learning_rate=learning_rate, batch_size=batch_size, seed=seed, use_amp=use_amp, warmup_proportion=warmup_proportion, loss=loss_class, ) super().__init__( teacher_model=teacher_model, student_model=student_model, args=args, train_dataset=train_dataset, eval_dataset=eval_dataset, model_init=model_init, metric=metric, column_mapping=column_mapping, )