|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import cmd |
|
import json |
|
import os |
|
import random |
|
|
|
from pyserini.search.lucene import LuceneSearcher |
|
from pyserini.search.faiss import FaissSearcher, TctColBertQueryEncoder, AnceQueryEncoder |
|
from pyserini.search.hybrid import HybridSearcher |
|
from pyserini import search |
|
|
|
|
|
class MsMarcoDemo(cmd.Cmd): |
|
dev_topics = list(search.get_topics('msmarco-passage-dev-subset').values()) |
|
|
|
ssearcher = LuceneSearcher.from_prebuilt_index('msmarco-passage') |
|
dsearcher = None |
|
hsearcher = None |
|
searcher = ssearcher |
|
|
|
k = 10 |
|
prompt = '>>> ' |
|
|
|
|
|
def precmd(self, line): |
|
if line[0] == '/': |
|
line = line[1:] |
|
return line |
|
|
|
def do_help(self, arg): |
|
print(f'/help : returns this message') |
|
print(f'/k [NUM] : sets k (number of hits to return) to [NUM]') |
|
print(f'/model [MODEL] : sets encoder to use the model [MODEL] (one of tct, ance)') |
|
print(f'/mode [MODE] : sets retriever type to [MODE] (one of sparse, dense, hybrid)') |
|
print(f'/random : returns results for a random question from dev subset') |
|
|
|
def do_k(self, arg): |
|
print(f'setting k = {int(arg)}') |
|
self.k = int(arg) |
|
|
|
def do_mode(self, arg): |
|
if arg == "sparse": |
|
self.searcher = self.ssearcher |
|
elif arg == "dense": |
|
if self.dsearcher is None: |
|
print(f'Specify model through /model before using dense retrieval.') |
|
return |
|
self.searcher = self.dsearcher |
|
elif arg == "hybrid": |
|
if self.hsearcher is None: |
|
print(f'Specify model through /model before using hybrid retrieval.') |
|
return |
|
self.searcher = self.hsearcher |
|
else: |
|
print( |
|
f'Mode "{arg}" is invalid. Mode should be one of [sparse, dense, hybrid].') |
|
return |
|
print(f'setting retriver = {arg}') |
|
|
|
def do_model(self, arg): |
|
if arg == "tct": |
|
encoder = TctColBertQueryEncoder("castorini/tct_colbert-msmarco") |
|
index = "msmarco-passage-tct_colbert-hnsw" |
|
elif arg == "ance": |
|
encoder = AnceQueryEncoder("castorini/ance-msmarco-passage") |
|
index = "msmarco-passage-ance-bf" |
|
else: |
|
print( |
|
f'Model "{arg}" is invalid. Model should be one of [tct, ance].') |
|
return |
|
|
|
self.dsearcher = FaissSearcher.from_prebuilt_index( |
|
index, |
|
encoder |
|
) |
|
self.hsearcher = HybridSearcher(self.dsearcher, self.ssearcher) |
|
print(f'setting model = {arg}') |
|
|
|
def do_random(self, arg): |
|
q = random.choice(self.dev_topics)['title'] |
|
print(f'question: {q}') |
|
self.default(q) |
|
|
|
def do_EOF(self, line): |
|
return True |
|
|
|
def default(self, q): |
|
hits = self.searcher.search(q, self.k) |
|
|
|
for i in range(0, len(hits)): |
|
raw_doc = None |
|
if isinstance(self.searcher, LuceneSearcher): |
|
raw_doc = hits[i].raw |
|
else: |
|
doc = self.searcher.doc(hits[i].docid) |
|
if doc: |
|
raw_doc = doc.raw() |
|
jsondoc = json.loads(raw_doc) |
|
print(f'{i + 1:2} {hits[i].score:.5f} {jsondoc["contents"]}') |
|
|
|
|
|
if __name__ == '__main__': |
|
MsMarcoDemo().cmdloop() |
|
|