Spaces:
Paused
Paused
File size: 6,838 Bytes
0fdb130 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import warnings
from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple, Union
import torch
from datasets import Dataset
from sentence_transformers import InputExample, losses, util
from torch import nn
from torch.utils.data import DataLoader
from . import logging
from .sampler import ContrastiveDistillationDataset
from .trainer import Trainer
from .training_args import TrainingArguments
if TYPE_CHECKING:
from .modeling import SetFitModel
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
class DistillationTrainer(Trainer):
"""Trainer to compress a SetFit model with knowledge distillation.
Args:
teacher_model (`SetFitModel`):
The teacher model to mimic.
student_model (`SetFitModel`, *optional*):
The model to train. If not provided, a `model_init` must be passed.
args (`TrainingArguments`, *optional*):
The training arguments to use.
train_dataset (`Dataset`):
The training dataset.
eval_dataset (`Dataset`, *optional*):
The evaluation dataset.
model_init (`Callable[[], SetFitModel]`, *optional*):
A function that instantiates the model to be used. If provided, each call to
[`~DistillationTrainer.train`] will start from a new instance of the model as given by this
function when a `trial` is passed.
metric (`str` or `Callable`, *optional*, defaults to `"accuracy"`):
The metric to use for evaluation. If a string is provided, we treat it as the metric
name and load it with default settings.
If a callable is provided, it must take two arguments (`y_pred`, `y_test`).
column_mapping (`Dict[str, str]`, *optional*):
A mapping from the column names in the dataset to the column names expected by the model.
The expected format is a dictionary with the following format:
`{"text_column_name": "text", "label_column_name: "label"}`.
"""
_REQUIRED_COLUMNS = {"text"}
def __init__(
self,
teacher_model: "SetFitModel",
student_model: Optional["SetFitModel"] = None,
args: TrainingArguments = None,
train_dataset: Optional["Dataset"] = None,
eval_dataset: Optional["Dataset"] = None,
model_init: Optional[Callable[[], "SetFitModel"]] = None,
metric: Union[str, Callable[["Dataset", "Dataset"], Dict[str, float]]] = "accuracy",
column_mapping: Optional[Dict[str, str]] = None,
) -> None:
super().__init__(
model=student_model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
model_init=model_init,
metric=metric,
column_mapping=column_mapping,
)
self.teacher_model = teacher_model
self.student_model = self.model
def dataset_to_parameters(self, dataset: Dataset) -> List[Iterable]:
return [dataset["text"]]
def get_dataloader(
self,
x: List[str],
y: Optional[Union[List[int], List[List[int]]]],
args: TrainingArguments,
max_pairs: int = -1,
) -> Tuple[DataLoader, nn.Module, int]:
x_embd_student = self.teacher_model.model_body.encode(
x, convert_to_tensor=self.teacher_model.has_differentiable_head
)
cos_sim_matrix = util.cos_sim(x_embd_student, x_embd_student)
input_data = [InputExample(texts=[text]) for text in x]
data_sampler = ContrastiveDistillationDataset(
input_data, cos_sim_matrix, args.num_iterations, args.sampling_strategy, max_pairs=max_pairs
)
# shuffle_sampler = True can be dropped in for further 'randomising'
shuffle_sampler = True if args.sampling_strategy == "unique" else False
batch_size = min(args.embedding_batch_size, len(data_sampler))
dataloader = DataLoader(data_sampler, batch_size=batch_size, shuffle=shuffle_sampler, drop_last=False)
loss = args.loss(self.model.model_body)
return dataloader, loss, batch_size
def train_classifier(self, x_train: List[str], args: Optional[TrainingArguments] = None) -> None:
"""
Method to perform the classifier phase: fitting the student classifier head.
Args:
x_train (`List[str]`): A list of training sentences.
args (`TrainingArguments`, *optional*):
Temporarily change the training arguments for this training call.
"""
y_train = self.teacher_model.predict(x_train, as_numpy=not self.student_model.has_differentiable_head)
return super().train_classifier(x_train, y_train, args)
class DistillationSetFitTrainer(DistillationTrainer):
"""
`DistillationSetFitTrainer` has been deprecated and will be removed in v2.0.0 of SetFit.
Please use `DistillationTrainer` instead.
"""
def __init__(
self,
teacher_model: "SetFitModel",
student_model: Optional["SetFitModel"] = None,
train_dataset: Optional["Dataset"] = None,
eval_dataset: Optional["Dataset"] = None,
model_init: Optional[Callable[[], "SetFitModel"]] = None,
metric: Union[str, Callable[["Dataset", "Dataset"], Dict[str, float]]] = "accuracy",
loss_class: torch.nn.Module = losses.CosineSimilarityLoss,
num_iterations: int = 20,
num_epochs: int = 1,
learning_rate: float = 2e-5,
batch_size: int = 16,
seed: int = 42,
column_mapping: Optional[Dict[str, str]] = None,
use_amp: bool = False,
warmup_proportion: float = 0.1,
) -> None:
warnings.warn(
"`DistillationSetFitTrainer` has been deprecated and will be removed in v2.0.0 of SetFit. "
"Please use `DistillationTrainer` instead.",
DeprecationWarning,
stacklevel=2,
)
args = TrainingArguments(
num_iterations=num_iterations,
num_epochs=num_epochs,
body_learning_rate=learning_rate,
head_learning_rate=learning_rate,
batch_size=batch_size,
seed=seed,
use_amp=use_amp,
warmup_proportion=warmup_proportion,
loss=loss_class,
)
super().__init__(
teacher_model=teacher_model,
student_model=student_model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
model_init=model_init,
metric=metric,
column_mapping=column_mapping,
)
|