File size: 6,838 Bytes
0fdb130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import warnings
from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple, Union

import torch
from datasets import Dataset
from sentence_transformers import InputExample, losses, util
from torch import nn
from torch.utils.data import DataLoader

from . import logging
from .sampler import ContrastiveDistillationDataset
from .trainer import Trainer
from .training_args import TrainingArguments


if TYPE_CHECKING:
    from .modeling import SetFitModel

logging.set_verbosity_info()
logger = logging.get_logger(__name__)


class DistillationTrainer(Trainer):
    """Trainer to compress a SetFit model with knowledge distillation.



    Args:

        teacher_model (`SetFitModel`):

            The teacher model to mimic.

        student_model (`SetFitModel`, *optional*):

            The model to train. If not provided, a `model_init` must be passed.

        args (`TrainingArguments`, *optional*):

            The training arguments to use.

        train_dataset (`Dataset`):

            The training dataset.

        eval_dataset (`Dataset`, *optional*):

            The evaluation dataset.

        model_init (`Callable[[], SetFitModel]`, *optional*):

            A function that instantiates the model to be used. If provided, each call to

            [`~DistillationTrainer.train`] will start from a new instance of the model as given by this

            function when a `trial` is passed.

        metric (`str` or `Callable`, *optional*, defaults to `"accuracy"`):

            The metric to use for evaluation. If a string is provided, we treat it as the metric

            name and load it with default settings.

            If a callable is provided, it must take two arguments (`y_pred`, `y_test`).

        column_mapping (`Dict[str, str]`, *optional*):

            A mapping from the column names in the dataset to the column names expected by the model.

            The expected format is a dictionary with the following format:

            `{"text_column_name": "text", "label_column_name: "label"}`.

    """

    _REQUIRED_COLUMNS = {"text"}

    def __init__(

        self,

        teacher_model: "SetFitModel",

        student_model: Optional["SetFitModel"] = None,

        args: TrainingArguments = None,

        train_dataset: Optional["Dataset"] = None,

        eval_dataset: Optional["Dataset"] = None,

        model_init: Optional[Callable[[], "SetFitModel"]] = None,

        metric: Union[str, Callable[["Dataset", "Dataset"], Dict[str, float]]] = "accuracy",

        column_mapping: Optional[Dict[str, str]] = None,

    ) -> None:
        super().__init__(
            model=student_model,
            args=args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            model_init=model_init,
            metric=metric,
            column_mapping=column_mapping,
        )

        self.teacher_model = teacher_model
        self.student_model = self.model

    def dataset_to_parameters(self, dataset: Dataset) -> List[Iterable]:
        return [dataset["text"]]

    def get_dataloader(

        self,

        x: List[str],

        y: Optional[Union[List[int], List[List[int]]]],

        args: TrainingArguments,

        max_pairs: int = -1,

    ) -> Tuple[DataLoader, nn.Module, int]:
        x_embd_student = self.teacher_model.model_body.encode(
            x, convert_to_tensor=self.teacher_model.has_differentiable_head
        )
        cos_sim_matrix = util.cos_sim(x_embd_student, x_embd_student)

        input_data = [InputExample(texts=[text]) for text in x]
        data_sampler = ContrastiveDistillationDataset(
            input_data, cos_sim_matrix, args.num_iterations, args.sampling_strategy, max_pairs=max_pairs
        )
        # shuffle_sampler = True can be dropped in for further 'randomising'
        shuffle_sampler = True if args.sampling_strategy == "unique" else False
        batch_size = min(args.embedding_batch_size, len(data_sampler))
        dataloader = DataLoader(data_sampler, batch_size=batch_size, shuffle=shuffle_sampler, drop_last=False)
        loss = args.loss(self.model.model_body)
        return dataloader, loss, batch_size

    def train_classifier(self, x_train: List[str], args: Optional[TrainingArguments] = None) -> None:
        """

        Method to perform the classifier phase: fitting the student classifier head.



        Args:

            x_train (`List[str]`): A list of training sentences.

            args (`TrainingArguments`, *optional*):

                Temporarily change the training arguments for this training call.

        """
        y_train = self.teacher_model.predict(x_train, as_numpy=not self.student_model.has_differentiable_head)
        return super().train_classifier(x_train, y_train, args)


class DistillationSetFitTrainer(DistillationTrainer):
    """

    `DistillationSetFitTrainer` has been deprecated and will be removed in v2.0.0 of SetFit.

    Please use `DistillationTrainer` instead.

    """

    def __init__(

        self,

        teacher_model: "SetFitModel",

        student_model: Optional["SetFitModel"] = None,

        train_dataset: Optional["Dataset"] = None,

        eval_dataset: Optional["Dataset"] = None,

        model_init: Optional[Callable[[], "SetFitModel"]] = None,

        metric: Union[str, Callable[["Dataset", "Dataset"], Dict[str, float]]] = "accuracy",

        loss_class: torch.nn.Module = losses.CosineSimilarityLoss,

        num_iterations: int = 20,

        num_epochs: int = 1,

        learning_rate: float = 2e-5,

        batch_size: int = 16,

        seed: int = 42,

        column_mapping: Optional[Dict[str, str]] = None,

        use_amp: bool = False,

        warmup_proportion: float = 0.1,

    ) -> None:
        warnings.warn(
            "`DistillationSetFitTrainer` has been deprecated and will be removed in v2.0.0 of SetFit. "
            "Please use `DistillationTrainer` instead.",
            DeprecationWarning,
            stacklevel=2,
        )
        args = TrainingArguments(
            num_iterations=num_iterations,
            num_epochs=num_epochs,
            body_learning_rate=learning_rate,
            head_learning_rate=learning_rate,
            batch_size=batch_size,
            seed=seed,
            use_amp=use_amp,
            warmup_proportion=warmup_proportion,
            loss=loss_class,
        )
        super().__init__(
            teacher_model=teacher_model,
            student_model=student_model,
            args=args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            model_init=model_init,
            metric=metric,
            column_mapping=column_mapping,
        )