svystun-taras's picture
created the updated web ui
0fdb130
raw
history blame
6.91 kB
from itertools import zip_longest
from typing import Generator, Iterable, List, Optional
import numpy as np
import torch
from sentence_transformers import InputExample
from torch.utils.data import IterableDataset
from . import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def shuffle_combinations(iterable: Iterable, replacement: bool = True) -> Generator:
"""Generates shuffled pair combinations for any iterable data provided.
Args:
iterable: data to generate pair combinations from
replacement: enable to include combinations of same samples,
equivalent to itertools.combinations_with_replacement
Returns:
Generator of shuffled pairs as a tuple
"""
n = len(iterable)
k = 1 if not replacement else 0
idxs = np.stack(np.triu_indices(n, k), axis=-1)
for i in np.random.RandomState(seed=42).permutation(len(idxs)):
_idx, idx = idxs[i, :]
yield iterable[_idx], iterable[idx]
class ContrastiveDataset(IterableDataset):
def __init__(
self,
examples: List[InputExample],
multilabel: bool,
num_iterations: Optional[None] = None,
sampling_strategy: str = "oversampling",
max_pairs: int = -1,
) -> None:
"""Generates positive and negative text pairs for contrastive learning.
Args:
examples (InputExample): text and labels in a text transformer dataclass
multilabel: set to process "multilabel" labels array
sampling_strategy: "unique", "oversampling", or "undersampling"
num_iterations: if provided explicitly sets the number of pairs to be generated
where n_pairs = n_iterations * n_sentences * 2 (for pos & neg pairs)
max_pairs: If not -1, then we only sample pairs until we have certainly reached
max_pairs pairs.
"""
super().__init__()
self.pos_index = 0
self.neg_index = 0
self.pos_pairs = []
self.neg_pairs = []
self.sentences = np.array([s.texts[0] for s in examples])
self.labels = np.array([s.label for s in examples])
self.sentence_labels = list(zip(self.sentences, self.labels))
self.max_pairs = max_pairs
if multilabel:
self.generate_multilabel_pairs()
else:
self.generate_pairs()
if num_iterations is not None and num_iterations > 0:
self.len_pos_pairs = num_iterations * len(self.sentences)
self.len_neg_pairs = num_iterations * len(self.sentences)
elif sampling_strategy == "unique":
self.len_pos_pairs = len(self.pos_pairs)
self.len_neg_pairs = len(self.neg_pairs)
elif sampling_strategy == "undersampling":
self.len_pos_pairs = min(len(self.pos_pairs), len(self.neg_pairs))
self.len_neg_pairs = min(len(self.pos_pairs), len(self.neg_pairs))
elif sampling_strategy == "oversampling":
self.len_pos_pairs = max(len(self.pos_pairs), len(self.neg_pairs))
self.len_neg_pairs = max(len(self.pos_pairs), len(self.neg_pairs))
else:
raise ValueError("Invalid sampling strategy. Must be one of 'unique', 'oversampling', or 'undersampling'.")
def generate_pairs(self) -> None:
for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels):
if _label == label:
self.pos_pairs.append(InputExample(texts=[_text, text], label=1.0))
else:
self.neg_pairs.append(InputExample(texts=[_text, text], label=0.0))
if self.max_pairs != -1 and len(self.pos_pairs) > self.max_pairs and len(self.neg_pairs) > self.max_pairs:
break
def generate_multilabel_pairs(self) -> None:
for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels):
if any(np.logical_and(_label, label)):
# logical_and checks if labels are both set for each class
self.pos_pairs.append(InputExample(texts=[_text, text], label=1.0))
else:
self.neg_pairs.append(InputExample(texts=[_text, text], label=0.0))
if self.max_pairs != -1 and len(self.pos_pairs) > self.max_pairs and len(self.neg_pairs) > self.max_pairs:
break
def get_positive_pairs(self) -> List[InputExample]:
pairs = []
for _ in range(self.len_pos_pairs):
if self.pos_index >= len(self.pos_pairs):
self.pos_index = 0
pairs.append(self.pos_pairs[self.pos_index])
self.pos_index += 1
return pairs
def get_negative_pairs(self) -> List[InputExample]:
pairs = []
for _ in range(self.len_neg_pairs):
if self.neg_index >= len(self.neg_pairs):
self.neg_index = 0
pairs.append(self.neg_pairs[self.neg_index])
self.neg_index += 1
return pairs
def __iter__(self):
for pos_pair, neg_pair in zip_longest(self.get_positive_pairs(), self.get_negative_pairs()):
if pos_pair is not None:
yield pos_pair
if neg_pair is not None:
yield neg_pair
def __len__(self) -> int:
return self.len_pos_pairs + self.len_neg_pairs
class ContrastiveDistillationDataset(ContrastiveDataset):
def __init__(
self,
examples: List[InputExample],
cos_sim_matrix: torch.Tensor,
num_iterations: Optional[None] = None,
sampling_strategy: str = "oversampling",
max_pairs: int = -1,
) -> None:
self.cos_sim_matrix = cos_sim_matrix
super().__init__(
examples,
multilabel=False,
num_iterations=num_iterations,
sampling_strategy=sampling_strategy,
max_pairs=max_pairs,
)
# Internally we store all pairs in pos_pairs, regardless of sampling strategy.
# After all, without labels, there isn't much of a strategy.
self.sentence_labels = list(enumerate(self.sentences))
self.len_neg_pairs = 0
if num_iterations is not None and num_iterations > 0:
self.len_pos_pairs = num_iterations * len(self.sentences)
else:
self.len_pos_pairs = len(self.pos_pairs)
def generate_pairs(self) -> None:
for (text_one, id_one), (text_two, id_two) in shuffle_combinations(self.sentence_labels):
self.pos_pairs.append(InputExample(texts=[text_one, text_two], label=self.cos_sim_matrix[id_one][id_two]))
if self.max_pairs != -1 and len(self.pos_pairs) > self.max_pairs:
break