File size: 2,384 Bytes
66ae8fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""
This file contains utility functions for the FLMR model. Some of these functions are adapted from the original ColBERT codebase.
"""

import torch
import torch.distributed as dist


def get_rank():
    return dist.get_rank()


def get_world_size():
    return dist.get_world_size()


def get_default_group():
    return dist.group.WORLD


# TODO: The masking below might also be applicable in the kNN part
def colbert_score_reduce(scores_padded, D_mask):
    # print('D_mask', D_mask.shape, D_mask)
    D_padding = ~D_mask.view(scores_padded.size(0), scores_padded.size(1)).bool()
    # print('D_padding', D_padding.shape, D_padding)
    # print(D_padding[0].tolist())
    scores_padded[D_padding] = -9999
    scores = scores_padded.max(1).values

    return scores.sum(-1)


def colbert_score(Q, D_padded, D_mask, use_gpu=False):
    """
    Supply sizes Q = (1 | num_docs, *, dim) and D = (num_docs, *, dim).
    If Q.size(0) is 1, the matrix will be compared with all passages.
    Otherwise, each query matrix will be compared against the *aligned* passage.

    EVENTUALLY: Consider masking with -inf for the maxsim (or enforcing a ReLU).
    """
    if use_gpu:
        Q, D_padded, D_mask = Q.cuda(), D_padded.cuda(), D_mask.cuda()
    assert Q.dim() == 3, Q.size()
    assert D_padded.dim() == 3, D_padded.size()
    assert Q.size(0) in [1, D_padded.size(0)]

    scores = D_padded @ Q.to(dtype=D_padded.dtype).permute(0, 2, 1)

    return colbert_score_reduce(scores, D_mask)


def _sort_by_length(ids, mask, bsize, *args):
    if ids.size(0) <= bsize:
        return ids, mask, torch.arange(ids.size(0))

    indices = mask.sum(-1).sort().indices
    reverse_indices = indices.sort().indices

    return_array = [ids[indices], mask[indices]]
    for arg in args:
        if isinstance(arg, torch.Tensor):
            return_array.append(arg[indices])
        else:
            # arg is a list, and we want to sort the list according to indices
            return_array.append([arg[i] for i in indices])

    return *return_array, reverse_indices


def _split_into_batches(ids, mask, bsize, *args):
    batches = []
    for offset in range(0, ids.size(0), bsize):
        batch = [ids[offset : offset + bsize], mask[offset : offset + bsize]]
        for arg in args:
            batch.append(arg[offset : offset + bsize])
        batches.append(batch)
    return batches