NetsPresso_QA / search_online.py
geonmin-kim's picture
Upload folder using huggingface_hub
d6585f5
import argparse
import json
from pyserini.search.faiss import (
AutoQueryEncoder,
AnceQueryEncoder,
DprQueryEncoder,
TctColBertQueryEncoder,
)
def _init_encoder_from_str(encoder, device="cpu"):
encoder_lower = encoder.lower()
if "dpr" in encoder_lower:
return DprQueryEncoder(encoder_dir=encoder, device=device)
elif "tct_colbert" in encoder_lower:
return TctColBertQueryEncoder(encoder_dir=encoder, device=device)
elif "ance" in encoder_lower:
return AnceQueryEncoder(encoder_dir=encoder, device=device)
elif "sentence" in encoder_lower:
return AutoQueryEncoder(
encoder_dir=encoder, pooling="mean", l2_norm=True, device=device
)
else:
return AutoQueryEncoder(encoder_dir=encoder, device=device)
def load_index(searcher_class, index_dir, query_encoder=None):
if query_encoder is not None:
searcher = searcher_class(index_dir=index_dir, query_encoder=query_encoder)
else:
searcher = searcher_class(index_dir=index_dir)
return searcher
class OnlineSearcher(object):
def __init__(self, args):
self.args = args
if args.index_type == "sparse":
query_encoder = None
elif args.index_type == "dense" or args.index_type == "hybrid":
query_encoder = _init_encoder_from_str(
encoder=args.encoder, device=args.device
)
else:
raise ValueError(
f"index_type {args.index_type} should be chosen among sparse, dense, or hybrid"
)
# load index
if args.index_type == "hybrid":
args.index = args.index.split(",")
assert (
len(args.index) == 2
), "require both sparse and dense index delimited by comma"
from pyserini.search.lucene import LuceneSearcher
self.ssearcher = load_index(
searcher_class=LuceneSearcher, index_dir=args.index[0]
)
self.ssearcher.set_language(args.lang_abbr)
from pyserini.search.faiss import FaissSearcher
self.dsearcher = load_index(
searcher_class=FaissSearcher,
index_dir=args.index[1],
query_encoder=query_encoder,
)
from pyserini.search.hybrid import HybridSearcher
self.searcher = HybridSearcher(self.dsearcher, self.ssearcher)
print(f"load {self.ssearcher.num_docs} documents from {args.index}")
else:
if args.index_type == "sparse":
from pyserini.search.lucene import LuceneSearcher as Searcher
elif args.index_type == "dense":
from pyserini.search.faiss import FaissSearcher as Searcher
self.searcher = load_index(
searcher_class=Searcher,
index_dir=args.index,
query_encoder=query_encoder,
)
if args.index_type == "sparse":
self.searcher.set_language(args.lang_abbr)
print(f"load {self.searcher.num_docs} documents from {args.index}")
def search(self, query, k=10):
if self.args.index_type == "hybrid":
hits = self.searcher.search(
query, alpha=self.args.alpha, normalization=self.args.normalization, k=k
)
else:
hits = self.searcher.search(query)
return hits
def print_result(self, hits, k):
# Print the first k hits:
docs = []
for i in range(0, min(k, len(hits))):
print(f"{i+1:2} {hits[i].docid:15} {hits[i].score:.5f}")
if (
self.args.index_type == "sparse"
): # faiss searcher does not store document raw text
doc = self.searcher.doc(hits[i].docid)
elif self.args.index_type == "hybrid":
doc = self.searcher.sparse_searcher.doc(hits[i].docid)
else:
doc = None
if doc is not None and not self.args.hide_text:
doc_raw = doc.raw()
docs.append(json.loads(doc_raw))
print(doc_raw)
docs = "\n\n".join(
[f'문서 {idx+1}\n{doc["contents"]}' for idx, doc in enumerate(docs)]
)
return docs
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Search interactively")
parser.add_argument(
"--index_type",
type=str,
required=True,
help="choose indexing type",
choices=["sparse", "dense", "hybrid"],
)
parser.add_argument(
"--index",
type=str,
required=True,
help="Path to index or name of prebuilt index.",
)
parser.add_argument("--query", type=str, required=True, help="Query text")
parser.add_argument(
"--lang_abbr",
type=str,
required=False,
default="ko",
help="for language specific algorithms for sparse retrieveal)",
)
parser.add_argument(
"--encoder", type=str, required=False, help="encoder name or checkpoint path"
)
parser.add_argument(
"--device",
type=str,
required=False,
default="cpu",
help="device to use for encoding queries (cf. pyserini does not support faiss-gpu)",
)
# for hybrid search
parser.add_argument(
"--alpha",
type=float,
default=0.5,
help="weight for hybrid search: alpha*score(sparse) + score(dense)",
)
parser.add_argument(
"--normalization",
action="store_true",
help="normalize sparse & dens score before fusion",
)
# search range
parser.add_argument(
"--k",
type=int,
default=10,
help="the number of passages to return (default: 10)",
)
# print option
parser.add_argument(
"--hide_text", action="store_true", help="do not print if this is true"
)
args = parser.parse_args()
# make searcher
searcher = OnlineSearcher(args)
print(f"given query: {args.query}")
# search
hits = searcher.search(args.query)
# print results
searcher.print_result(hits, args.k)