File size: 15,941 Bytes
0fdb130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

from datasets import Dataset
from transformers.trainer_callback import TrainerCallback

from setfit.span.modeling import AbsaModel, AspectModel, PolarityModel
from setfit.training_args import TrainingArguments

from .. import logging
from ..trainer import ColumnMappingMixin, Trainer


if TYPE_CHECKING:
    import optuna

logger = logging.get_logger(__name__)


class AbsaTrainer(ColumnMappingMixin):
    """Trainer to train a SetFit ABSA model.



    Args:

        model (`AbsaModel`):

            The AbsaModel model to train.

        args (`TrainingArguments`, *optional*):

            The training arguments to use. If `polarity_args` is not defined, then `args` is used for both

            the aspect and the polarity model.

        polarity_args (`TrainingArguments`, *optional*):

            The training arguments to use for the polarity model. If not defined, `args` is used for both

            the aspect and the polarity model.

        train_dataset (`Dataset`):

            The training dataset. The dataset must have "text", "span", "label" and "ordinal" columns.

        eval_dataset (`Dataset`, *optional*):

            The evaluation dataset. The dataset must have "text", "span", "label" and "ordinal" columns.

        metric (`str` or `Callable`, *optional*, defaults to `"accuracy"`):

            The metric to use for evaluation. If a string is provided, we treat it as the metric

            name and load it with default settings.

            If a callable is provided, it must take two arguments (`y_pred`, `y_test`).

        metric_kwargs (`Dict[str, Any]`, *optional*):

            Keyword arguments passed to the evaluation function if `metric` is an evaluation string like "f1".

            For example useful for providing an averaging strategy for computing f1 in a multi-label setting.

        callbacks (`List[`[`~transformers.TrainerCallback`]`]`, *optional*):

            A list of callbacks to customize the training loop. Will add those to the list of default callbacks

            detailed in [here](https://huggingface.co/docs/transformers/main/en/main_classes/callback).

            If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method.

        column_mapping (`Dict[str, str]`, *optional*):

            A mapping from the column names in the dataset to the column names expected by the model.

            The expected format is a dictionary with the following format:

            `{"text_column_name": "text", "span_column_name": "span", "label_column_name: "label", "ordinal_column_name": "ordinal"}`.

    """

    _REQUIRED_COLUMNS = {"text", "span", "label", "ordinal"}

    def __init__(

        self,

        model: AbsaModel,

        args: Optional[TrainingArguments] = None,

        polarity_args: Optional[TrainingArguments] = None,

        train_dataset: Optional["Dataset"] = None,

        eval_dataset: Optional["Dataset"] = None,

        metric: Union[str, Callable[["Dataset", "Dataset"], Dict[str, float]]] = "accuracy",

        metric_kwargs: Optional[Dict[str, Any]] = None,

        callbacks: Optional[List[TrainerCallback]] = None,

        column_mapping: Optional[Dict[str, str]] = None,

    ) -> None:
        self.model = model
        self.aspect_extractor = model.aspect_extractor

        if train_dataset is not None and column_mapping:
            train_dataset = self._apply_column_mapping(train_dataset, column_mapping)
        aspect_train_dataset, polarity_train_dataset = self.preprocess_dataset(
            model.aspect_model, model.polarity_model, train_dataset
        )
        if eval_dataset is not None and column_mapping:
            eval_dataset = self._apply_column_mapping(eval_dataset, column_mapping)
        aspect_eval_dataset, polarity_eval_dataset = self.preprocess_dataset(
            model.aspect_model, model.polarity_model, eval_dataset
        )

        self.aspect_trainer = Trainer(
            model.aspect_model,
            args=args,
            train_dataset=aspect_train_dataset,
            eval_dataset=aspect_eval_dataset,
            metric=metric,
            metric_kwargs=metric_kwargs,
            callbacks=callbacks,
        )
        self.aspect_trainer._set_logs_mapper(
            {
                "eval_embedding_loss": "eval_aspect_embedding_loss",
                "embedding_loss": "aspect_embedding_loss",
            }
        )
        self.polarity_trainer = Trainer(
            model.polarity_model,
            args=polarity_args or args,
            train_dataset=polarity_train_dataset,
            eval_dataset=polarity_eval_dataset,
            metric=metric,
            metric_kwargs=metric_kwargs,
            callbacks=callbacks,
        )
        self.polarity_trainer._set_logs_mapper(
            {
                "eval_embedding_loss": "eval_polarity_embedding_loss",
                "embedding_loss": "polarity_embedding_loss",
            }
        )

    def preprocess_dataset(

        self, aspect_model: AspectModel, polarity_model: PolarityModel, dataset: Dataset

    ) -> Dataset:
        if dataset is None:
            return dataset, dataset

        # Group by "text"
        grouped_data = defaultdict(list)
        for sample in dataset:
            text = sample.pop("text")
            grouped_data[text].append(sample)

        def index_ordinal(text: str, target: str, ordinal: int) -> Tuple[int, int]:
            find_from = 0
            for _ in range(ordinal + 1):
                start_idx = text.index(target, find_from)
                find_from = start_idx + 1
            return start_idx, start_idx + len(target)

        def overlaps(aspect: slice, aspects: List[slice]) -> bool:
            for test_aspect in aspects:
                overlapping_indices = set(range(aspect.start, aspect.stop + 1)) & set(
                    range(test_aspect.start, test_aspect.stop + 1)
                )
                if overlapping_indices:
                    return True
            return False

        docs, aspects_list = self.aspect_extractor(grouped_data.keys())
        aspect_aspect_list = []
        aspect_labels = []
        polarity_aspect_list = []
        polarity_labels = []
        for doc, aspects, text in zip(docs, aspects_list, grouped_data):
            # Collect all of the gold aspects
            gold_aspects = []
            gold_polarity_labels = []
            for annotation in grouped_data[text]:
                try:
                    start, end = index_ordinal(text, annotation["span"], annotation["ordinal"])
                except ValueError:
                    logger.info(
                        f"The ordinal of {annotation['ordinal']} for span {annotation['span']!r} in {text!r} is too high. "
                        "Skipping this sample."
                    )
                    continue

                gold_aspect_span = doc.char_span(start, end)
                if gold_aspect_span is None:
                    continue
                gold_aspects.append(slice(gold_aspect_span.start, gold_aspect_span.end))
                gold_polarity_labels.append(annotation["label"])

            # The Aspect model uses all gold aspects as "True", and all non-overlapping predicted
            # aspects as "False"
            aspect_labels.extend([True] * len(gold_aspects))
            aspect_aspect_list.append(gold_aspects[:])
            for aspect in aspects:
                if not overlaps(aspect, gold_aspects):
                    aspect_labels.append(False)
                    aspect_aspect_list[-1].append(aspect)

            # The Polarity model uses only the gold aspects and labels
            polarity_labels.extend(gold_polarity_labels)
            polarity_aspect_list.append(gold_aspects)

        aspect_texts = list(aspect_model.prepend_aspects(docs, aspect_aspect_list))
        polarity_texts = list(polarity_model.prepend_aspects(docs, polarity_aspect_list))
        return Dataset.from_dict({"text": aspect_texts, "label": aspect_labels}), Dataset.from_dict(
            {"text": polarity_texts, "label": polarity_labels}
        )

    def train(

        self,

        args: Optional[TrainingArguments] = None,

        polarity_args: Optional[TrainingArguments] = None,

        trial: Optional[Union["optuna.Trial", Dict[str, Any]]] = None,

        **kwargs,

    ) -> None:
        """

        Main training entry point.



        Args:

            args (`TrainingArguments`, *optional*):

                Temporarily change the aspect training arguments for this training call.

            polarity_args (`TrainingArguments`, *optional*):

                Temporarily change the polarity training arguments for this training call.

            trial (`optuna.Trial` or `Dict[str, Any]`, *optional*):

                The trial run or the hyperparameter dictionary for hyperparameter search.

        """
        self.train_aspect(args=args, trial=trial, **kwargs)
        self.train_polarity(args=polarity_args, trial=trial, **kwargs)

    def train_aspect(

        self,

        args: Optional[TrainingArguments] = None,

        trial: Optional[Union["optuna.Trial", Dict[str, Any]]] = None,

        **kwargs,

    ) -> None:
        """

        Train the aspect model only.



        Args:

            args (`TrainingArguments`, *optional*):

                Temporarily change the aspect training arguments for this training call.

            trial (`optuna.Trial` or `Dict[str, Any]`, *optional*):

                The trial run or the hyperparameter dictionary for hyperparameter search.

        """
        self.aspect_trainer.train(args=args, trial=trial, **kwargs)

    def train_polarity(

        self,

        args: Optional[TrainingArguments] = None,

        trial: Optional[Union["optuna.Trial", Dict[str, Any]]] = None,

        **kwargs,

    ) -> None:
        """

        Train the polarity model only.



        Args:

            args (`TrainingArguments`, *optional*):

                Temporarily change the aspect training arguments for this training call.

            trial (`optuna.Trial` or `Dict[str, Any]`, *optional*):

                The trial run or the hyperparameter dictionary for hyperparameter search.

        """
        self.polarity_trainer.train(args=args, trial=trial, **kwargs)

    def add_callback(self, callback: Union[type, TrainerCallback]) -> None:
        """

        Add a callback to the current list of [`~transformers.TrainerCallback`].



        Args:

            callback (`type` or [`~transformers.TrainerCallback`]):

                A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the

                first case, will instantiate a member of that class.

        """
        self.aspect_trainer.add_callback(callback)
        self.polarity_trainer.add_callback(callback)

    def pop_callback(self, callback: Union[type, TrainerCallback]) -> Tuple[TrainerCallback, TrainerCallback]:
        """

        Remove a callback from the current list of [`~transformers.TrainerCallback`] and returns it.



        If the callback is not found, returns `None` (and no error is raised).



        Args:

            callback (`type` or [`~transformers.TrainerCallback`]):

                A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the

                first case, will pop the first member of that class found in the list of callbacks.



        Returns:

            `Tuple[`[`~transformers.TrainerCallback`], [`~transformers.TrainerCallback`]`]`: The callbacks removed from the

                aspect and polarity trainers, if found.

        """
        return self.aspect_trainer.pop_callback(callback), self.polarity_trainer.pop_callback(callback)

    def remove_callback(self, callback: Union[type, TrainerCallback]) -> None:
        """

        Remove a callback from the current list of [`~transformers.TrainerCallback`].



        Args:

            callback (`type` or [`~transformers.TrainerCallback`]):

                A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the

                first case, will remove the first member of that class found in the list of callbacks.

        """
        self.aspect_trainer.remove_callback(callback)
        self.polarity_trainer.remove_callback(callback)

    def push_to_hub(self, repo_id: str, polarity_repo_id: Optional[str] = None, **kwargs) -> None:
        """Upload model checkpoint to the Hub using `huggingface_hub`.



        See the full list of parameters for your `huggingface_hub` version in the\

        [huggingface_hub documentation](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.ModelHubMixin.push_to_hub).



        Args:

            repo_id (`str`):

                The full repository ID to push to, e.g. `"tomaarsen/setfit-aspect"`.

            repo_id (`str`):

                The full repository ID to push to, e.g. `"tomaarsen/setfit-sst2"`.

            config (`dict`, *optional*):

                Configuration object to be saved alongside the model weights.

            commit_message (`str`, *optional*):

                Message to commit while pushing.

            private (`bool`, *optional*, defaults to `False`):

                Whether the repository created should be private.

            api_endpoint (`str`, *optional*):

                The API endpoint to use when pushing the model to the hub.

            token (`str`, *optional*):

                The token to use as HTTP bearer authorization for remote files.

                If not set, will use the token set when logging in with

                `transformers-cli login` (stored in `~/.huggingface`).

            branch (`str`, *optional*):

                The git branch on which to push the model. This defaults to

                the default branch as specified in your repository, which

                defaults to `"main"`.

            create_pr (`boolean`, *optional*):

                Whether or not to create a Pull Request from `branch` with that commit.

                Defaults to `False`.

            allow_patterns (`List[str]` or `str`, *optional*):

                If provided, only files matching at least one pattern are pushed.

            ignore_patterns (`List[str]` or `str`, *optional*):

                If provided, files matching any of the patterns are not pushed.

        """
        return self.model.push_to_hub(repo_id=repo_id, polarity_repo_id=polarity_repo_id, **kwargs)

    def evaluate(self, dataset: Optional[Dataset] = None) -> Dict[str, Dict[str, float]]:
        """

        Computes the metrics for a given classifier.



        Args:

            dataset (`Dataset`, *optional*):

                The dataset to compute the metrics on. If not provided, will use the evaluation dataset passed via

                the `eval_dataset` argument at `Trainer` initialization.



        Returns:

            `Dict[str, Dict[str, float]]`: The evaluation metrics.

        """
        aspect_eval_dataset = polarity_eval_dataset = None
        if dataset:
            aspect_eval_dataset, polarity_eval_dataset = self.preprocess_dataset(
                self.model.aspect_model, self.model.polarity_model, dataset
            )
        return {
            "aspect": self.aspect_trainer.evaluate(aspect_eval_dataset),
            "polarity": self.polarity_trainer.evaluate(polarity_eval_dataset),
        }