File size: 6,911 Bytes
0fdb130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
from itertools import zip_longest
from typing import Generator, Iterable, List, Optional

import numpy as np
import torch
from sentence_transformers import InputExample
from torch.utils.data import IterableDataset

from . import logging


logging.set_verbosity_info()
logger = logging.get_logger(__name__)


def shuffle_combinations(iterable: Iterable, replacement: bool = True) -> Generator:
    """Generates shuffled pair combinations for any iterable data provided.



    Args:

        iterable: data to generate pair combinations from

        replacement: enable to include combinations of same samples,

            equivalent to itertools.combinations_with_replacement



    Returns:

        Generator of shuffled pairs as a tuple

    """
    n = len(iterable)
    k = 1 if not replacement else 0
    idxs = np.stack(np.triu_indices(n, k), axis=-1)
    for i in np.random.RandomState(seed=42).permutation(len(idxs)):
        _idx, idx = idxs[i, :]
        yield iterable[_idx], iterable[idx]


class ContrastiveDataset(IterableDataset):
    def __init__(

        self,

        examples: List[InputExample],

        multilabel: bool,

        num_iterations: Optional[None] = None,

        sampling_strategy: str = "oversampling",

        max_pairs: int = -1,

    ) -> None:
        """Generates positive and negative text pairs for contrastive learning.



        Args:

            examples (InputExample): text and labels in a text transformer dataclass

            multilabel: set to process "multilabel" labels array

            sampling_strategy: "unique", "oversampling", or "undersampling"

            num_iterations: if provided explicitly sets the number of pairs to be generated

                where n_pairs = n_iterations * n_sentences * 2 (for pos & neg pairs)

            max_pairs: If not -1, then we only sample pairs until we have certainly reached

                max_pairs pairs.

        """
        super().__init__()
        self.pos_index = 0
        self.neg_index = 0
        self.pos_pairs = []
        self.neg_pairs = []
        self.sentences = np.array([s.texts[0] for s in examples])
        self.labels = np.array([s.label for s in examples])
        self.sentence_labels = list(zip(self.sentences, self.labels))
        self.max_pairs = max_pairs

        if multilabel:
            self.generate_multilabel_pairs()
        else:
            self.generate_pairs()

        if num_iterations is not None and num_iterations > 0:
            self.len_pos_pairs = num_iterations * len(self.sentences)
            self.len_neg_pairs = num_iterations * len(self.sentences)

        elif sampling_strategy == "unique":
            self.len_pos_pairs = len(self.pos_pairs)
            self.len_neg_pairs = len(self.neg_pairs)

        elif sampling_strategy == "undersampling":
            self.len_pos_pairs = min(len(self.pos_pairs), len(self.neg_pairs))
            self.len_neg_pairs = min(len(self.pos_pairs), len(self.neg_pairs))

        elif sampling_strategy == "oversampling":
            self.len_pos_pairs = max(len(self.pos_pairs), len(self.neg_pairs))
            self.len_neg_pairs = max(len(self.pos_pairs), len(self.neg_pairs))

        else:
            raise ValueError("Invalid sampling strategy. Must be one of 'unique', 'oversampling', or 'undersampling'.")

    def generate_pairs(self) -> None:
        for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels):
            if _label == label:
                self.pos_pairs.append(InputExample(texts=[_text, text], label=1.0))
            else:
                self.neg_pairs.append(InputExample(texts=[_text, text], label=0.0))
            if self.max_pairs != -1 and len(self.pos_pairs) > self.max_pairs and len(self.neg_pairs) > self.max_pairs:
                break

    def generate_multilabel_pairs(self) -> None:
        for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels):
            if any(np.logical_and(_label, label)):
                # logical_and checks if labels are both set for each class
                self.pos_pairs.append(InputExample(texts=[_text, text], label=1.0))
            else:
                self.neg_pairs.append(InputExample(texts=[_text, text], label=0.0))
            if self.max_pairs != -1 and len(self.pos_pairs) > self.max_pairs and len(self.neg_pairs) > self.max_pairs:
                break

    def get_positive_pairs(self) -> List[InputExample]:
        pairs = []
        for _ in range(self.len_pos_pairs):
            if self.pos_index >= len(self.pos_pairs):
                self.pos_index = 0
            pairs.append(self.pos_pairs[self.pos_index])
            self.pos_index += 1
        return pairs

    def get_negative_pairs(self) -> List[InputExample]:
        pairs = []
        for _ in range(self.len_neg_pairs):
            if self.neg_index >= len(self.neg_pairs):
                self.neg_index = 0
            pairs.append(self.neg_pairs[self.neg_index])
            self.neg_index += 1
        return pairs

    def __iter__(self):
        for pos_pair, neg_pair in zip_longest(self.get_positive_pairs(), self.get_negative_pairs()):
            if pos_pair is not None:
                yield pos_pair
            if neg_pair is not None:
                yield neg_pair

    def __len__(self) -> int:
        return self.len_pos_pairs + self.len_neg_pairs


class ContrastiveDistillationDataset(ContrastiveDataset):
    def __init__(

        self,

        examples: List[InputExample],

        cos_sim_matrix: torch.Tensor,

        num_iterations: Optional[None] = None,

        sampling_strategy: str = "oversampling",

        max_pairs: int = -1,

    ) -> None:
        self.cos_sim_matrix = cos_sim_matrix
        super().__init__(
            examples,
            multilabel=False,
            num_iterations=num_iterations,
            sampling_strategy=sampling_strategy,
            max_pairs=max_pairs,
        )
        # Internally we store all pairs in pos_pairs, regardless of sampling strategy.
        # After all, without labels, there isn't much of a strategy.
        self.sentence_labels = list(enumerate(self.sentences))

        self.len_neg_pairs = 0
        if num_iterations is not None and num_iterations > 0:
            self.len_pos_pairs = num_iterations * len(self.sentences)
        else:
            self.len_pos_pairs = len(self.pos_pairs)

    def generate_pairs(self) -> None:
        for (text_one, id_one), (text_two, id_two) in shuffle_combinations(self.sentence_labels):
            self.pos_pairs.append(InputExample(texts=[text_one, text_two], label=self.cos_sim_matrix[id_one][id_two]))
            if self.max_pairs != -1 and len(self.pos_pairs) > self.max_pairs:
                break