|
import json |
|
import os |
|
from collections import defaultdict |
|
from typing import Any, Dict, Iterable, List, Optional, Union |
|
|
|
import numpy as np |
|
import transformers as tr |
|
from tqdm import tqdm |
|
|
|
|
|
class HardNegativesManager: |
|
def __init__( |
|
self, |
|
tokenizer: tr.PreTrainedTokenizer, |
|
data: Union[List[Dict], os.PathLike, Dict[int, List]] = None, |
|
max_length: int = 64, |
|
batch_size: int = 1000, |
|
lazy: bool = False, |
|
) -> None: |
|
self._db: dict = None |
|
self.tokenizer = tokenizer |
|
|
|
if data is None: |
|
self._db = {} |
|
else: |
|
if isinstance(data, Dict): |
|
self._db = data |
|
elif isinstance(data, os.PathLike): |
|
with open(data) as f: |
|
self._db = json.load(f) |
|
else: |
|
raise ValueError( |
|
f"Data type {type(data)} not supported, only Dict and os.PathLike are supported." |
|
) |
|
|
|
self.tokenizer = tokenizer |
|
|
|
|
|
self._passage_db = defaultdict(set) |
|
for sample_idx, passages in self._db.items(): |
|
for passage in passages: |
|
self._passage_db[passage].add(sample_idx) |
|
|
|
self._passage_hard_negatives = {} |
|
if not lazy: |
|
|
|
batch_size = min(batch_size, len(self._passage_db)) |
|
unique_passages = list(self._passage_db.keys()) |
|
for i in tqdm( |
|
range(0, len(unique_passages), batch_size), |
|
desc="Tokenizing Hard Negatives", |
|
): |
|
batch = unique_passages[i : i + batch_size] |
|
tokenized_passages = self.tokenizer( |
|
batch, |
|
max_length=max_length, |
|
truncation=True, |
|
) |
|
for i, passage in enumerate(batch): |
|
self._passage_hard_negatives[passage] = { |
|
k: tokenized_passages[k][i] for k in tokenized_passages.keys() |
|
} |
|
|
|
def __len__(self) -> int: |
|
return len(self._db) |
|
|
|
def __getitem__(self, idx: int) -> Dict: |
|
return self._db[idx] |
|
|
|
def __iter__(self): |
|
for sample in self._db: |
|
yield sample |
|
|
|
def __contains__(self, idx: int) -> bool: |
|
return idx in self._db |
|
|
|
def get(self, idx: int) -> List[str]: |
|
"""Get the hard negatives for a given sample index.""" |
|
if idx not in self._db: |
|
raise ValueError(f"Sample index {idx} not in the database.") |
|
|
|
passages = self._db[idx] |
|
|
|
output = [] |
|
for passage in passages: |
|
if passage not in self._passage_hard_negatives: |
|
self._passage_hard_negatives[passage] = self._tokenize(passage) |
|
output.append(self._passage_hard_negatives[passage]) |
|
|
|
return output |
|
|
|
def _tokenize(self, passage: str) -> Dict: |
|
return self.tokenizer(passage, max_length=self.max_length, truncation=True) |
|
|
|
|
|
class NegativeSampler: |
|
def __init__( |
|
self, num_elements: int, probabilities: Optional[Union[List, np.ndarray]] = None |
|
): |
|
if not isinstance(probabilities, np.ndarray): |
|
probabilities = np.array(probabilities) |
|
|
|
if probabilities is None: |
|
|
|
probabilities = np.random.random(num_elements) |
|
probabilities /= np.sum(probabilities) |
|
self.probabilities = probabilities |
|
|
|
def __call__( |
|
self, |
|
sample_size: int, |
|
num_samples: int = 1, |
|
probabilities: np.array = None, |
|
exclude: List[int] = None, |
|
) -> np.array: |
|
""" |
|
Fast sampling of `sample_size` elements from `num_elements` elements. |
|
The sampling is done by randomly shifting the probabilities and then |
|
finding the smallest of the negative numbers. This is much faster than |
|
sampling from a multinomial distribution. |
|
|
|
Args: |
|
sample_size (`int`): |
|
number of elements to sample |
|
num_samples (`int`, optional): |
|
number of samples to draw. Defaults to 1. |
|
probabilities (`np.array`, optional): |
|
probabilities of each element. Defaults to None. |
|
exclude (`List[int]`, optional): |
|
indices of elements to exclude. Defaults to None. |
|
|
|
Returns: |
|
`np.array`: array of sampled indices |
|
""" |
|
if probabilities is None: |
|
probabilities = self.probabilities |
|
|
|
if exclude is not None: |
|
probabilities[exclude] = 0 |
|
|
|
|
|
|
|
|
|
replicated_probabilities = np.tile(probabilities, (num_samples, 1)) |
|
|
|
random_shifts = np.random.random(replicated_probabilities.shape) |
|
random_shifts /= random_shifts.sum(axis=1)[:, np.newaxis] |
|
|
|
shifted_probabilities = random_shifts - replicated_probabilities |
|
sampled_indices = np.argpartition(shifted_probabilities, sample_size, axis=1)[ |
|
:, :sample_size |
|
] |
|
return sampled_indices |
|
|
|
|
|
def batch_generator(samples: Iterable[Any], batch_size: int) -> Iterable[Any]: |
|
""" |
|
Generate batches from samples. |
|
|
|
Args: |
|
samples (`Iterable[Any]`): Iterable of samples. |
|
batch_size (`int`): Batch size. |
|
|
|
Returns: |
|
`Iterable[Any]`: Iterable of batches. |
|
""" |
|
batch = [] |
|
for sample in samples: |
|
batch.append(sample) |
|
if len(batch) == batch_size: |
|
yield batch |
|
batch = [] |
|
|
|
|
|
if len(batch) > 0: |
|
yield batch |
|
|