|
import joblib |
|
import gradio as gr |
|
from collections import Counter |
|
from typing import TypedDict |
|
from abc import ABC, abstractmethod |
|
from typing import Any, Dict, Type |
|
from scipy.sparse._csc import csc_matrix |
|
from typing import Iterable, Callable, List, Dict, Optional, Type, TypeVar |
|
import pickle |
|
from dataclasses import dataclass |
|
import tqdm |
|
import re |
|
import os |
|
import nltk |
|
nltk.download("stopwords", quiet=True) |
|
from nltk.corpus import stopwords as nltk_stopwords |
|
import math |
|
from dataclasses import dataclass |
|
from typing import Optional |
|
from datasets import load_dataset |
|
from enum import Enum |
|
import numpy as np |
|
|
|
@dataclass |
|
class Document: |
|
collection_id: str |
|
text: str |
|
|
|
|
|
@dataclass |
|
class Query: |
|
query_id: str |
|
text: str |
|
|
|
|
|
@dataclass |
|
class QRel: |
|
query_id: str |
|
collection_id: str |
|
relevance: int |
|
answer: Optional[str] = None |
|
|
|
class Split(str, Enum): |
|
train = "train" |
|
dev = "dev" |
|
test = "test" |
|
|
|
@dataclass |
|
class IRDataset: |
|
corpus: List[Document] |
|
queries: List[Query] |
|
split2qrels: Dict[Split, List[QRel]] |
|
|
|
def get_stats(self) -> Dict[str, int]: |
|
stats = {"|corpus|": len(self.corpus), "|queries|": len(self.queries)} |
|
for split, qrels in self.split2qrels.items(): |
|
stats[f"|qrels-{split}|"] = len(qrels) |
|
return stats |
|
|
|
def get_qrels_dict(self, split: Split) -> Dict[str, Dict[str, int]]: |
|
qrels_dict = {} |
|
for qrel in self.split2qrels[split]: |
|
qrels_dict.setdefault(qrel.query_id, {}) |
|
qrels_dict[qrel.query_id][qrel.collection_id] = qrel.relevance |
|
return qrels_dict |
|
|
|
def get_split_queries(self, split: Split) -> List[Query]: |
|
qrels = self.split2qrels[split] |
|
qids = {qrel.query_id for qrel in qrels} |
|
return list(filter(lambda query: query.query_id in qids, self.queries)) |
|
|
|
|
|
|
|
@(joblib.Memory(".cache").cache) |
|
def load_sciq(verbose: bool = False) -> IRDataset: |
|
train = load_dataset("allenai/sciq", split="train") |
|
validation = load_dataset("allenai/sciq", split="validation") |
|
test = load_dataset("allenai/sciq", split="test") |
|
data = {Split.train: train, Split.dev: validation, Split.test: test} |
|
|
|
|
|
df = train.to_pandas() + validation.to_pandas() + test.to_pandas() |
|
for question, group in df.groupby("question"): |
|
assert len(set(group["support"].tolist())) == len(group) |
|
assert len(set(group["correct_answer"].tolist())) == len(group) |
|
|
|
|
|
corpus = [] |
|
queries = [] |
|
split2qrels: Dict[str, List[dict]] = {} |
|
question2id = {} |
|
support2id = {} |
|
for split, rows in data.items(): |
|
if verbose: |
|
print(f"|raw_{split}|", len(rows)) |
|
split2qrels[split] = [] |
|
for i, row in enumerate(rows): |
|
example_id = f"{split}-{i}" |
|
support: str = row["support"] |
|
if len(support.strip()) == 0: |
|
continue |
|
question = row["question"] |
|
if len(support.strip()) == 0: |
|
continue |
|
if support in support2id: |
|
continue |
|
else: |
|
support2id[support] = example_id |
|
if question in question2id: |
|
continue |
|
else: |
|
question2id[question] = example_id |
|
doc = {"collection_id": example_id, "text": support} |
|
query = {"query_id": example_id, "text": row["question"]} |
|
qrel = { |
|
"query_id": example_id, |
|
"collection_id": example_id, |
|
"relevance": 1, |
|
"answer": row["correct_answer"], |
|
} |
|
corpus.append(Document(**doc)) |
|
queries.append(Query(**query)) |
|
split2qrels[split].append(QRel(**qrel)) |
|
|
|
|
|
return IRDataset(corpus=corpus, queries=queries, split2qrels=split2qrels) |
|
|
|
LANGUAGE = "english" |
|
word_splitter = re.compile(r"(?u)\b\w\w+\b").findall |
|
stopwords = set(nltk_stopwords.words(LANGUAGE)) |
|
|
|
def word_splitting(text: str) -> List[str]: |
|
return word_splitter(text.lower()) |
|
|
|
def lemmatization(words: List[str]) -> List[str]: |
|
return words |
|
|
|
def simple_tokenize(text: str) -> List[str]: |
|
words = word_splitting(text) |
|
tokenized = list(filter(lambda w: w not in stopwords, words)) |
|
tokenized = lemmatization(tokenized) |
|
return tokenized |
|
|
|
T = TypeVar("T", bound="InvertedIndex") |
|
|
|
@dataclass |
|
class PostingList: |
|
term: str |
|
docid_postings: List[int] |
|
tweight_postings: List[float] |
|
|
|
@dataclass |
|
class InvertedIndex: |
|
posting_lists: List[PostingList] |
|
vocab: Dict[str, int] |
|
cid2docid: Dict[str, int] |
|
collection_ids: List[str] |
|
doc_texts: Optional[List[str]] = None |
|
|
|
def save(self, output_dir: str) -> None: |
|
os.makedirs(output_dir, exist_ok=True) |
|
with open(os.path.join(output_dir, "index.pkl"), "wb") as f: |
|
pickle.dump(self, f) |
|
|
|
@classmethod |
|
def from_saved(cls: Type[T], saved_dir: str) -> T: |
|
index = cls( |
|
posting_lists=[], vocab={}, cid2docid={}, collection_ids=[], doc_texts=None |
|
) |
|
with open(os.path.join(saved_dir, "index.pkl"), "rb") as f: |
|
index = pickle.load(f) |
|
return index |
|
|
|
class BaseRetriever(ABC): |
|
|
|
@property |
|
@abstractmethod |
|
def index_class(self) -> Type[Any]: |
|
pass |
|
|
|
def get_term_weights(self, query: str, cid: str) -> Dict[str, float]: |
|
raise NotImplementedError |
|
|
|
@abstractmethod |
|
def score(self, query: str, cid: str) -> float: |
|
pass |
|
|
|
@abstractmethod |
|
def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]: |
|
pass |
|
|
|
@dataclass |
|
class Counting: |
|
posting_lists: List[PostingList] |
|
vocab: Dict[str, int] |
|
cid2docid: Dict[str, int] |
|
collection_ids: List[str] |
|
dfs: List[int] |
|
dls: List[int] |
|
avgdl: float |
|
nterms: int |
|
doc_texts: Optional[List[str]] = None |
|
|
|
def run_counting( |
|
documents: Iterable[Document], |
|
tokenize_fn: Callable[[str], List[str]] = simple_tokenize, |
|
store_raw: bool = True, |
|
ndocs: Optional[int] = None, |
|
show_progress_bar: bool = True, |
|
) -> Counting: |
|
"""Counting TFs, DFs, doc_lengths, etc.""" |
|
posting_lists: List[PostingList] = [] |
|
vocab: Dict[str, int] = {} |
|
cid2docid: Dict[str, int] = {} |
|
collection_ids: List[str] = [] |
|
dfs: List[int] = [] |
|
dls: List[int] = [] |
|
nterms: int = 0 |
|
doc_texts: Optional[List[str]] = [] |
|
for doc in tqdm.tqdm( |
|
documents, |
|
desc="Counting", |
|
total=ndocs, |
|
disable=not show_progress_bar, |
|
): |
|
if doc.collection_id in cid2docid: |
|
continue |
|
collection_ids.append(doc.collection_id) |
|
docid = cid2docid.setdefault(doc.collection_id, len(cid2docid)) |
|
toks = tokenize_fn(doc.text) |
|
tok2tf = Counter(toks) |
|
dls.append(sum(tok2tf.values())) |
|
for tok, tf in tok2tf.items(): |
|
nterms += tf |
|
tid = vocab.get(tok, None) |
|
if tid is None: |
|
posting_lists.append( |
|
PostingList(term=tok, docid_postings=[], tweight_postings=[]) |
|
) |
|
tid = vocab.setdefault(tok, len(vocab)) |
|
posting_lists[tid].docid_postings.append(docid) |
|
posting_lists[tid].tweight_postings.append(tf) |
|
if tid < len(dfs): |
|
dfs[tid] += 1 |
|
else: |
|
dfs.append(0) |
|
if store_raw: |
|
doc_texts.append(doc.text) |
|
else: |
|
doc_texts = None |
|
return Counting( |
|
posting_lists=posting_lists, |
|
vocab=vocab, |
|
cid2docid=cid2docid, |
|
collection_ids=collection_ids, |
|
dfs=dfs, |
|
dls=dls, |
|
avgdl=sum(dls) / len(dls), |
|
nterms=nterms, |
|
doc_texts=doc_texts, |
|
) |
|
|
|
@dataclass |
|
class BM25Index(InvertedIndex): |
|
|
|
@staticmethod |
|
def tokenize(text: str) -> List[str]: |
|
return simple_tokenize(text) |
|
|
|
@staticmethod |
|
def cache_term_weights( |
|
posting_lists: List[PostingList], |
|
total_docs: int, |
|
avgdl: float, |
|
dfs: List[int], |
|
dls: List[int], |
|
k1: float, |
|
b: float, |
|
) -> None: |
|
"""Compute term weights and caching""" |
|
|
|
N = total_docs |
|
for tid, posting_list in enumerate( |
|
tqdm.tqdm(posting_lists, desc="Regularizing TFs") |
|
): |
|
idf = BM25Index.calc_idf(df=dfs[tid], N=N) |
|
for i in range(len(posting_list.docid_postings)): |
|
docid = posting_list.docid_postings[i] |
|
tf = posting_list.tweight_postings[i] |
|
dl = dls[docid] |
|
regularized_tf = BM25Index.calc_regularized_tf( |
|
tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b |
|
) |
|
posting_list.tweight_postings[i] = regularized_tf * idf |
|
|
|
@staticmethod |
|
def calc_regularized_tf( |
|
tf: int, dl: float, avgdl: float, k1: float, b: float |
|
) -> float: |
|
return tf / (tf + k1 * (1 - b + b * dl / avgdl)) |
|
|
|
@staticmethod |
|
def calc_idf(df: int, N: int): |
|
return math.log(1 + (N - df + 0.5) / (df + 0.5)) |
|
|
|
@classmethod |
|
def build_from_documents( |
|
cls: Type["BM25Index"], |
|
documents: Iterable[Document], |
|
store_raw: bool = True, |
|
output_dir: Optional[str] = None, |
|
ndocs: Optional[int] = None, |
|
show_progress_bar: bool = True, |
|
k1: float = 0.9, |
|
b: float = 0.4, |
|
) -> "BM25Index": |
|
|
|
counting = run_counting( |
|
documents=documents, |
|
tokenize_fn=BM25Index.tokenize, |
|
store_raw=store_raw, |
|
ndocs=ndocs, |
|
show_progress_bar=show_progress_bar, |
|
) |
|
|
|
|
|
posting_lists = counting.posting_lists |
|
total_docs = len(counting.cid2docid) |
|
BM25Index.cache_term_weights( |
|
posting_lists=posting_lists, |
|
total_docs=total_docs, |
|
avgdl=counting.avgdl, |
|
dfs=counting.dfs, |
|
dls=counting.dls, |
|
k1=k1, |
|
b=b, |
|
) |
|
|
|
|
|
index = BM25Index( |
|
posting_lists=posting_lists, |
|
vocab=counting.vocab, |
|
cid2docid=counting.cid2docid, |
|
collection_ids=counting.collection_ids, |
|
doc_texts=counting.doc_texts, |
|
) |
|
return index |
|
|
|
|
|
@dataclass |
|
class CSCInvertedIndex: |
|
posting_lists_matrix: csc_matrix |
|
vocab: Dict[str, int] |
|
cid2docid: Dict[str, int] |
|
collection_ids: List[str] |
|
doc_texts: Optional[List[str]] = None |
|
|
|
def save(self, output_dir: str) -> None: |
|
os.makedirs(output_dir, exist_ok=True) |
|
with open(os.path.join(output_dir, "index.pkl"), "wb") as f: |
|
pickle.dump(self, f) |
|
|
|
@classmethod |
|
def from_saved(cls: Type[T], saved_dir: str) -> T: |
|
index = cls( |
|
posting_lists_matrix=None, vocab={}, cid2docid={}, collection_ids=[], doc_texts=None |
|
) |
|
with open(os.path.join(saved_dir, "index.pkl"), "rb") as f: |
|
index = pickle.load(f) |
|
return index |
|
|
|
@dataclass |
|
class CSCBM25Index(CSCInvertedIndex): |
|
|
|
@staticmethod |
|
def tokenize(text: str) -> List[str]: |
|
return simple_tokenize(text) |
|
|
|
@staticmethod |
|
def cache_term_weights( |
|
posting_lists: List[PostingList], |
|
total_docs: int, |
|
avgdl: float, |
|
dfs: List[int], |
|
dls: List[int], |
|
k1: float, |
|
b: float, |
|
) -> csc_matrix: |
|
"""Compute term weights and caching""" |
|
|
|
|
|
data = [] |
|
indices = [] |
|
indptr = [0] |
|
N = total_docs |
|
for tid, posting_list in enumerate( |
|
tqdm.tqdm(posting_lists, desc="Regularizing TFs") |
|
): |
|
idf = BM25Index.calc_idf(df=dfs[tid], N=N) |
|
for i in range(len(posting_list.docid_postings)): |
|
docid = posting_list.docid_postings[i] |
|
tf = posting_list.tweight_postings[i] |
|
dl = dls[docid] |
|
regularized_tf = BM25Index.calc_regularized_tf( |
|
tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b |
|
) |
|
weight = regularized_tf * idf |
|
data.append(weight) |
|
indices.append(docid) |
|
indptr.append(len(data)) |
|
|
|
data = np.array(data, dtype=np.float32) |
|
indices = np.array(indices, dtype=np.int32) |
|
indptr = np.array(indptr, dtype=np.int32) |
|
posting_lists_matrix = csc_matrix( |
|
(data, indices, indptr), |
|
shape=(total_docs, len(posting_lists)) |
|
) |
|
|
|
return posting_lists_matrix |
|
|
|
|
|
@staticmethod |
|
def calc_regularized_tf( |
|
tf: int, dl: float, avgdl: float, k1: float, b: float |
|
) -> float: |
|
return tf / (tf + k1 * (1 - b + b * dl / avgdl)) |
|
|
|
@staticmethod |
|
def calc_idf(df: int, N: int): |
|
return math.log(1 + (N - df + 0.5) / (df + 0.5)) |
|
|
|
@classmethod |
|
def build_from_documents( |
|
cls: Type["CSCBM25Index"], |
|
documents: Iterable[Document], |
|
store_raw: bool = True, |
|
output_dir: Optional[str] = None, |
|
ndocs: Optional[int] = None, |
|
show_progress_bar: bool = True, |
|
k1: float = 0.9, |
|
b: float = 0.4, |
|
) -> "CSCBM25Index": |
|
|
|
counting = run_counting( |
|
documents=documents, |
|
tokenize_fn=CSCBM25Index.tokenize, |
|
store_raw=store_raw, |
|
ndocs=ndocs, |
|
show_progress_bar=show_progress_bar, |
|
) |
|
|
|
|
|
posting_lists = counting.posting_lists |
|
total_docs = len(counting.cid2docid) |
|
posting_lists_matrix = CSCBM25Index.cache_term_weights( |
|
posting_lists=posting_lists, |
|
total_docs=total_docs, |
|
avgdl=counting.avgdl, |
|
dfs=counting.dfs, |
|
dls=counting.dls, |
|
k1=k1, |
|
b=b, |
|
) |
|
|
|
|
|
index = CSCBM25Index( |
|
posting_lists_matrix=posting_lists_matrix, |
|
vocab=counting.vocab, |
|
cid2docid=counting.cid2docid, |
|
collection_ids=counting.collection_ids, |
|
doc_texts=counting.doc_texts, |
|
) |
|
return index |
|
|
|
class BaseCSCInvertedIndexRetriever(BaseRetriever): |
|
|
|
@property |
|
@abstractmethod |
|
def index_class(self) -> Type[CSCInvertedIndex]: |
|
pass |
|
|
|
def __init__(self, index_dir: str) -> None: |
|
self.index = self.index_class.from_saved(index_dir) |
|
|
|
def get_term_weights(self, query: str, cid: str) -> Dict[str, float]: |
|
|
|
toks = self.index.tokenize(query) |
|
target_docid = self.index.cid2docid[cid] |
|
term_weights = {} |
|
|
|
for tok in toks: |
|
if tok not in self.index.vocab: |
|
continue |
|
tid = self.index.vocab[tok] |
|
weight = self.index.posting_lists_matrix[target_docid, tid] |
|
if weight == 0: continue |
|
term_weights[tok] = weight |
|
return term_weights |
|
|
|
|
|
def score(self, query: str, cid: str) -> float: |
|
return sum(self.get_term_weights(query=query, cid=cid).values()) |
|
|
|
def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]: |
|
|
|
toks = self.index.tokenize(query) |
|
docid2score: Dict[int, float] = {} |
|
for tok in toks: |
|
if tok not in self.index.vocab: |
|
continue |
|
tid = self.index.vocab[tok] |
|
col = self.index.posting_lists_matrix[:, tid] |
|
rows, data = col.indices, col.data |
|
|
|
for docid, tweight in zip(rows, data): |
|
docid2score.setdefault(docid, 0) |
|
docid2score[docid] += tweight |
|
|
|
docid2score = dict( |
|
sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk] |
|
) |
|
return { |
|
self.index.collection_ids[docid]: score |
|
for docid, score in docid2score.items() |
|
} |
|
|
|
|
|
class CSCBM25Retriever(BaseCSCInvertedIndexRetriever): |
|
|
|
@property |
|
def index_class(self) -> Type[CSCBM25Index]: |
|
return CSCBM25Index |
|
|
|
class Hit(TypedDict): |
|
cid: str |
|
score: float |
|
text: str |
|
|
|
demo: Optional[gr.Interface] = None |
|
return_type = List[Hit] |
|
|
|
|
|
|
|
sciq = load_sciq() |
|
csc_bm25_index = CSCBM25Index.build_from_documents( |
|
documents=iter(sciq.corpus), |
|
ndocs=12160, |
|
show_progress_bar=True |
|
) |
|
csc_bm25_index.save("output/csc_bm25_index_default") |
|
csc_bm25_retriever = CSCBM25Retriever(index_dir="output/csc_bm25_index_default") |
|
doc2text = {doc.collection_id: doc.text for doc in sciq.corpus} |
|
|
|
def retrieve(query: str) -> List[Hit]: |
|
results = csc_bm25_retriever.retrieve(query) |
|
|
|
hits: List[Hit] = [] |
|
for cid, score in results.items(): |
|
hit: Hit = { |
|
"cid": cid, |
|
"score": score, |
|
"text": doc2text[cid] |
|
} |
|
hits.append(hit) |
|
hits = sorted(hits, key=lambda x: x["score"], reverse=True) |
|
return hits |
|
|
|
def format_hits(hits: List[Hit]): |
|
output = "" |
|
for i, hit in enumerate(hits, 1): |
|
output += f"\n\n{i}. Score: {hit['score']:.3f}\n" |
|
output += f"ID: {hit['cid']}\n" |
|
output += f"Text: {hit['text']}\n" |
|
output += "-" * 80 |
|
return output |
|
|
|
demo = gr.Interface( |
|
fn=retrieve, |
|
inputs=gr.Textbox(label="Query"), |
|
outputs=gr.JSON(label="Results"), |
|
title="Document Search", |
|
description="Search documents using BM25 retrieval" |
|
) |
|
|
|
demo.launch() |