Finn commited on
Commit
99a17a5
1 Parent(s): 503a024

Inital commit

Browse files
Files changed (1) hide show
  1. app.py +329 -0
app.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ !pip install git+https://github.com/kwang2049/nlp4web-codebase.git
2
+
3
+ from dataclasses import dataclass
4
+ import pickle
5
+ import os
6
+ from typing import Iterable, Callable, List, Dict, Optional, Type, TypeVar
7
+ from nlp4web_codebase.ir.data_loaders.dm import Document
8
+ from collections import Counter
9
+ import tqdm
10
+ import re
11
+ import nltk
12
+ nltk.download("stopwords", quiet=True)
13
+ from nltk.corpus import stopwords as nltk_stopwords
14
+
15
+ LANGUAGE = "english"
16
+ word_splitter = re.compile(r"(?u)\b\w\w+\b").findall
17
+ stopwords = set(nltk_stopwords.words(LANGUAGE))
18
+
19
+
20
+ def word_splitting(text: str) -> List[str]:
21
+ return word_splitter(text.lower())
22
+
23
+ def lemmatization(words: List[str]) -> List[str]:
24
+ return words # We ignore lemmatization here for simplicity
25
+
26
+ def simple_tokenize(text: str) -> List[str]:
27
+ words = word_splitting(text)
28
+ tokenized = list(filter(lambda w: w not in stopwords, words))
29
+ tokenized = lemmatization(tokenized)
30
+ return tokenized
31
+
32
+ T = TypeVar("T", bound="InvertedIndex")
33
+
34
+ @dataclass
35
+ class PostingList:
36
+ term: str # The term
37
+ docid_postings: List[int] # docid_postings[i] means the docid (int) of the i-th associated posting
38
+ tweight_postings: List[float] # tweight_postings[i] means the term weight (float) of the i-th associated posting
39
+
40
+
41
+ @dataclass
42
+ class InvertedIndex:
43
+ posting_lists: List[PostingList] # docid -> posting_list
44
+ vocab: Dict[str, int]
45
+ cid2docid: Dict[str, int] # collection_id -> docid
46
+ collection_ids: List[str] # docid -> collection_id
47
+ doc_texts: Optional[List[str]] = None # docid -> document text
48
+
49
+ def save(self, output_dir: str) -> None:
50
+ os.makedirs(output_dir, exist_ok=True)
51
+ with open(os.path.join(output_dir, "index.pkl"), "wb") as f:
52
+ pickle.dump(self, f)
53
+
54
+ @classmethod
55
+ def from_saved(cls: Type[T], saved_dir: str) -> T:
56
+ index = cls(
57
+ posting_lists=[], vocab={}, cid2docid={}, collection_ids=[], doc_texts=None
58
+ )
59
+ with open(os.path.join(saved_dir, "index.pkl"), "rb") as f:
60
+ index = pickle.load(f)
61
+ return index
62
+
63
+
64
+ # The output of the counting function:
65
+ @dataclass
66
+ class Counting:
67
+ posting_lists: List[PostingList]
68
+ vocab: Dict[str, int]
69
+ cid2docid: Dict[str, int]
70
+ collection_ids: List[str]
71
+ dfs: List[int] # tid -> df
72
+ dls: List[int] # docid -> doc length
73
+ avgdl: float
74
+ nterms: int
75
+ doc_texts: Optional[List[str]] = None
76
+
77
+ def run_counting(
78
+ documents: Iterable[Document],
79
+ tokenize_fn: Callable[[str], List[str]] = simple_tokenize,
80
+ store_raw: bool = True, # store the document text in doc_texts
81
+ ndocs: Optional[int] = None,
82
+ show_progress_bar: bool = True,
83
+ ) -> Counting:
84
+ """Counting TFs, DFs, doc_lengths, etc."""
85
+ posting_lists: List[PostingList] = []
86
+ vocab: Dict[str, int] = {}
87
+ cid2docid: Dict[str, int] = {}
88
+ collection_ids: List[str] = []
89
+ dfs: List[int] = [] # tid -> df
90
+ dls: List[int] = [] # docid -> doc length
91
+ nterms: int = 0
92
+ doc_texts: Optional[List[str]] = []
93
+ for doc in tqdm.tqdm(
94
+ documents,
95
+ desc="Counting",
96
+ total=ndocs,
97
+ disable=not show_progress_bar,
98
+ ):
99
+ if doc.collection_id in cid2docid:
100
+ continue
101
+ collection_ids.append(doc.collection_id)
102
+ docid = cid2docid.setdefault(doc.collection_id, len(cid2docid))
103
+ toks = tokenize_fn(doc.text)
104
+ tok2tf = Counter(toks)
105
+ dls.append(sum(tok2tf.values()))
106
+ for tok, tf in tok2tf.items():
107
+ nterms += tf
108
+ tid = vocab.get(tok, None)
109
+ if tid is None:
110
+ posting_lists.append(
111
+ PostingList(term=tok, docid_postings=[], tweight_postings=[])
112
+ )
113
+ tid = vocab.setdefault(tok, len(vocab))
114
+ posting_lists[tid].docid_postings.append(docid)
115
+ posting_lists[tid].tweight_postings.append(tf)
116
+ if tid < len(dfs):
117
+ dfs[tid] += 1
118
+ else:
119
+ dfs.append(0)
120
+ if store_raw:
121
+ doc_texts.append(doc.text)
122
+ else:
123
+ doc_texts = None
124
+ return Counting(
125
+ posting_lists=posting_lists,
126
+ vocab=vocab,
127
+ cid2docid=cid2docid,
128
+ collection_ids=collection_ids,
129
+ dfs=dfs,
130
+ dls=dls,
131
+ avgdl=sum(dls) / len(dls),
132
+ nterms=nterms,
133
+ doc_texts=doc_texts,
134
+ )
135
+
136
+ from nlp4web_codebase.ir.data_loaders.sciq import load_sciq
137
+ sciq = load_sciq()
138
+ counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus))
139
+
140
+ from __future__ import annotations
141
+ from dataclasses import asdict, dataclass
142
+ import math
143
+ import os
144
+ from typing import Iterable, List, Optional, Type
145
+ import tqdm
146
+ from nlp4web_codebase.ir.data_loaders.dm import Document
147
+
148
+
149
+ @dataclass
150
+ class BM25Index(InvertedIndex):
151
+
152
+ @staticmethod
153
+ def tokenize(text: str) -> List[str]:
154
+ return simple_tokenize(text)
155
+
156
+ @staticmethod
157
+ def cache_term_weights(
158
+ posting_lists: List[PostingList],
159
+ total_docs: int,
160
+ avgdl: float,
161
+ dfs: List[int],
162
+ dls: List[int],
163
+ k1: float,
164
+ b: float,
165
+ ) -> None:
166
+ """Compute term weights and caching"""
167
+
168
+ N = total_docs
169
+ for tid, posting_list in enumerate(
170
+ tqdm.tqdm(posting_lists, desc="Regularizing TFs")
171
+ ):
172
+ idf = BM25Index.calc_idf(df=dfs[tid], N=N)
173
+ for i in range(len(posting_list.docid_postings)):
174
+ docid = posting_list.docid_postings[i]
175
+ tf = posting_list.tweight_postings[i]
176
+ dl = dls[docid]
177
+ regularized_tf = BM25Index.calc_regularized_tf(
178
+ tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b
179
+ )
180
+ posting_list.tweight_postings[i] = regularized_tf * idf
181
+
182
+ @staticmethod
183
+ def calc_regularized_tf(
184
+ tf: int, dl: float, avgdl: float, k1: float, b: float
185
+ ) -> float:
186
+ return tf / (tf + k1 * (1 - b + b * dl / avgdl))
187
+
188
+ @staticmethod
189
+ def calc_idf(df: int, N: int):
190
+ return math.log(1 + (N - df + 0.5) / (df + 0.5))
191
+
192
+ @classmethod
193
+ def build_from_documents(
194
+ cls: Type[BM25Index],
195
+ documents: Iterable[Document],
196
+ store_raw: bool = True,
197
+ output_dir: Optional[str] = None,
198
+ ndocs: Optional[int] = None,
199
+ show_progress_bar: bool = True,
200
+ k1: float = 0.9,
201
+ b: float = 0.4,
202
+ ) -> BM25Index:
203
+ # Counting TFs, DFs, doc_lengths, etc.:
204
+ counting = run_counting(
205
+ documents=documents,
206
+ tokenize_fn=BM25Index.tokenize,
207
+ store_raw=store_raw,
208
+ ndocs=ndocs,
209
+ show_progress_bar=show_progress_bar,
210
+ )
211
+
212
+ # Compute term weights and caching:
213
+ posting_lists = counting.posting_lists
214
+ total_docs = len(counting.cid2docid)
215
+ BM25Index.cache_term_weights(
216
+ posting_lists=posting_lists,
217
+ total_docs=total_docs,
218
+ avgdl=counting.avgdl,
219
+ dfs=counting.dfs,
220
+ dls=counting.dls,
221
+ k1=k1,
222
+ b=b,
223
+ )
224
+
225
+ # Assembly and save:
226
+ index = BM25Index(
227
+ posting_lists=posting_lists,
228
+ vocab=counting.vocab,
229
+ cid2docid=counting.cid2docid,
230
+ collection_ids=counting.collection_ids,
231
+ doc_texts=counting.doc_texts,
232
+ )
233
+ return index
234
+
235
+ from nlp4web_codebase.ir.models import BaseRetriever
236
+ from typing import Type
237
+ from abc import abstractmethod
238
+
239
+
240
+ class BaseInvertedIndexRetriever(BaseRetriever):
241
+
242
+ @property
243
+ @abstractmethod
244
+ def index_class(self) -> Type[InvertedIndex]:
245
+ pass
246
+
247
+ def __init__(self, index_dir: str) -> None:
248
+ self.index = self.index_class.from_saved(index_dir)
249
+
250
+ def get_term_weights(self, query: str, cid: str) -> Dict[str, float]:
251
+ toks = self.index.tokenize(query)
252
+ target_docid = self.index.cid2docid[cid]
253
+ term_weights = {}
254
+ for tok in toks:
255
+ if tok not in self.index.vocab:
256
+ continue
257
+ tid = self.index.vocab[tok]
258
+ posting_list = self.index.posting_lists[tid]
259
+ for docid, tweight in zip(
260
+ posting_list.docid_postings, posting_list.tweight_postings
261
+ ):
262
+ if docid == target_docid:
263
+ term_weights[tok] = tweight
264
+ break
265
+ return term_weights
266
+
267
+ def score(self, query: str, cid: str) -> float:
268
+ return sum(self.get_term_weights(query=query, cid=cid).values())
269
+
270
+ def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]:
271
+ toks = self.index.tokenize(query)
272
+ docid2score: Dict[int, float] = {}
273
+ for tok in toks:
274
+ if tok not in self.index.vocab:
275
+ continue
276
+ tid = self.index.vocab[tok]
277
+ posting_list = self.index.posting_lists[tid]
278
+ for docid, tweight in zip(
279
+ posting_list.docid_postings, posting_list.tweight_postings
280
+ ):
281
+ docid2score.setdefault(docid, 0)
282
+ docid2score[docid] += tweight
283
+ docid2score = dict(
284
+ sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk]
285
+ )
286
+ return {
287
+ self.index.collection_ids[docid]: score
288
+ for docid, score in docid2score.items()
289
+ }
290
+
291
+
292
+ class BM25Retriever(BaseInvertedIndexRetriever):
293
+
294
+ @property
295
+ def index_class(self) -> Type[BM25Index]:
296
+ return BM25Index
297
+
298
+
299
+ import gradio as gr
300
+ from typing import TypedDict
301
+
302
+ class Hit(TypedDict):
303
+ cid: str
304
+ score: float
305
+ text: str
306
+
307
+ demo: Optional[gr.Interface] = None # Assign your gradio demo to this variable
308
+ return_type = List[Hit]
309
+
310
+ ## YOUR_CODE_STARTS_HERE
311
+ def search(query: str) -> List[Hit]:
312
+ bm25_index = BM25Index.build_from_documents(
313
+ documents=iter(sciq.corpus),
314
+ ndocs=12160,
315
+ show_progress_bar=True
316
+ )
317
+ bm25_index.save("output/bm25_index")
318
+ bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
319
+ results = bm25_retriever.retrieve(query)
320
+
321
+ hit: Hit = []
322
+ for result in results:
323
+ hit.append({'cid': result, 'score': results[result], 'text': bm25_index.doc_texts[bm25_index.cid2docid[result]]})
324
+
325
+ return hit
326
+
327
+ demo = gr.Interface(fn=search, inputs="textbox", outputs="textbox")
328
+ ## YOUR_CODE_ENDS_HERE
329
+ demo.launch()