Shakshi3104 commited on
Commit
6d23787
·
1 Parent(s): 02c2acd

[add] Implement hybrid search

Browse files
Files changed (1) hide show
  1. model/search/hybrid.py +146 -0
model/search/hybrid.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, List
2
+
3
+ import pandas as pd
4
+ from copy import deepcopy
5
+
6
+ from dotenv import load_dotenv
7
+ from loguru import logger
8
+ from tqdm import tqdm
9
+
10
+ from model.search.base import BaseSearchClient
11
+ from model.search.surface import BM25SearchClient
12
+ from model.search.vector import RuriVoyagerSearchClient
13
+
14
+ from model.utils.timer import stop_watch
15
+
16
+
17
+ def reciprocal_rank_fusion(sparse: pd.DataFrame, dense: pd.DataFrame, k=60) -> pd.DataFrame:
18
+ """
19
+ Reciprocal Rank Fusionを計算する
20
+
21
+ Notes
22
+ ----------
23
+ RRFの計算は以下の式
24
+
25
+ .. math:: RRF = \sum_{i=1}^n \frac{1}{k+r_i}
26
+
27
+ Parameters
28
+ ----------
29
+ sparse:
30
+ pd.DataFrame, 表層検索の検索結果
31
+ dense:
32
+ pd.DataFrame, ベクトル検索の結果
33
+ k:
34
+ int,
35
+
36
+ Returns
37
+ -------
38
+ rank_results:
39
+ pd.DataFrame, RRFによるリランク結果
40
+
41
+ """
42
+ # カラム名を変更
43
+ sparse = sparse.rename(columns={"rank": "rank_sparse"})
44
+ dense = dense.rename(columns={"rank": "rank_dense"})
45
+ # denseはランク以外を落として結合する
46
+ dense_ = dense["rank_dense"]
47
+
48
+ # 順位を1からスタートするようにする
49
+ sparse["rank_sparse"] += 1
50
+ dense_ += 1
51
+
52
+ # 文書のインデックスをキーに結合する
53
+ rank_results = pd.merge(sparse, dense_, how="left", left_index=True, right_index=True)
54
+
55
+ # RRFスコアの計算
56
+ rank_results["rrf_score"] = 1 / (rank_results["rank_dense"] + k) + 1 / (rank_results["rank_sparse"] + k)
57
+
58
+ # RRFスコアのスコアが大きい順にソート
59
+ rank_results = rank_results.sort_values(["rrf_score"], ascending=False)
60
+ rank_results["rank"] = deepcopy(rank_results.reset_index()).index
61
+
62
+ return rank_results
63
+
64
+
65
+ class HybridSearchClient(BaseSearchClient):
66
+ def __init__(self, dense_model: BaseSearchClient, sparse_model: BaseSearchClient):
67
+ self.dense_model = dense_model
68
+ self.sparse_model = sparse_model
69
+
70
+ @classmethod
71
+ @stop_watch
72
+ def from_dataframe(cls, _data: pd.DataFrame, _target: str):
73
+ """
74
+ 検索ドキュメントのpd.DataFrameから初期化する
75
+
76
+ Parameters
77
+ ----------
78
+ _data:
79
+ pd.DataFrame, 検索対象のDataFrame
80
+
81
+ _target:
82
+ str, 検索対象のカラム名
83
+
84
+ Returns
85
+ -------
86
+
87
+ """
88
+ # 表層検索の初期化
89
+ dense_model = BM25SearchClient.from_dataframe(_data, _target)
90
+ # ベクトル検索の初期化
91
+ sparse_model = RuriVoyagerSearchClient.from_dataframe(_data, _target)
92
+
93
+ return cls(dense_model, sparse_model)
94
+
95
+ @stop_watch
96
+ def search_top_n(self, _query: Union[List[str], str], n: int = 10) -> List[pd.DataFrame]:
97
+ """
98
+ クエリに対する検索結果をtop-n個取得する
99
+
100
+ Parameters
101
+ ----------
102
+ _query:
103
+ Union[List[str], str], 検索クエリ
104
+ n:
105
+ int, top-nの個数. デフォルト 10.
106
+
107
+ Returns
108
+ -------
109
+ results:
110
+ List[pd.DataFrame], ランキング結果
111
+ """
112
+
113
+ logger.info(f"🚦 [HybridSearchClient] Search top {n} | {_query}")
114
+
115
+ # 型チェック
116
+ if isinstance(_query, str):
117
+ _query = [_query]
118
+
119
+ # ランキングtop-nをクエリ毎に取得
120
+ result = []
121
+ for query in tqdm(_query):
122
+ assert len(self.sparse_model.corpus) == len(
123
+ self.dense_model.corpus), "The document counts do not match between sparse and dense!"
124
+
125
+ # ドキュメント数
126
+ doc_num = len(self.sparse_model.corpus)
127
+
128
+ # 表層検索
129
+ logger.info(f"🚦 [HybridSearchClient] run surface search ...")
130
+ sparse_res = self.sparse_model.search_top_n(query, n=doc_num)
131
+ # ベクトル検索
132
+ logger.info(f"🚦 [HybridSearchClient] run vector search ...")
133
+ dense_res = self.dense_model.search_top_n(query, n=doc_num)
134
+
135
+ # RRFスコアの計算
136
+ logger.info(f"🚦 [HybridSearchClient] compute RRF scores ...")
137
+ rrf_res = reciprocal_rank_fusion(sparse_res[0], dense_res[0])
138
+
139
+ # 結果をtop Nに絞る
140
+ top_num = 10
141
+ rrf_res = rrf_res.head(top_num)
142
+ logger.info(f"🚦 [HybridSearchClient] return {top_num} results")
143
+
144
+ result.append(rrf_res)
145
+
146
+ return result