from itertools import zip_longest from typing import Generator, Iterable, List, Optional import numpy as np import torch from sentence_transformers import InputExample from torch.utils.data import IterableDataset from . import logging logging.set_verbosity_info() logger = logging.get_logger(__name__) def shuffle_combinations(iterable: Iterable, replacement: bool = True) -> Generator: """Generates shuffled pair combinations for any iterable data provided. Args: iterable: data to generate pair combinations from replacement: enable to include combinations of same samples, equivalent to itertools.combinations_with_replacement Returns: Generator of shuffled pairs as a tuple """ n = len(iterable) k = 1 if not replacement else 0 idxs = np.stack(np.triu_indices(n, k), axis=-1) for i in np.random.RandomState(seed=42).permutation(len(idxs)): _idx, idx = idxs[i, :] yield iterable[_idx], iterable[idx] class ContrastiveDataset(IterableDataset): def __init__( self, examples: List[InputExample], multilabel: bool, num_iterations: Optional[None] = None, sampling_strategy: str = "oversampling", max_pairs: int = -1, ) -> None: """Generates positive and negative text pairs for contrastive learning. Args: examples (InputExample): text and labels in a text transformer dataclass multilabel: set to process "multilabel" labels array sampling_strategy: "unique", "oversampling", or "undersampling" num_iterations: if provided explicitly sets the number of pairs to be generated where n_pairs = n_iterations * n_sentences * 2 (for pos & neg pairs) max_pairs: If not -1, then we only sample pairs until we have certainly reached max_pairs pairs. """ super().__init__() self.pos_index = 0 self.neg_index = 0 self.pos_pairs = [] self.neg_pairs = [] self.sentences = np.array([s.texts[0] for s in examples]) self.labels = np.array([s.label for s in examples]) self.sentence_labels = list(zip(self.sentences, self.labels)) self.max_pairs = max_pairs if multilabel: self.generate_multilabel_pairs() else: self.generate_pairs() if num_iterations is not None and num_iterations > 0: self.len_pos_pairs = num_iterations * len(self.sentences) self.len_neg_pairs = num_iterations * len(self.sentences) elif sampling_strategy == "unique": self.len_pos_pairs = len(self.pos_pairs) self.len_neg_pairs = len(self.neg_pairs) elif sampling_strategy == "undersampling": self.len_pos_pairs = min(len(self.pos_pairs), len(self.neg_pairs)) self.len_neg_pairs = min(len(self.pos_pairs), len(self.neg_pairs)) elif sampling_strategy == "oversampling": self.len_pos_pairs = max(len(self.pos_pairs), len(self.neg_pairs)) self.len_neg_pairs = max(len(self.pos_pairs), len(self.neg_pairs)) else: raise ValueError("Invalid sampling strategy. Must be one of 'unique', 'oversampling', or 'undersampling'.") def generate_pairs(self) -> None: for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels): if _label == label: self.pos_pairs.append(InputExample(texts=[_text, text], label=1.0)) else: self.neg_pairs.append(InputExample(texts=[_text, text], label=0.0)) if self.max_pairs != -1 and len(self.pos_pairs) > self.max_pairs and len(self.neg_pairs) > self.max_pairs: break def generate_multilabel_pairs(self) -> None: for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels): if any(np.logical_and(_label, label)): # logical_and checks if labels are both set for each class self.pos_pairs.append(InputExample(texts=[_text, text], label=1.0)) else: self.neg_pairs.append(InputExample(texts=[_text, text], label=0.0)) if self.max_pairs != -1 and len(self.pos_pairs) > self.max_pairs and len(self.neg_pairs) > self.max_pairs: break def get_positive_pairs(self) -> List[InputExample]: pairs = [] for _ in range(self.len_pos_pairs): if self.pos_index >= len(self.pos_pairs): self.pos_index = 0 pairs.append(self.pos_pairs[self.pos_index]) self.pos_index += 1 return pairs def get_negative_pairs(self) -> List[InputExample]: pairs = [] for _ in range(self.len_neg_pairs): if self.neg_index >= len(self.neg_pairs): self.neg_index = 0 pairs.append(self.neg_pairs[self.neg_index]) self.neg_index += 1 return pairs def __iter__(self): for pos_pair, neg_pair in zip_longest(self.get_positive_pairs(), self.get_negative_pairs()): if pos_pair is not None: yield pos_pair if neg_pair is not None: yield neg_pair def __len__(self) -> int: return self.len_pos_pairs + self.len_neg_pairs class ContrastiveDistillationDataset(ContrastiveDataset): def __init__( self, examples: List[InputExample], cos_sim_matrix: torch.Tensor, num_iterations: Optional[None] = None, sampling_strategy: str = "oversampling", max_pairs: int = -1, ) -> None: self.cos_sim_matrix = cos_sim_matrix super().__init__( examples, multilabel=False, num_iterations=num_iterations, sampling_strategy=sampling_strategy, max_pairs=max_pairs, ) # Internally we store all pairs in pos_pairs, regardless of sampling strategy. # After all, without labels, there isn't much of a strategy. self.sentence_labels = list(enumerate(self.sentences)) self.len_neg_pairs = 0 if num_iterations is not None and num_iterations > 0: self.len_pos_pairs = num_iterations * len(self.sentences) else: self.len_pos_pairs = len(self.pos_pairs) def generate_pairs(self) -> None: for (text_one, id_one), (text_two, id_two) in shuffle_combinations(self.sentence_labels): self.pos_pairs.append(InputExample(texts=[text_one, text_two], label=self.cos_sim_matrix[id_one][id_two])) if self.max_pairs != -1 and len(self.pos_pairs) > self.max_pairs: break