File size: 4,255 Bytes
2cca64b
 
 
 
 
 
 
 
 
 
 
 
 
b58c0e3
2cca64b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b58c0e3
2cca64b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85c3441
2cca64b
 
 
 
 
85c3441
2cca64b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from copy import deepcopy
from typing import List, Union

import pandas as pd
import numpy as np

from loguru import logger
from tqdm import tqdm

from rank_bm25 import BM25Okapi

from model.search.base import BaseSearchClient
from model.utils.tokenizer import MeCabTokenizer
from model.utils.timer import stop_watch


class BM25Wrapper(BM25Okapi):
    def __init__(self, dataset: pd.DataFrame, target, tokenizer=None, k1=1.5, b=0.75, epsilon=0.25):
        self.k1 = k1
        self.b = b
        self.epsilon = epsilon
        self.dataset = dataset
        corpus = dataset[target].values.tolist()
        super().__init__(corpus, tokenizer)

    def get_top_n(self, query, documents, n=5):
        assert self.corpus_size == len(documents), "The documents given don't match the index corpus!"

        scores = self.get_scores(query)
        top_n = np.argsort(scores)[::-1][:n]

        result = deepcopy(self.dataset.iloc[top_n])
        result["score"] = scores[top_n]
        return result


class BM25SearchClient(BaseSearchClient):
    def __init__(self, _model: BM25Okapi, _corpus: List[List[str]]):
        """

        Parameters
        ----------
        _model:
            BM25Okapi
        _corpus:
            List[List[str]], 検索対象の分かち書き後のフィールド
        """
        self.model = _model
        self.corpus = _corpus

    @staticmethod
    def tokenize_ja(_text: List[str]):
        """MeCab日本語分かち書きによるコーパス作成

        Args:
            _text (List[str]): コーパス文のリスト

        Returns:
            List[List[str]]: 分かち書きされたテキストのリスト
        """

        # MeCabで分かち書き
        parser = MeCabTokenizer.from_tagger("-Owakati")

        corpus = []
        with tqdm(_text) as pbar:
            for i, t in enumerate(pbar):
                try:
                    # 分かち書きをする
                    corpus.append(parser.parse(t).split())
                except TypeError as e:
                    if not isinstance(t, str):
                        logger.info(f"🚦 [BM25SearchClient] Corpus index of {i} is not instance of String.")
                        corpus.append(["[UNKNOWN]"])
                    else:
                        raise e
        return corpus

    @classmethod
    def from_dataframe(cls, _data: pd.DataFrame, _target: str):
        """
        検索ドキュメントのpd.DataFrameから初期化する

        Parameters
        ----------
        _data:
            pd.DataFrame, 検索対象のDataFrame

        _target:
            str, 検索対象のカラム名

        Returns
        -------

        """

        logger.info("🚦 [BM25SearchClient] Initialize from DataFrame")

        search_field = _data[_target]
        corpus = search_field.values.tolist()

        # 分かち書きをする
        corpus_tokenized = cls.tokenize_ja(corpus)
        _data["tokenized"] = corpus_tokenized

        bm25 = BM25Wrapper(_data, "tokenized")
        return cls(bm25, corpus_tokenized)

    @stop_watch
    def search_top_n(self, _query: Union[List[str], str], n: int = 10) -> List[pd.DataFrame]:
        """
        クエリに対する検索結果をtop-n個取得する

        Parameters
        ----------
        _query:
            Union[List[str], str], 検索クエリ
        n:
            int, top-nの個数. デフォルト 10.

        Returns
        -------
        results:
            List[pd.DataFrame], ランキング結果
        """

        logger.info(f"🚦 [BM25SearchClient] Search top {n} | {_query}")

        # 型チェック
        if isinstance(_query, str):
            _query = [_query]

        # クエリを分かち書き
        query_tokens = self.tokenize_ja(_query)

        # ランキングtop-nをクエリ毎に取得
        result = []
        for query in tqdm(query_tokens):
            df_res = self.model.get_top_n(query, self.corpus, n)
            # ランク
            df_res["rank"] = deepcopy(df_res.reset_index()).index
            df_res = df_res.drop(columns=["tokenized"])
            result.append(df_res)

        logger.success(f"🚦 [BM25SearchClient] Executed")

        return result