import json import os from collections import defaultdict from typing import Any, Dict, Iterable, List, Optional, Union import numpy as np import transformers as tr from tqdm import tqdm class HardNegativesManager: def __init__( self, tokenizer: tr.PreTrainedTokenizer, data: Union[List[Dict], os.PathLike, Dict[int, List]] = None, max_length: int = 64, batch_size: int = 1000, lazy: bool = False, ) -> None: self._db: dict = None self.tokenizer = tokenizer if data is None: self._db = {} else: if isinstance(data, Dict): self._db = data elif isinstance(data, os.PathLike): with open(data) as f: self._db = json.load(f) else: raise ValueError( f"Data type {type(data)} not supported, only Dict and os.PathLike are supported." ) # add the tokenizer to the class for future use self.tokenizer = tokenizer # invert the db to have a passage -> sample_idx mapping self._passage_db = defaultdict(set) for sample_idx, passages in self._db.items(): for passage in passages: self._passage_db[passage].add(sample_idx) self._passage_hard_negatives = {} if not lazy: # create a dictionary of passage -> hard_negative mapping batch_size = min(batch_size, len(self._passage_db)) unique_passages = list(self._passage_db.keys()) for i in tqdm( range(0, len(unique_passages), batch_size), desc="Tokenizing Hard Negatives", ): batch = unique_passages[i : i + batch_size] tokenized_passages = self.tokenizer( batch, max_length=max_length, truncation=True, ) for i, passage in enumerate(batch): self._passage_hard_negatives[passage] = { k: tokenized_passages[k][i] for k in tokenized_passages.keys() } def __len__(self) -> int: return len(self._db) def __getitem__(self, idx: int) -> Dict: return self._db[idx] def __iter__(self): for sample in self._db: yield sample def __contains__(self, idx: int) -> bool: return idx in self._db def get(self, idx: int) -> List[str]: """Get the hard negatives for a given sample index.""" if idx not in self._db: raise ValueError(f"Sample index {idx} not in the database.") passages = self._db[idx] output = [] for passage in passages: if passage not in self._passage_hard_negatives: self._passage_hard_negatives[passage] = self._tokenize(passage) output.append(self._passage_hard_negatives[passage]) return output def _tokenize(self, passage: str) -> Dict: return self.tokenizer(passage, max_length=self.max_length, truncation=True) class NegativeSampler: def __init__( self, num_elements: int, probabilities: Optional[Union[List, np.ndarray]] = None ): if not isinstance(probabilities, np.ndarray): probabilities = np.array(probabilities) if probabilities is None: # probabilities should sum to 1 probabilities = np.random.random(num_elements) probabilities /= np.sum(probabilities) self.probabilities = probabilities def __call__( self, sample_size: int, num_samples: int = 1, probabilities: np.array = None, exclude: List[int] = None, ) -> np.array: """ Fast sampling of `sample_size` elements from `num_elements` elements. The sampling is done by randomly shifting the probabilities and then finding the smallest of the negative numbers. This is much faster than sampling from a multinomial distribution. Args: sample_size (`int`): number of elements to sample num_samples (`int`, optional): number of samples to draw. Defaults to 1. probabilities (`np.array`, optional): probabilities of each element. Defaults to None. exclude (`List[int]`, optional): indices of elements to exclude. Defaults to None. Returns: `np.array`: array of sampled indices """ if probabilities is None: probabilities = self.probabilities if exclude is not None: probabilities[exclude] = 0 # re-normalize? # probabilities /= np.sum(probabilities) # replicate probabilities as many times as `num_samples` replicated_probabilities = np.tile(probabilities, (num_samples, 1)) # get random shifting numbers & scale them correctly random_shifts = np.random.random(replicated_probabilities.shape) random_shifts /= random_shifts.sum(axis=1)[:, np.newaxis] # shift by numbers & find largest (by finding the smallest of the negative) shifted_probabilities = random_shifts - replicated_probabilities sampled_indices = np.argpartition(shifted_probabilities, sample_size, axis=1)[ :, :sample_size ] return sampled_indices def batch_generator(samples: Iterable[Any], batch_size: int) -> Iterable[Any]: """ Generate batches from samples. Args: samples (`Iterable[Any]`): Iterable of samples. batch_size (`int`): Batch size. Returns: `Iterable[Any]`: Iterable of batches. """ batch = [] for sample in samples: batch.append(sample) if len(batch) == batch_size: yield batch batch = [] # leftover batch if len(batch) > 0: yield batch