tipperdair commited on
Commit
ed452d4
·
verified ·
1 Parent(s): 93458b4

Upload 12 files

Browse files
.gitignore ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+
131
+ *.tsv
132
+ *.jsonl
133
+ *.zip
134
+ output/
README.md CHANGED
@@ -1,12 +1,2 @@
1
- ---
2
- title: Nlp4web
3
- emoji: 🦀
4
- colorFrom: yellow
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 5.5.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ # nlp4web
2
+ Codebase of teaching materials for NLP4Web.
 
 
 
 
 
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,733 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Kopie von HW1 (more instructed).ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1dGoZK5ZufqNgHm3hH8FEXe34rFqvwLOY
8
+ """
9
+ from __future__ import annotations
10
+
11
+ """## Pre-requisite code
12
+
13
+ The code within this section will be used in the tasks. Please do not change these code lines.
14
+
15
+ ### SciQ loading and counting
16
+ """
17
+
18
+ from dataclasses import dataclass
19
+ import pickle
20
+ import os
21
+ from typing import Iterable, Callable, List, Dict, Optional, Type, TypeVar
22
+ from nlp4web_codebase.ir.data_loaders.dm import Document
23
+ from collections import Counter
24
+ import tqdm
25
+ import re
26
+ import nltk
27
+ nltk.download("stopwords", quiet=True)
28
+ from nltk.corpus import stopwords as nltk_stopwords
29
+
30
+ LANGUAGE = "english"
31
+ word_splitter = re.compile(r"(?u)\b\w\w+\b").findall
32
+ stopwords = set(nltk_stopwords.words(LANGUAGE))
33
+
34
+
35
+ def word_splitting(text: str) -> List[str]:
36
+ return word_splitter(text.lower())
37
+
38
+ def lemmatization(words: List[str]) -> List[str]:
39
+ return words # We ignore lemmatization here for simplicity
40
+
41
+ def simple_tokenize(text: str) -> List[str]:
42
+ words = word_splitting(text)
43
+ tokenized = list(filter(lambda w: w not in stopwords, words))
44
+ tokenized = lemmatization(tokenized)
45
+ return tokenized
46
+
47
+ T = TypeVar("T", bound="InvertedIndex")
48
+
49
+ @dataclass
50
+ class PostingList:
51
+ term: str # The term
52
+ docid_postings: List[int] # docid_postings[i] means the docid (int) of the i-th associated posting
53
+ tweight_postings: List[float] # tweight_postings[i] means the term weight (float) of the i-th associated posting
54
+
55
+
56
+ @dataclass
57
+ class InvertedIndex:
58
+ posting_lists: List[PostingList] # docid -> posting_list
59
+ vocab: Dict[str, int]
60
+ cid2docid: Dict[str, int] # collection_id -> docid
61
+ collection_ids: List[str] # docid -> collection_id
62
+ doc_texts: Optional[List[str]] = None # docid -> document text
63
+
64
+ def save(self, output_dir: str) -> None:
65
+ os.makedirs(output_dir, exist_ok=True)
66
+ with open(os.path.join(output_dir, "index.pkl"), "wb") as f:
67
+ pickle.dump(self, f)
68
+
69
+ @classmethod
70
+ def from_saved(cls: Type[T], saved_dir: str) -> T:
71
+ index = cls(
72
+ posting_lists=[], vocab={}, cid2docid={}, collection_ids=[], doc_texts=None
73
+ )
74
+ with open(os.path.join(saved_dir, "index.pkl"), "rb") as f:
75
+ index = pickle.load(f)
76
+ return index
77
+
78
+
79
+ # The output of the counting function:
80
+ @dataclass
81
+ class Counting:
82
+ posting_lists: List[PostingList]
83
+ vocab: Dict[str, int]
84
+ cid2docid: Dict[str, int]
85
+ collection_ids: List[str]
86
+ dfs: List[int] # tid -> df
87
+ dls: List[int] # docid -> doc length
88
+ avgdl: float
89
+ nterms: int
90
+ doc_texts: Optional[List[str]] = None
91
+
92
+ def run_counting(
93
+ documents: Iterable[Document],
94
+ tokenize_fn: Callable[[str], List[str]] = simple_tokenize,
95
+ store_raw: bool = True, # store the document text in doc_texts
96
+ ndocs: Optional[int] = None,
97
+ show_progress_bar: bool = True,
98
+ ) -> Counting:
99
+ """Counting TFs, DFs, doc_lengths, etc."""
100
+ posting_lists: List[PostingList] = []
101
+ vocab: Dict[str, int] = {}
102
+ cid2docid: Dict[str, int] = {}
103
+ collection_ids: List[str] = []
104
+ dfs: List[int] = [] # tid -> df
105
+ dls: List[int] = [] # docid -> doc length
106
+ nterms: int = 0
107
+ doc_texts: Optional[List[str]] = []
108
+ for doc in tqdm.tqdm(
109
+ documents,
110
+ desc="Counting",
111
+ total=ndocs,
112
+ disable=not show_progress_bar,
113
+ ):
114
+ if doc.collection_id in cid2docid:
115
+ continue
116
+ collection_ids.append(doc.collection_id)
117
+ docid = cid2docid.setdefault(doc.collection_id, len(cid2docid))
118
+ toks = tokenize_fn(doc.text)
119
+ tok2tf = Counter(toks)
120
+ dls.append(sum(tok2tf.values()))
121
+ for tok, tf in tok2tf.items():
122
+ nterms += tf
123
+ tid = vocab.get(tok, None)
124
+ if tid is None:
125
+ posting_lists.append(
126
+ PostingList(term=tok, docid_postings=[], tweight_postings=[])
127
+ )
128
+ tid = vocab.setdefault(tok, len(vocab))
129
+ posting_lists[tid].docid_postings.append(docid)
130
+ posting_lists[tid].tweight_postings.append(tf)
131
+ if tid < len(dfs):
132
+ dfs[tid] += 1
133
+ else:
134
+ dfs.append(0)
135
+ if store_raw:
136
+ doc_texts.append(doc.text)
137
+ else:
138
+ doc_texts = None
139
+ return Counting(
140
+ posting_lists=posting_lists,
141
+ vocab=vocab,
142
+ cid2docid=cid2docid,
143
+ collection_ids=collection_ids,
144
+ dfs=dfs,
145
+ dls=dls,
146
+ avgdl=sum(dls) / len(dls),
147
+ nterms=nterms,
148
+ doc_texts=doc_texts,
149
+ )
150
+
151
+ from nlp4web_codebase.ir.data_loaders.sciq import load_sciq
152
+ sciq = load_sciq()
153
+ counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus))
154
+
155
+ """### BM25 Index"""
156
+
157
+
158
+ from dataclasses import asdict, dataclass
159
+ import math
160
+ import os
161
+ from typing import Iterable, List, Optional, Type
162
+ import tqdm
163
+ from nlp4web_codebase.ir.data_loaders.dm import Document
164
+
165
+
166
+ @dataclass
167
+ class BM25Index(InvertedIndex):
168
+
169
+ @staticmethod
170
+ def tokenize(text: str) -> List[str]:
171
+ return simple_tokenize(text)
172
+
173
+ @staticmethod
174
+ def cache_term_weights(
175
+ posting_lists: List[PostingList],
176
+ total_docs: int,
177
+ avgdl: float,
178
+ dfs: List[int],
179
+ dls: List[int],
180
+ k1: float,
181
+ b: float,
182
+ ) -> None:
183
+ """Compute term weights and caching"""
184
+
185
+ N = total_docs
186
+ for tid, posting_list in enumerate(
187
+ tqdm.tqdm(posting_lists, desc="Regularizing TFs")
188
+ ):
189
+ idf = BM25Index.calc_idf(df=dfs[tid], N=N)
190
+ for i in range(len(posting_list.docid_postings)):
191
+ docid = posting_list.docid_postings[i]
192
+ tf = posting_list.tweight_postings[i]
193
+ dl = dls[docid]
194
+ regularized_tf = BM25Index.calc_regularized_tf(
195
+ tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b
196
+ )
197
+ posting_list.tweight_postings[i] = regularized_tf * idf
198
+
199
+ @staticmethod
200
+ def calc_regularized_tf(
201
+ tf: int, dl: float, avgdl: float, k1: float, b: float
202
+ ) -> float:
203
+ return tf / (tf + k1 * (1 - b + b * dl / avgdl))
204
+
205
+ @staticmethod
206
+ def calc_idf(df: int, N: int):
207
+ return math.log(1 + (N - df + 0.5) / (df + 0.5))
208
+
209
+ @classmethod
210
+ def build_from_documents(
211
+ cls: Type[BM25Index],
212
+ documents: Iterable[Document],
213
+ store_raw: bool = True,
214
+ output_dir: Optional[str] = None,
215
+ ndocs: Optional[int] = None,
216
+ show_progress_bar: bool = True,
217
+ k1: float = 0.9,
218
+ b: float = 0.4,
219
+ ) -> BM25Index:
220
+ # Counting TFs, DFs, doc_lengths, etc.:
221
+ counting = run_counting(
222
+ documents=documents,
223
+ tokenize_fn=BM25Index.tokenize,
224
+ store_raw=store_raw,
225
+ ndocs=ndocs,
226
+ show_progress_bar=show_progress_bar,
227
+ )
228
+
229
+ # Compute term weights and caching:
230
+ posting_lists = counting.posting_lists
231
+ total_docs = len(counting.cid2docid)
232
+ BM25Index.cache_term_weights(
233
+ posting_lists=posting_lists,
234
+ total_docs=total_docs,
235
+ avgdl=counting.avgdl,
236
+ dfs=counting.dfs,
237
+ dls=counting.dls,
238
+ k1=k1,
239
+ b=b,
240
+ )
241
+
242
+ # Assembly and save:
243
+ index = BM25Index(
244
+ posting_lists=posting_lists,
245
+ vocab=counting.vocab,
246
+ cid2docid=counting.cid2docid,
247
+ collection_ids=counting.collection_ids,
248
+ doc_texts=counting.doc_texts,
249
+ )
250
+ return index
251
+
252
+ bm25_index = BM25Index.build_from_documents(
253
+ documents=iter(sciq.corpus),
254
+ ndocs=12160,
255
+ show_progress_bar=True,
256
+ )
257
+ bm25_index.save("output/bm25_index")
258
+
259
+ """### BM25 Retriever"""
260
+
261
+ from nlp4web_codebase.ir.models import BaseRetriever
262
+ from typing import Type
263
+ from abc import abstractmethod
264
+
265
+
266
+ class BaseInvertedIndexRetriever(BaseRetriever):
267
+
268
+ @property
269
+ @abstractmethod
270
+ def index_class(self) -> Type[InvertedIndex]:
271
+ pass
272
+
273
+ def __init__(self, index_dir: str) -> None:
274
+ self.index = self.index_class.from_saved(index_dir)
275
+
276
+ def get_term_weights(self, query: str, cid: str) -> Dict[str, float]:
277
+ toks = self.index.tokenize(query)
278
+ target_docid = self.index.cid2docid[cid]
279
+ term_weights = {}
280
+ for tok in toks:
281
+ if tok not in self.index.vocab:
282
+ continue
283
+ tid = self.index.vocab[tok]
284
+ posting_list = self.index.posting_lists[tid]
285
+ for docid, tweight in zip(
286
+ posting_list.docid_postings, posting_list.tweight_postings
287
+ ):
288
+ if docid == target_docid:
289
+ term_weights[tok] = tweight
290
+ break
291
+ return term_weights
292
+
293
+ def score(self, query: str, cid: str) -> float:
294
+ return sum(self.get_term_weights(query=query, cid=cid).values())
295
+
296
+ def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]:
297
+ toks = self.index.tokenize(query)
298
+ docid2score: Dict[int, float] = {}
299
+ for tok in toks:
300
+ if tok not in self.index.vocab:
301
+ continue
302
+ tid = self.index.vocab[tok]
303
+ posting_list = self.index.posting_lists[tid]
304
+ for docid, tweight in zip(
305
+ posting_list.docid_postings, posting_list.tweight_postings
306
+ ):
307
+ docid2score.setdefault(docid, 0)
308
+ docid2score[docid] += tweight
309
+ docid2score = dict(
310
+ sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk]
311
+ )
312
+ return {
313
+ self.index.collection_ids[docid]: score
314
+ for docid, score in docid2score.items()
315
+ }
316
+
317
+
318
+ class BM25Retriever(BaseInvertedIndexRetriever):
319
+
320
+ @property
321
+ def index_class(self) -> Type[BM25Index]:
322
+ return BM25Index
323
+
324
+ bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
325
+ bm25_retriever.retrieve("What type of diseases occur when the immune system attacks normal body cells?")
326
+
327
+ """# TASK1: tune b and k1 (4 points)
328
+
329
+ Tune b and k1 on the **dev** split of SciQ using the metric MAP@10. The evaluation function (`evalaute_map`) is provided. Record the values in `plots_k1` and `plots_b`. Do it in a greedy manner: as the influence from b is larger, please first tune b (with k1 fixed to the default value 0.9) and use the best value of b to further tune k1.
330
+
331
+ $${\displaystyle {\text{score}}(D,Q)=\sum _{i=1}^{n}{\text{IDF}}(q_{i})\cdot {\frac {f(q_{i},D)\cdot (k_{1}+1)}{f(q_{i},D)+k_{1}\cdot \left(1-b+b\cdot {\frac {|D|}{\text{avgdl}}}\right)}}}$$
332
+ """
333
+
334
+ from nlp4web_codebase.ir.data_loaders import Split
335
+ import pytrec_eval
336
+
337
+
338
+ def evaluate_map(rankings: Dict[str, Dict[str, float]], split=Split.dev) -> float:
339
+ metric = "map_cut_10"
340
+ qrels = sciq.get_qrels_dict(split)
341
+ evaluator = pytrec_eval.RelevanceEvaluator(sciq.get_qrels_dict(split), (metric,))
342
+ qps = evaluator.evaluate(rankings)
343
+ return float(np.mean([qp[metric] for qp in qps.values()]))
344
+
345
+ """Example of using the pre-requisite code:"""
346
+
347
+ # Loading dataset:
348
+ from nlp4web_codebase.ir.data_loaders.sciq import load_sciq
349
+ sciq = load_sciq()
350
+ counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus))
351
+
352
+ # Building BM25 index and save:
353
+ bm25_index = BM25Index.build_from_documents(
354
+ documents=iter(sciq.corpus),
355
+ ndocs=12160,
356
+ show_progress_bar=True
357
+ )
358
+ bm25_index.save("output/bm25_index")
359
+
360
+ # Loading index and use BM25 retriever to retrieve:
361
+ bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
362
+ print(bm25_retriever.retrieve("What type of diseases occur when the immune system attacks normal body cells?")) # the ranking
363
+
364
+ plots_b: Dict[str, List[float]] = {
365
+ "X": [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
366
+ "Y": []
367
+ }
368
+ plots_k1: Dict[str, List[float]] = {
369
+ "X": [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
370
+ "Y": []
371
+ }
372
+
373
+ ## YOUR_CODE_STARTS_HERE
374
+ # Two steps should be involved:
375
+ # Step 1. Fix k1 value to the default one 0.9,
376
+ # go through all the candidate b values (0, 0.1, ..., 1.0),
377
+ # and record in plots_b["Y"] the corresponding performances obtained via evaluate_map;
378
+ # Step 2. Fix b to the best one in step 1. and do the same for k1.
379
+
380
+ # Hint (on using the pre-requisite code):
381
+ # - One can use the loaded sciq dataset directly (loaded in the pre-requisite code);
382
+ # - One can build bm25_index with `BM25Index.build_from_documents`;
383
+ # - One can use BM25Retriever to load the index and perform retrieval on the dev queries
384
+ # (dev queries can be obtained via sciq.get_split_queries(Split.dev))
385
+
386
+ import numpy as np
387
+
388
+ for x in plots_b["X"]:
389
+ bm25_index = BM25Index.build_from_documents(
390
+ documents=iter(sciq.corpus),
391
+ ndocs=12160,
392
+ show_progress_bar=True,
393
+ k1=0.9,
394
+ b=x
395
+ )
396
+ bm25_index.save("output/bm25_index")
397
+ bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
398
+ rankings = {}
399
+ for query in sciq.get_split_queries(Split.dev):
400
+ ranking = bm25_retriever.retrieve(query=query.text)
401
+ rankings[query.query_id] = ranking
402
+ result = evaluate_map(rankings, split=Split.dev)
403
+ plots_b["Y"].append(result)
404
+
405
+ best_b = plots_b["X"][np.argmax(plots_b["Y"])]
406
+
407
+ for x in plots_k1["X"]:
408
+ bm25_index = BM25Index.build_from_documents(
409
+ documents=iter(sciq.corpus),
410
+ ndocs=12160,
411
+ show_progress_bar=True,
412
+ k1=x,
413
+ b=best_b
414
+ )
415
+ bm25_index.save("output/bm25_index")
416
+ bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
417
+ rankings = {}
418
+ for query in sciq.get_split_queries(Split.dev):
419
+ ranking = bm25_retriever.retrieve(query=query.text)
420
+ rankings[query.query_id] = ranking
421
+ result = evaluate_map(rankings, split=Split.dev)
422
+ plots_k1["Y"].append(result)
423
+
424
+ """Let's check the effectiveness gain on test after this tuning on dev"""
425
+
426
+ default_map = 0.7849
427
+ best_b = plots_b["X"][np.argmax(plots_b["Y"])]
428
+ best_k1 = plots_k1["X"][np.argmax(plots_k1["Y"])]
429
+ bm25_index = BM25Index.build_from_documents(
430
+ documents=iter(sciq.corpus),
431
+ ndocs=12160,
432
+ show_progress_bar=True,
433
+ k1=best_k1,
434
+ b=best_b
435
+ )
436
+ bm25_index.save("output/bm25_index")
437
+ bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
438
+ rankings = {}
439
+ for query in sciq.get_split_queries(Split.test): # note this is now on test
440
+ ranking = bm25_retriever.retrieve(query=query.text)
441
+ rankings[query.query_id] = ranking
442
+ optimized_map = evaluate_map(rankings, split=Split.test) # note this is now on test
443
+
444
+ """# TASK2: CSC matrix and `CSCBM25Index` (12 points)
445
+
446
+ Recall that we use Python lists to implement posting lists, mapping term IDs to the documents in which they appear. This is inefficient due to its naive design. Actually [Compressed Sparse Column matrix](https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csc_matrix.html) is very suitable for storing the posting lists and can boost the efficiency.
447
+
448
+ ## TASK2.1: learn about `scipy.sparse.csc_matrix` (2 point)
449
+
450
+ Convert the matrix \begin{bmatrix}
451
+ 0 & 1 & 0 & 3 \\
452
+ 10 & 2 & 1 & 0 \\
453
+ 0 & 0 & 0 & 9
454
+ \end{bmatrix} to a `csc_matrix` by specifying `data`, `indices`, `indptr` and `shape`.
455
+ """
456
+
457
+ from scipy.sparse._csc import csc_matrix
458
+
459
+
460
+ """## TASK2.2: implement `CSCBM25Index` (4 points)
461
+
462
+ Implement `CSCBM25Index` by completing the missing code. Note that `CSCInvertedIndex` is similar to `InvertedIndex` which we talked about during the class. The main difference is posting lists are represented by a CSC sparse matrix.
463
+ """
464
+
465
+ @dataclass
466
+ class CSCInvertedIndex:
467
+ posting_lists_matrix: csc_matrix # docid -> posting_list
468
+ vocab: Dict[str, int]
469
+ cid2docid: Dict[str, int] # collection_id -> docid
470
+ collection_ids: List[str] # docid -> collection_id
471
+ doc_texts: Optional[List[str]] = None # docid -> document text
472
+
473
+ def save(self, output_dir: str) -> None:
474
+ os.makedirs(output_dir, exist_ok=True)
475
+ with open(os.path.join(output_dir, "index.pkl"), "wb") as f:
476
+ pickle.dump(self, f)
477
+
478
+ @classmethod
479
+ def from_saved(cls: Type[T], saved_dir: str) -> T:
480
+ index = cls(
481
+ posting_lists_matrix=None, vocab={}, cid2docid={}, collection_ids=[], doc_texts=None
482
+ )
483
+ with open(os.path.join(saved_dir, "index.pkl"), "rb") as f:
484
+ index = pickle.load(f)
485
+ return index
486
+
487
+ @dataclass
488
+ class CSCBM25Index(CSCInvertedIndex):
489
+
490
+ @staticmethod
491
+ def tokenize(text: str) -> List[str]:
492
+ return simple_tokenize(text)
493
+
494
+ @staticmethod
495
+ def cache_term_weights(
496
+ posting_lists: List[PostingList],
497
+ total_docs: int,
498
+ avgdl: float,
499
+ dfs: List[int],
500
+ dls: List[int],
501
+ k1: float,
502
+ b: float,
503
+ ) -> csc_matrix:
504
+ """Compute term weights and caching"""
505
+
506
+ ## YOUR_CODE_STARTS_HERE
507
+ data = []
508
+ indices = []
509
+ indptr = [0]
510
+ count = 0
511
+ N = total_docs
512
+ print(N)
513
+ print(len(posting_lists))
514
+ for tid, posting_list in enumerate(
515
+ tqdm.tqdm(posting_lists, desc="Regularizing TFs")
516
+ ):
517
+ idf = CSCBM25Index.calc_idf(df=dfs[tid], N=N)
518
+ #print(len(posting_list.docid_postings))
519
+ for i in range(len(posting_list.docid_postings)):
520
+ docid = posting_list.docid_postings[i]
521
+ tf = posting_list.tweight_postings[i]
522
+ dl = dls[docid]
523
+ regularized_tf = CSCBM25Index.calc_regularized_tf(
524
+ tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b
525
+ )
526
+ # Update the term weight with modified TF * modified IDF:
527
+ data.append(regularized_tf * idf)
528
+ #indices.append(docid)
529
+ indices.append(docid)
530
+ count = count + 1
531
+
532
+ indptr.append(count)
533
+ #shape = (len(posting_lists),len(posting_lists[0].docid_postings))
534
+ output_matrix = csc_matrix((data, indices, indptr),dtype=np.float32) #shape=(N, len(posting_lists)))
535
+ #csc_transpose = output_matrix.transpose()
536
+ #print(len(posting_lists))
537
+ print(output_matrix.shape)
538
+ print(count)
539
+ print(output_matrix.size)
540
+ return output_matrix
541
+ ## YOUR_CODE_ENDS_HERE
542
+
543
+ @staticmethod
544
+ def calc_regularized_tf(
545
+ tf: int, dl: float, avgdl: float, k1: float, b: float
546
+ ) -> float:
547
+ return tf / (tf + k1 * (1 - b + b * dl / avgdl))
548
+
549
+ @staticmethod
550
+ def calc_idf(df: int, N: int):
551
+ return math.log(1 + (N - df + 0.5) / (df + 0.5))
552
+
553
+ @classmethod
554
+ def build_from_documents(
555
+ cls: Type[CSCBM25Index],
556
+ documents: Iterable[Document],
557
+ store_raw: bool = True,
558
+ output_dir: Optional[str] = None,
559
+ ndocs: Optional[int] = None,
560
+ show_progress_bar: bool = True,
561
+ k1: float = 0.9,
562
+ b: float = 0.4,
563
+ ) -> CSCBM25Index:
564
+ # Counting TFs, DFs, doc_lengths, etc.:
565
+ counting = run_counting(
566
+ documents=documents,
567
+ tokenize_fn=CSCBM25Index.tokenize,
568
+ store_raw=store_raw,
569
+ ndocs=ndocs,
570
+ show_progress_bar=show_progress_bar,
571
+ )
572
+
573
+ # Compute term weights and caching:
574
+ posting_lists = counting.posting_lists
575
+ total_docs = len(counting.cid2docid)
576
+ posting_lists_matrix = CSCBM25Index.cache_term_weights(
577
+ posting_lists=posting_lists,
578
+ total_docs=total_docs,
579
+ avgdl=counting.avgdl,
580
+ dfs=counting.dfs,
581
+ dls=counting.dls,
582
+ k1=k1,
583
+ b=b,
584
+ )
585
+
586
+ # Assembly and save:
587
+ index = CSCBM25Index(
588
+ posting_lists_matrix=posting_lists_matrix,
589
+ vocab=counting.vocab,
590
+ cid2docid=counting.cid2docid,
591
+ collection_ids=counting.collection_ids,
592
+ doc_texts=counting.doc_texts,
593
+ )
594
+ return index
595
+
596
+ csc_bm25_index = CSCBM25Index.build_from_documents(
597
+ documents=iter(sciq.corpus),
598
+ ndocs=12160,
599
+ show_progress_bar=True,
600
+ k1=best_k1,
601
+ b=best_b
602
+ )
603
+ csc_bm25_index.save("output/csc_bm25_index")
604
+
605
+
606
+ class BaseCSCInvertedIndexRetriever(BaseRetriever):
607
+
608
+ @property
609
+ @abstractmethod
610
+ def index_class(self) -> Type[CSCInvertedIndex]:
611
+ pass
612
+
613
+ def __init__(self, index_dir: str) -> None:
614
+ self.index = self.index_class.from_saved(index_dir)
615
+
616
+ def get_term_weights(self, query: str, cid: str) -> Dict[str, float]:
617
+ ## YOUR_CODE_STARTS_HERE
618
+ toks = self.index.tokenize(query)
619
+ target_docid = self.index.cid2docid[cid]
620
+ term_weights = {}
621
+ matrix = self.index.posting_lists_matrix.astype(np.float64)
622
+ for tok in toks:
623
+ if tok not in self.index.vocab:
624
+ continue
625
+ tid = self.index.vocab[tok]
626
+ if matrix[target_docid, tid]!= 0:
627
+ term_weights[tok] = matrix[target_docid, tid]
628
+
629
+ return term_weights
630
+ ## YOUR_CODE_ENDS_HERE
631
+
632
+ def score(self, query: str, cid: str) -> float:
633
+ return sum(self.get_term_weights(query=query, cid=cid).values())
634
+
635
+ def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]:
636
+ ## YOUR_CODE_STARTS_HERE
637
+ toks = self.index.tokenize(query)
638
+ docid2score: Dict[int, float] = {}
639
+ matrix = self.index.posting_lists_matrix.astype(np.float64)
640
+ for tok in toks:
641
+ if tok not in self.index.vocab:
642
+ continue
643
+ tid = self.index.vocab[tok]
644
+
645
+ #posting_list = self.index.posting_lists[tid]
646
+ #for i, docid in enumerate(posting_list.docid_postings):
647
+ #tweight = matrix[docid, i]
648
+ #docid2score.setdefault(docid, 0)
649
+ #docid2score[docid] += tweight
650
+
651
+ for docid in range(matrix.shape[0]):
652
+ tweight = matrix[docid, tid]
653
+ docid2score.setdefault(docid, 0)
654
+ docid2score[docid] += tweight
655
+
656
+ docid2score = dict(
657
+ sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk]
658
+ )
659
+ return {
660
+ self.index.collection_ids[docid]: score
661
+ for docid, score in docid2score.items()
662
+ }
663
+
664
+
665
+
666
+ ## YOUR_CODE_ENDS_HERE
667
+
668
+
669
+ class CSCBM25Retriever(BaseCSCInvertedIndexRetriever):
670
+
671
+ @property
672
+ def index_class(self) -> Type[CSCBM25Index]:
673
+ return CSCBM25Index
674
+
675
+
676
+
677
+ """# TASK3: a search-engine demo based on Huggingface space (4 points)
678
+
679
+ ## TASK3.1: create the gradio app (2 point)
680
+
681
+ Create a gradio app to demo the BM25 search engine index on SciQ. The app should have a single input variable for the query (of type `str`) and a single output variable for the returned ranking (of type `List[Hit]` in the code below). Please use the BM25 system with default k1 and b values.
682
+
683
+ Hint: it should use a "search" function of signature:
684
+
685
+ ```python
686
+ def search(query: str) -> List[Hit]:
687
+ ...
688
+ ```
689
+ """
690
+
691
+
692
+ import gradio as gr
693
+ from typing import TypedDict
694
+
695
+ class Hit(TypedDict):
696
+ cid: str
697
+ score: float
698
+ text: str
699
+
700
+ demo: Optional[gr.Interface] = None # Assign your gradio demo to this variable
701
+ return_type = List[Hit]
702
+
703
+ ## YOUR_CODE_STARTS_HERE
704
+ bm25_index = BM25Index.build_from_documents(
705
+ documents=iter(sciq.corpus),
706
+ ndocs=12160,
707
+ show_progress_bar=True,
708
+ )
709
+ bm25_index.save("output/bm25_index")
710
+ bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
711
+
712
+ def search(query: str) -> List[Hit]:
713
+ l = []
714
+ for x,y in bm25_retriever.retrieve(query).items():
715
+ hit_object: Hit = {
716
+ "cid": x,
717
+ "score": y,
718
+ "text": sciq.corpus[bm25_retriever.index.cid2docid[x]]
719
+ }
720
+ l.append(hit_object)
721
+ return l
722
+ #print(search("What type of organism is commonly used in preparation of foods such as cheese and yogurt?"))
723
+ demo = gr.Interface(
724
+ fn=search,
725
+ inputs="text",
726
+ outputs= "text",
727
+ )
728
+ ## YOUR_CODE_ENDS_HERE
729
+ demo.launch()
730
+
731
+
732
+
733
+
nlp4web_codebase/__init__.py ADDED
File without changes
nlp4web_codebase/ir/__init__.py ADDED
File without changes
nlp4web_codebase/ir/analysis.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, List, Optional, Protocol
3
+ import pandas as pd
4
+ import tqdm
5
+ import ujson
6
+ from nlp4web_codebase.ir.data_loaders import IRDataset
7
+
8
+
9
+ def round_dict(obj: Dict[str, float], ndigits: int = 4) -> Dict[str, float]:
10
+ return {k: round(v, ndigits=ndigits) for k, v in obj.items()}
11
+
12
+
13
+ def sort_dict(obj: Dict[str, float], reverse: bool = True) -> Dict[str, float]:
14
+ return dict(sorted(obj.items(), key=lambda pair: pair[1], reverse=reverse))
15
+
16
+
17
+ def save_ranking_results(
18
+ output_dir: str,
19
+ query_ids: List[str],
20
+ rankings: List[Dict[str, float]],
21
+ query_performances_lists: List[Dict[str, float]],
22
+ cid2tweights_lists: Optional[List[Dict[str, Dict[str, float]]]] = None,
23
+ ):
24
+ os.makedirs(output_dir, exist_ok=True)
25
+ output_path = os.path.join(output_dir, "ranking_results.jsonl")
26
+ rows = []
27
+ for i, (query_id, ranking, query_performances) in enumerate(
28
+ zip(query_ids, rankings, query_performances_lists)
29
+ ):
30
+ row = {
31
+ "query_id": query_id,
32
+ "ranking": round_dict(ranking),
33
+ "query_performances": round_dict(query_performances),
34
+ "cid2tweights": {},
35
+ }
36
+ if cid2tweights_lists is not None:
37
+ row["cid2tweights"] = {
38
+ cid: round_dict(tws) for cid, tws in cid2tweights_lists[i].items()
39
+ }
40
+ rows.append(row)
41
+ pd.DataFrame(rows).to_json(
42
+ output_path,
43
+ orient="records",
44
+ lines=True,
45
+ )
46
+
47
+
48
+ class TermWeightingFunction(Protocol):
49
+ def __call__(self, query: str, cid: str) -> Dict[str, float]: ...
50
+
51
+
52
+ def compare(
53
+ dataset: IRDataset,
54
+ results_path1: str,
55
+ results_path2: str,
56
+ output_dir: str,
57
+ main_metric: str = "recip_rank",
58
+ system1: Optional[str] = None,
59
+ system2: Optional[str] = None,
60
+ term_weighting_fn1: Optional[TermWeightingFunction] = None,
61
+ term_weighting_fn2: Optional[TermWeightingFunction] = None,
62
+ ) -> None:
63
+ os.makedirs(output_dir, exist_ok=True)
64
+ df1 = pd.read_json(results_path1, orient="records", lines=True)
65
+ df2 = pd.read_json(results_path2, orient="records", lines=True)
66
+ assert len(df1) == len(df2)
67
+ all_qrels = {}
68
+ for split in dataset.split2qrels:
69
+ all_qrels.update(dataset.get_qrels_dict(split))
70
+ qid2query = {query.query_id: query for query in dataset.queries}
71
+ cid2doc = {doc.collection_id: doc for doc in dataset.corpus}
72
+ diff_col = f"{main_metric}:qp1-qp2"
73
+ merged = pd.merge(df1, df2, on="query_id", how="outer")
74
+ rows = []
75
+ for _, example in tqdm.tqdm(merged.iterrows(), desc="Comparing", total=len(merged)):
76
+ docs = {cid: cid2doc[cid].text for cid in dict(example["ranking_x"])}
77
+ docs.update({cid: cid2doc[cid].text for cid in dict(example["ranking_y"])})
78
+ query_id = example["query_id"]
79
+ row = {
80
+ "query_id": query_id,
81
+ "query": qid2query[query_id].text,
82
+ diff_col: example["query_performances_x"][main_metric]
83
+ - example["query_performances_y"][main_metric],
84
+ "ranking1": ujson.dumps(example["ranking_x"], indent=4),
85
+ "ranking2": ujson.dumps(example["ranking_y"], indent=4),
86
+ "docs": ujson.dumps(docs, indent=4),
87
+ "query_performances1": ujson.dumps(
88
+ example["query_performances_x"], indent=4
89
+ ),
90
+ "query_performances2": ujson.dumps(
91
+ example["query_performances_y"], indent=4
92
+ ),
93
+ "qrels": ujson.dumps(all_qrels[query_id], indent=4),
94
+ }
95
+ if term_weighting_fn1 is not None and term_weighting_fn2 is not None:
96
+ all_cids = set(example["ranking_x"]) | set(example["ranking_y"])
97
+ cid2tweights1 = {}
98
+ cid2tweights2 = {}
99
+ ranking1 = {}
100
+ ranking2 = {}
101
+ for cid in all_cids:
102
+ tweights1 = term_weighting_fn1(query=qid2query[query_id].text, cid=cid)
103
+ tweights2 = term_weighting_fn2(query=qid2query[query_id].text, cid=cid)
104
+ ranking1[cid] = sum(tweights1.values())
105
+ ranking2[cid] = sum(tweights2.values())
106
+ cid2tweights1[cid] = tweights1
107
+ cid2tweights2[cid] = tweights2
108
+ ranking1 = sort_dict(ranking1)
109
+ ranking2 = sort_dict(ranking2)
110
+ row["ranking1"] = ujson.dumps(ranking1, indent=4)
111
+ row["ranking2"] = ujson.dumps(ranking2, indent=4)
112
+ cid2tweights1 = {cid: cid2tweights1[cid] for cid in ranking1}
113
+ cid2tweights2 = {cid: cid2tweights2[cid] for cid in ranking2}
114
+ row["cid2tweights1"] = ujson.dumps(cid2tweights1, indent=4)
115
+ row["cid2tweights2"] = ujson.dumps(cid2tweights2, indent=4)
116
+ rows.append(row)
117
+ table = pd.DataFrame(rows).sort_values(by=diff_col, ascending=False)
118
+ output_path = os.path.join(output_dir, f"compare-{system1}_vs_{system2}.tsv")
119
+ table.to_csv(output_path, sep="\t", index=False)
120
+
121
+
122
+ # if __name__ == "__main__":
123
+ # # python -m lecture2.bm25.analysis
124
+ # from nlp4web_codebase.ir.data_loaders.sciq import load_sciq
125
+ # from lecture2.bm25.bm25_retriever import BM25Retriever
126
+ # from lecture2.bm25.tfidf_retriever import TFIDFRetriever
127
+ # import numpy as np
128
+
129
+ # sciq = load_sciq()
130
+ # system1 = "bm25"
131
+ # system2 = "tfidf"
132
+ # results_path1 = f"output/sciq-{system1}/results/ranking_results.jsonl"
133
+ # results_path2 = f"output/sciq-{system2}/results/ranking_results.jsonl"
134
+ # index_dir1 = f"output/sciq-{system1}"
135
+ # index_dir2 = f"output/sciq-{system2}"
136
+ # compare(
137
+ # dataset=sciq,
138
+ # results_path1=results_path1,
139
+ # results_path2=results_path2,
140
+ # output_dir=f"output/sciq-{system1}_vs_{system2}",
141
+ # system1=system1,
142
+ # system2=system2,
143
+ # term_weighting_fn1=BM25Retriever(index_dir1).get_term_weights,
144
+ # term_weighting_fn2=TFIDFRetriever(index_dir2).get_term_weights,
145
+ # )
146
+
147
+ # # bias on #shared_terms of TFIDF:
148
+ # df1 = pd.read_json(results_path1, orient="records", lines=True)
149
+ # df2 = pd.read_json(results_path2, orient="records", lines=True)
150
+ # merged = pd.merge(df1, df2, on="query_id", how="outer")
151
+ # nterms1 = []
152
+ # nterms2 = []
153
+ # for _, row in merged.iterrows():
154
+ # nterms1.append(len(list(dict(row["cid2tweights_x"]).values())[0]))
155
+ # nterms2.append(len(list(dict(row["cid2tweights_y"]).values())[0]))
156
+ # percentiles = (5, 25, 50, 75, 95)
157
+ # print(system1, np.percentile(nterms1, percentiles), np.mean(nterms1).round(2))
158
+ # print(system2, np.percentile(nterms2, percentiles), np.mean(nterms2).round(2))
159
+ # # bm25 [ 3. 4. 5. 7. 11.] 5.64
160
+ # # tfidf [1. 2. 3. 5. 9.] 3.58
nlp4web_codebase/ir/data_loaders/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from enum import Enum
3
+ from typing import Dict, List
4
+ from nlp4web_codebase.ir.data_loaders.dm import Document, Query, QRel
5
+
6
+
7
+ class Split(str, Enum):
8
+ train = "train"
9
+ dev = "dev"
10
+ test = "test"
11
+
12
+
13
+ @dataclass
14
+ class IRDataset:
15
+ corpus: List[Document]
16
+ queries: List[Query]
17
+ split2qrels: Dict[Split, List[QRel]]
18
+
19
+ def get_stats(self) -> Dict[str, int]:
20
+ stats = {"|corpus|": len(self.corpus), "|queries|": len(self.queries)}
21
+ for split, qrels in self.split2qrels.items():
22
+ stats[f"|qrels-{split}|"] = len(qrels)
23
+ return stats
24
+
25
+ def get_qrels_dict(self, split: Split) -> Dict[str, Dict[str, int]]:
26
+ qrels_dict = {}
27
+ for qrel in self.split2qrels[split]:
28
+ qrels_dict.setdefault(qrel.query_id, {})
29
+ qrels_dict[qrel.query_id][qrel.collection_id] = qrel.relevance
30
+ return qrels_dict
31
+
32
+ def get_split_queries(self, split: Split) -> List[Query]:
33
+ qrels = self.split2qrels[split]
34
+ qids = {qrel.query_id for qrel in qrels}
35
+ return list(filter(lambda query: query.query_id in qids, self.queries))
nlp4web_codebase/ir/data_loaders/dm.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+
5
+ @dataclass
6
+ class Document:
7
+ collection_id: str
8
+ text: str
9
+
10
+
11
+ @dataclass
12
+ class Query:
13
+ query_id: str
14
+ text: str
15
+
16
+
17
+ @dataclass
18
+ class QRel:
19
+ query_id: str
20
+ collection_id: str
21
+ relevance: int
22
+ answer: Optional[str] = None
nlp4web_codebase/ir/data_loaders/sciq.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+ from nlp4web_codebase.ir.data_loaders import IRDataset, Split
3
+ from nlp4web_codebase.ir.data_loaders.dm import Document, Query, QRel
4
+ from datasets import load_dataset
5
+ import joblib
6
+
7
+
8
+ @(joblib.Memory(".cache").cache)
9
+ def load_sciq(verbose: bool = False) -> IRDataset:
10
+ train = load_dataset("allenai/sciq", split="train")
11
+ validation = load_dataset("allenai/sciq", split="validation")
12
+ test = load_dataset("allenai/sciq", split="test")
13
+ data = {Split.train: train, Split.dev: validation, Split.test: test}
14
+
15
+ # Each duplicated record is the same to each other:
16
+ df = train.to_pandas() + validation.to_pandas() + test.to_pandas()
17
+ for question, group in df.groupby("question"):
18
+ assert len(set(group["support"].tolist())) == len(group)
19
+ assert len(set(group["correct_answer"].tolist())) == len(group)
20
+
21
+ # Build:
22
+ corpus = []
23
+ queries = []
24
+ split2qrels: Dict[str, List[dict]] = {}
25
+ question2id = {}
26
+ support2id = {}
27
+ for split, rows in data.items():
28
+ if verbose:
29
+ print(f"|raw_{split}|", len(rows))
30
+ split2qrels[split] = []
31
+ for i, row in enumerate(rows):
32
+ example_id = f"{split}-{i}"
33
+ support: str = row["support"]
34
+ if len(support.strip()) == 0:
35
+ continue
36
+ question = row["question"]
37
+ if len(support.strip()) == 0:
38
+ continue
39
+ if support in support2id:
40
+ continue
41
+ else:
42
+ support2id[support] = example_id
43
+ if question in question2id:
44
+ continue
45
+ else:
46
+ question2id[question] = example_id
47
+ doc = {"collection_id": example_id, "text": support}
48
+ query = {"query_id": example_id, "text": row["question"]}
49
+ qrel = {
50
+ "query_id": example_id,
51
+ "collection_id": example_id,
52
+ "relevance": 1,
53
+ "answer": row["correct_answer"],
54
+ }
55
+ corpus.append(Document(**doc))
56
+ queries.append(Query(**query))
57
+ split2qrels[split].append(QRel(**qrel))
58
+
59
+ # Assembly and return:
60
+ return IRDataset(corpus=corpus, queries=queries, split2qrels=split2qrels)
61
+
62
+
63
+ if __name__ == "__main__":
64
+ # python -m nlp4web_codebase.ir.data_loaders.sciq
65
+ import ujson
66
+ import time
67
+
68
+ start = time.time()
69
+ dataset = load_sciq(verbose=True)
70
+ print(f"Loading costs: {time.time() - start}s")
71
+ print(ujson.dumps(dataset.get_stats(), indent=4))
72
+ # ________________________________________________________________________________
73
+ # [Memory] Calling __main__--home-kwang-research-nlp4web-ir-exercise-nlp4web-nlp4web-ir-data_loaders-sciq.load_sciq...
74
+ # load_sciq(verbose=True)
75
+ # |raw_train| 11679
76
+ # |raw_dev| 1000
77
+ # |raw_test| 1000
78
+ # ________________________________________________________load_sciq - 7.3s, 0.1min
79
+ # Loading costs: 7.260092735290527s
80
+ # {
81
+ # "|corpus|": 12160,
82
+ # "|queries|": 12160,
83
+ # "|qrels-train|": 10409,
84
+ # "|qrels-dev|": 875,
85
+ # "|qrels-test|": 876
86
+ # }
nlp4web_codebase/ir/models/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, Dict, Type
3
+
4
+
5
+ class BaseRetriever(ABC):
6
+
7
+ @property
8
+ @abstractmethod
9
+ def index_class(self) -> Type[Any]:
10
+ pass
11
+
12
+ def get_term_weights(self, query: str, cid: str) -> Dict[str, float]:
13
+ raise NotImplementedError
14
+
15
+ @abstractmethod
16
+ def score(self, query: str, cid: str) -> float:
17
+ pass
18
+
19
+ @abstractmethod
20
+ def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]:
21
+ pass
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ nltk==3.8.1
2
+ numpy==1.26.4
3
+ scipy==1.13.1
4
+ pandas==2.2.2
5
+ tqdm==4.66.5
6
+ ujson==5.10.0
7
+ joblib==1.4.2
8
+ datasets==3.0.1
9
+ pytrec_eval==0.5
10
+ gradio
setup.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+
4
+ with open("README.md", "r", encoding="utf-8") as fh:
5
+ readme = fh.read()
6
+
7
+ setup(
8
+ name="nlp4web-codebase",
9
+ version="0.0.0",
10
+ author="Kexin Wang",
11
+ author_email="kexin.wang.2049@gmail.com",
12
+ description="Codebase of teaching materials for NLP4Web.",
13
+ long_description=readme,
14
+ long_description_content_type="text/markdown",
15
+ url="https://https://github.com/kwang2049/nlp4web-codebase",
16
+ project_urls={
17
+ "Bug Tracker": "https://github.com/kwang2049/nlp4web-codebase/issues",
18
+ },
19
+ packages=find_packages(),
20
+ classifiers=[
21
+ "Programming Language :: Python :: 3",
22
+ "License :: OSI Approved :: Apache Software License",
23
+ "Operating System :: OS Independent",
24
+ ],
25
+ python_requires=">=3.10",
26
+ install_requires=[
27
+ "nltk==3.8.1",
28
+ "numpy==1.26.4",
29
+ "scipy==1.13.1",
30
+ "pandas==2.2.2",
31
+ "tqdm==4.66.5",
32
+ "ujson==5.10.0",
33
+ "joblib==1.4.2",
34
+ "datasets==3.0.1",
35
+ "pytrec_eval==0.5",
36
+ ],
37
+ )