|
""" |
|
This file contains utility functions for the FLMR model. Some of these functions are adapted from the original ColBERT codebase. |
|
""" |
|
|
|
import torch |
|
import torch.distributed as dist |
|
|
|
|
|
def get_rank(): |
|
return dist.get_rank() |
|
|
|
|
|
def get_world_size(): |
|
return dist.get_world_size() |
|
|
|
|
|
def get_default_group(): |
|
return dist.group.WORLD |
|
|
|
|
|
|
|
def colbert_score_reduce(scores_padded, D_mask): |
|
|
|
D_padding = ~D_mask.view(scores_padded.size(0), scores_padded.size(1)).bool() |
|
|
|
|
|
scores_padded[D_padding] = -9999 |
|
scores = scores_padded.max(1).values |
|
|
|
return scores.sum(-1) |
|
|
|
|
|
def colbert_score(Q, D_padded, D_mask, use_gpu=False): |
|
""" |
|
Supply sizes Q = (1 | num_docs, *, dim) and D = (num_docs, *, dim). |
|
If Q.size(0) is 1, the matrix will be compared with all passages. |
|
Otherwise, each query matrix will be compared against the *aligned* passage. |
|
|
|
EVENTUALLY: Consider masking with -inf for the maxsim (or enforcing a ReLU). |
|
""" |
|
if use_gpu: |
|
Q, D_padded, D_mask = Q.cuda(), D_padded.cuda(), D_mask.cuda() |
|
assert Q.dim() == 3, Q.size() |
|
assert D_padded.dim() == 3, D_padded.size() |
|
assert Q.size(0) in [1, D_padded.size(0)] |
|
|
|
scores = D_padded @ Q.to(dtype=D_padded.dtype).permute(0, 2, 1) |
|
|
|
return colbert_score_reduce(scores, D_mask) |
|
|
|
|
|
def _sort_by_length(ids, mask, bsize, *args): |
|
if ids.size(0) <= bsize: |
|
return ids, mask, torch.arange(ids.size(0)) |
|
|
|
indices = mask.sum(-1).sort().indices |
|
reverse_indices = indices.sort().indices |
|
|
|
return_array = [ids[indices], mask[indices]] |
|
for arg in args: |
|
if isinstance(arg, torch.Tensor): |
|
return_array.append(arg[indices]) |
|
else: |
|
|
|
return_array.append([arg[i] for i in indices]) |
|
|
|
return *return_array, reverse_indices |
|
|
|
|
|
def _split_into_batches(ids, mask, bsize, *args): |
|
batches = [] |
|
for offset in range(0, ids.size(0), bsize): |
|
batch = [ids[offset : offset + bsize], mask[offset : offset + bsize]] |
|
for arg in args: |
|
batch.append(arg[offset : offset + bsize]) |
|
batches.append(batch) |
|
return batches |
|
|