File size: 3,744 Bytes
30ffb9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from sentence_transformers import CrossEncoder
from torch.nn import Sigmoid
from typing import List, Union
import numpy as np
from loguru import logger

class ReRanker(CrossEncoder):
    '''
    Cross-Encoder models achieve higher performance than Bi-Encoders, 
    however, they do not scale well to large datasets. The lack of scalability
    is due to the underlying cross-attention mechanism, which is computationally
    expensive.  Thus a Bi-Encoder is best used for 1st-stage document retrieval and 
    a Cross-Encoder is used to re-rank the retrieved documents. 

    https://www.sbert.net/examples/applications/cross-encoder/README.html
    '''

    def __init__(self, 
                 model_name: str='cross-encoder/ms-marco-MiniLM-L-6-v2',
                 **kwargs
                 ):
        super().__init__(model_name=model_name, 
                         **kwargs) 
        self.model_name = model_name
        self.score_field = 'cross_score'
        self.activation_fct = Sigmoid()

    def _cross_encoder_score(self, 
                             results: List[dict], 
                             query: str, 
                             hit_field: str='content',
                             apply_sigmoid: bool=True,
                             return_scores: bool=False
                             ) -> Union[np.array, None]:
        '''
        Given a list of hits from a Retriever:
            1. Scores hits by passing query and results through CrossEncoder model. 
            2. Adds cross-score key to results dictionary. 
            3. If desired returns np.array of Cross Encoder scores.
        '''
        activation_fct = self.activation_fct if apply_sigmoid else None
        #build query/content list
        cross_inp = [[query, hit[hit_field]] for hit in results]
        #get scores
        cross_scores = self.predict(cross_inp, activation_fct=activation_fct)
        for i, result in enumerate(results):
            result[self.score_field]=cross_scores[i]

        if return_scores:return cross_scores

    def rerank(self, 
               results: List[dict], 
               query: str, 
               top_k: int=10, 
               apply_sigmoid: bool=True,
               threshold: float=None
               ) -> List[dict]:
        '''
        Given a list of hits from a Retriever:
            1. Scores hits by passing query and results through CrossEncoder model. 
            2. Adds cross_score key to results dictionary. 
            3. Returns reranked results limited by either a threshold value or top_k.
        
        Args:
        -----
        results : List[dict]
            List of results from the Weaviate client
        query : str
            User query
        top_k : int=10
            Number of results to return
        apply_sigmoid : bool=True
            Whether to apply sigmoid activation to cross-encoder scores.  If False, 
            returns raw cross-encoder scores (logits).
        threshold : float=None
            Minimum cross-encoder score to return. If no hits are above threshold, 
            returns top_k hits.
        '''
        # Sort results by the cross-encoder scores
        self._cross_encoder_score(results=results, query=query, apply_sigmoid=apply_sigmoid)

        sorted_hits = sorted(results, key=lambda x: x[self.score_field], reverse=True)
        if threshold or threshold == 0:
            filtered_hits = [hit for hit in sorted_hits if hit[self.score_field] >= threshold]
            if not any(filtered_hits):
                logger.warning(f'No hits above threshold {threshold}. Returning top {top_k} hits.')
                return sorted_hits[:top_k]
            return filtered_hits
        return sorted_hits[:top_k]