import joblib import gradio as gr from collections import Counter from typing import TypedDict from abc import ABC, abstractmethod from typing import Any, Dict, Type from scipy.sparse._csc import csc_matrix from typing import Iterable, Callable, List, Dict, Optional, Type, TypeVar import pickle from dataclasses import dataclass import tqdm import re import os import nltk nltk.download("stopwords", quiet=True) from nltk.corpus import stopwords as nltk_stopwords import math from dataclasses import dataclass from typing import Optional from datasets import load_dataset from enum import Enum import numpy as np @dataclass class Document: collection_id: str text: str @dataclass class Query: query_id: str text: str @dataclass class QRel: query_id: str collection_id: str relevance: int answer: Optional[str] = None class Split(str, Enum): train = "train" dev = "dev" test = "test" @dataclass class IRDataset: corpus: List[Document] queries: List[Query] split2qrels: Dict[Split, List[QRel]] def get_stats(self) -> Dict[str, int]: stats = {"|corpus|": len(self.corpus), "|queries|": len(self.queries)} for split, qrels in self.split2qrels.items(): stats[f"|qrels-{split}|"] = len(qrels) return stats def get_qrels_dict(self, split: Split) -> Dict[str, Dict[str, int]]: qrels_dict = {} for qrel in self.split2qrels[split]: qrels_dict.setdefault(qrel.query_id, {}) qrels_dict[qrel.query_id][qrel.collection_id] = qrel.relevance return qrels_dict def get_split_queries(self, split: Split) -> List[Query]: qrels = self.split2qrels[split] qids = {qrel.query_id for qrel in qrels} return list(filter(lambda query: query.query_id in qids, self.queries)) @(joblib.Memory(".cache").cache) def load_sciq(verbose: bool = False) -> IRDataset: train = load_dataset("allenai/sciq", split="train") validation = load_dataset("allenai/sciq", split="validation") test = load_dataset("allenai/sciq", split="test") data = {Split.train: train, Split.dev: validation, Split.test: test} # Each duplicated record is the same to each other: df = train.to_pandas() + validation.to_pandas() + test.to_pandas() for question, group in df.groupby("question"): assert len(set(group["support"].tolist())) == len(group) assert len(set(group["correct_answer"].tolist())) == len(group) # Build: corpus = [] queries = [] split2qrels: Dict[str, List[dict]] = {} question2id = {} support2id = {} for split, rows in data.items(): if verbose: print(f"|raw_{split}|", len(rows)) split2qrels[split] = [] for i, row in enumerate(rows): example_id = f"{split}-{i}" support: str = row["support"] if len(support.strip()) == 0: continue question = row["question"] if len(support.strip()) == 0: continue if support in support2id: continue else: support2id[support] = example_id if question in question2id: continue else: question2id[question] = example_id doc = {"collection_id": example_id, "text": support} query = {"query_id": example_id, "text": row["question"]} qrel = { "query_id": example_id, "collection_id": example_id, "relevance": 1, "answer": row["correct_answer"], } corpus.append(Document(**doc)) queries.append(Query(**query)) split2qrels[split].append(QRel(**qrel)) # Assembly and return: return IRDataset(corpus=corpus, queries=queries, split2qrels=split2qrels) LANGUAGE = "english" word_splitter = re.compile(r"(?u)\b\w\w+\b").findall stopwords = set(nltk_stopwords.words(LANGUAGE)) def word_splitting(text: str) -> List[str]: return word_splitter(text.lower()) def lemmatization(words: List[str]) -> List[str]: return words # We ignore lemmatization here for simplicity def simple_tokenize(text: str) -> List[str]: words = word_splitting(text) tokenized = list(filter(lambda w: w not in stopwords, words)) tokenized = lemmatization(tokenized) return tokenized T = TypeVar("T", bound="InvertedIndex") @dataclass class PostingList: term: str # The term docid_postings: List[int] # docid_postings[i] means the docid (int) of the i-th associated posting tweight_postings: List[float] # tweight_postings[i] means the term weight (float) of the i-th associated posting @dataclass class InvertedIndex: posting_lists: List[PostingList] # docid -> posting_list vocab: Dict[str, int] cid2docid: Dict[str, int] # collection_id -> docid collection_ids: List[str] # docid -> collection_id doc_texts: Optional[List[str]] = None # docid -> document text def save(self, output_dir: str) -> None: os.makedirs(output_dir, exist_ok=True) with open(os.path.join(output_dir, "index.pkl"), "wb") as f: pickle.dump(self, f) @classmethod def from_saved(cls: Type[T], saved_dir: str) -> T: index = cls( posting_lists=[], vocab={}, cid2docid={}, collection_ids=[], doc_texts=None ) with open(os.path.join(saved_dir, "index.pkl"), "rb") as f: index = pickle.load(f) return index class BaseRetriever(ABC): @property @abstractmethod def index_class(self) -> Type[Any]: pass def get_term_weights(self, query: str, cid: str) -> Dict[str, float]: raise NotImplementedError @abstractmethod def score(self, query: str, cid: str) -> float: pass @abstractmethod def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]: pass @dataclass class Counting: posting_lists: List[PostingList] vocab: Dict[str, int] cid2docid: Dict[str, int] collection_ids: List[str] dfs: List[int] # tid -> df dls: List[int] # docid -> doc length avgdl: float nterms: int doc_texts: Optional[List[str]] = None def run_counting( documents: Iterable[Document], tokenize_fn: Callable[[str], List[str]] = simple_tokenize, store_raw: bool = True, # store the document text in doc_texts ndocs: Optional[int] = None, show_progress_bar: bool = True, ) -> Counting: """Counting TFs, DFs, doc_lengths, etc.""" posting_lists: List[PostingList] = [] vocab: Dict[str, int] = {} cid2docid: Dict[str, int] = {} collection_ids: List[str] = [] dfs: List[int] = [] # tid -> df dls: List[int] = [] # docid -> doc length nterms: int = 0 doc_texts: Optional[List[str]] = [] for doc in tqdm.tqdm( documents, desc="Counting", total=ndocs, disable=not show_progress_bar, ): if doc.collection_id in cid2docid: continue collection_ids.append(doc.collection_id) docid = cid2docid.setdefault(doc.collection_id, len(cid2docid)) toks = tokenize_fn(doc.text) tok2tf = Counter(toks) dls.append(sum(tok2tf.values())) for tok, tf in tok2tf.items(): nterms += tf tid = vocab.get(tok, None) if tid is None: posting_lists.append( PostingList(term=tok, docid_postings=[], tweight_postings=[]) ) tid = vocab.setdefault(tok, len(vocab)) posting_lists[tid].docid_postings.append(docid) posting_lists[tid].tweight_postings.append(tf) if tid < len(dfs): dfs[tid] += 1 else: dfs.append(0) if store_raw: doc_texts.append(doc.text) else: doc_texts = None return Counting( posting_lists=posting_lists, vocab=vocab, cid2docid=cid2docid, collection_ids=collection_ids, dfs=dfs, dls=dls, avgdl=sum(dls) / len(dls), nterms=nterms, doc_texts=doc_texts, ) @dataclass class BM25Index(InvertedIndex): @staticmethod def tokenize(text: str) -> List[str]: return simple_tokenize(text) @staticmethod def cache_term_weights( posting_lists: List[PostingList], total_docs: int, avgdl: float, dfs: List[int], dls: List[int], k1: float, b: float, ) -> None: """Compute term weights and caching""" N = total_docs for tid, posting_list in enumerate( tqdm.tqdm(posting_lists, desc="Regularizing TFs") ): idf = BM25Index.calc_idf(df=dfs[tid], N=N) for i in range(len(posting_list.docid_postings)): docid = posting_list.docid_postings[i] tf = posting_list.tweight_postings[i] dl = dls[docid] regularized_tf = BM25Index.calc_regularized_tf( tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b ) posting_list.tweight_postings[i] = regularized_tf * idf @staticmethod def calc_regularized_tf( tf: int, dl: float, avgdl: float, k1: float, b: float ) -> float: return tf / (tf + k1 * (1 - b + b * dl / avgdl)) @staticmethod def calc_idf(df: int, N: int): return math.log(1 + (N - df + 0.5) / (df + 0.5)) @classmethod def build_from_documents( cls: Type["BM25Index"], documents: Iterable[Document], store_raw: bool = True, output_dir: Optional[str] = None, ndocs: Optional[int] = None, show_progress_bar: bool = True, k1: float = 0.9, b: float = 0.4, ) -> "BM25Index": # Counting TFs, DFs, doc_lengths, etc.: counting = run_counting( documents=documents, tokenize_fn=BM25Index.tokenize, store_raw=store_raw, ndocs=ndocs, show_progress_bar=show_progress_bar, ) # Compute term weights and caching: posting_lists = counting.posting_lists total_docs = len(counting.cid2docid) BM25Index.cache_term_weights( posting_lists=posting_lists, total_docs=total_docs, avgdl=counting.avgdl, dfs=counting.dfs, dls=counting.dls, k1=k1, b=b, ) # Assembly and save: index = BM25Index( posting_lists=posting_lists, vocab=counting.vocab, cid2docid=counting.cid2docid, collection_ids=counting.collection_ids, doc_texts=counting.doc_texts, ) return index @dataclass class CSCInvertedIndex: posting_lists_matrix: csc_matrix # docid -> posting_list vocab: Dict[str, int] cid2docid: Dict[str, int] # collection_id -> docid collection_ids: List[str] # docid -> collection_id doc_texts: Optional[List[str]] = None # docid -> document text def save(self, output_dir: str) -> None: os.makedirs(output_dir, exist_ok=True) with open(os.path.join(output_dir, "index.pkl"), "wb") as f: pickle.dump(self, f) @classmethod def from_saved(cls: Type[T], saved_dir: str) -> T: index = cls( posting_lists_matrix=None, vocab={}, cid2docid={}, collection_ids=[], doc_texts=None ) with open(os.path.join(saved_dir, "index.pkl"), "rb") as f: index = pickle.load(f) return index @dataclass class CSCBM25Index(CSCInvertedIndex): @staticmethod def tokenize(text: str) -> List[str]: return simple_tokenize(text) @staticmethod def cache_term_weights( posting_lists: List[PostingList], total_docs: int, avgdl: float, dfs: List[int], dls: List[int], k1: float, b: float, ) -> csc_matrix: """Compute term weights and caching""" ## YOUR_CODE_STARTS_HERE data = [] indices = [] indptr = [0] N = total_docs for tid, posting_list in enumerate( tqdm.tqdm(posting_lists, desc="Regularizing TFs") ): idf = BM25Index.calc_idf(df=dfs[tid], N=N) for i in range(len(posting_list.docid_postings)): docid = posting_list.docid_postings[i] tf = posting_list.tweight_postings[i] dl = dls[docid] regularized_tf = BM25Index.calc_regularized_tf( tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b ) weight = regularized_tf * idf data.append(weight) indices.append(docid) indptr.append(len(data)) data = np.array(data, dtype=np.float32) indices = np.array(indices, dtype=np.int32) indptr = np.array(indptr, dtype=np.int32) posting_lists_matrix = csc_matrix( (data, indices, indptr), shape=(total_docs, len(posting_lists)) ) return posting_lists_matrix ## YOUR_CODE_ENDS_HERE @staticmethod def calc_regularized_tf( tf: int, dl: float, avgdl: float, k1: float, b: float ) -> float: return tf / (tf + k1 * (1 - b + b * dl / avgdl)) @staticmethod def calc_idf(df: int, N: int): return math.log(1 + (N - df + 0.5) / (df + 0.5)) @classmethod def build_from_documents( cls: Type["CSCBM25Index"], documents: Iterable[Document], store_raw: bool = True, output_dir: Optional[str] = None, ndocs: Optional[int] = None, show_progress_bar: bool = True, k1: float = 0.9, b: float = 0.4, ) -> "CSCBM25Index": # Counting TFs, DFs, doc_lengths, etc.: counting = run_counting( documents=documents, tokenize_fn=CSCBM25Index.tokenize, store_raw=store_raw, ndocs=ndocs, show_progress_bar=show_progress_bar, ) # Compute term weights and caching: posting_lists = counting.posting_lists total_docs = len(counting.cid2docid) posting_lists_matrix = CSCBM25Index.cache_term_weights( posting_lists=posting_lists, total_docs=total_docs, avgdl=counting.avgdl, dfs=counting.dfs, dls=counting.dls, k1=k1, b=b, ) # Assembly and save: index = CSCBM25Index( posting_lists_matrix=posting_lists_matrix, vocab=counting.vocab, cid2docid=counting.cid2docid, collection_ids=counting.collection_ids, doc_texts=counting.doc_texts, ) return index class BaseCSCInvertedIndexRetriever(BaseRetriever): @property @abstractmethod def index_class(self) -> Type[CSCInvertedIndex]: pass def __init__(self, index_dir: str) -> None: self.index = self.index_class.from_saved(index_dir) def get_term_weights(self, query: str, cid: str) -> Dict[str, float]: ## YOUR_CODE_STARTS_HERE toks = self.index.tokenize(query) target_docid = self.index.cid2docid[cid] term_weights = {} for tok in toks: if tok not in self.index.vocab: continue tid = self.index.vocab[tok] weight = self.index.posting_lists_matrix[target_docid, tid] if weight == 0: continue term_weights[tok] = weight return term_weights ## YOUR_CODE_ENDS_HERE def score(self, query: str, cid: str) -> float: return sum(self.get_term_weights(query=query, cid=cid).values()) def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]: ## YOUR_CODE_STARTS_HERE toks = self.index.tokenize(query) docid2score: Dict[int, float] = {} for tok in toks: if tok not in self.index.vocab: continue tid = self.index.vocab[tok] col = self.index.posting_lists_matrix[:, tid] rows, data = col.indices, col.data for docid, tweight in zip(rows, data): docid2score.setdefault(docid, 0) docid2score[docid] += tweight docid2score = dict( sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk] ) return { self.index.collection_ids[docid]: score for docid, score in docid2score.items() } ## YOUR_CODE_ENDS_HERE class CSCBM25Retriever(BaseCSCInvertedIndexRetriever): @property def index_class(self) -> Type[CSCBM25Index]: return CSCBM25Index class Hit(TypedDict): cid: str score: float text: str demo: Optional[gr.Interface] = None # Assign your gradio demo to this variable return_type = List[Hit] ## YOUR_CODE_STARTS_HERE # Use default b, k1 sciq = load_sciq() csc_bm25_index = CSCBM25Index.build_from_documents( documents=iter(sciq.corpus), ndocs=12160, show_progress_bar=True ) csc_bm25_index.save("output/csc_bm25_index_default") csc_bm25_retriever = CSCBM25Retriever(index_dir="output/csc_bm25_index_default") doc2text = {doc.collection_id: doc.text for doc in sciq.corpus} def retrieve(query: str) -> List[Hit]: results = csc_bm25_retriever.retrieve(query) hits: List[Hit] = [] for cid, score in results.items(): hit: Hit = { "cid": cid, "score": score, "text": doc2text[cid] } hits.append(hit) hits = sorted(hits, key=lambda x: x["score"], reverse=True) return hits def format_hits(hits: List[Hit]): output = "" for i, hit in enumerate(hits, 1): output += f"\n\n{i}. Score: {hit['score']:.3f}\n" output += f"ID: {hit['cid']}\n" output += f"Text: {hit['text']}\n" output += "-" * 80 return output demo = gr.Interface( fn=retrieve, inputs=gr.Textbox(label="Query"), outputs=gr.JSON(label="Results"), title="Document Search", description="Search documents using BM25 retrieval" ) ## YOUR_CODE_ENDS_HERE demo.launch()