miki5799 commited on
Commit
5270fb9
1 Parent(s): 54d1112

Extract jupyter notebook and nlp4web-codebase contents to hf shitspace repo

Browse files
app.py CHANGED
@@ -1,3 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  from typing import TypedDict
3
  import pandas as pd
 
1
+ from dataclasses import dataclass
2
+ import pickle
3
+ import os
4
+ from typing import Iterable, Callable, List, Dict, Optional, Type, TypeVar
5
+ from nlp4web_codebase.ir.data_loaders.dm import Document
6
+ from collections import Counter
7
+ import tqdm
8
+ import re
9
+ import nltk
10
+ nltk.download("stopwords", quiet=True)
11
+ from nltk.corpus import stopwords as nltk_stopwords
12
+
13
+ LANGUAGE = "english"
14
+ word_splitter = re.compile(r"(?u)\b\w\w+\b").findall
15
+ stopwords = set(nltk_stopwords.words(LANGUAGE))
16
+
17
+
18
+ def word_splitting(text: str) -> List[str]:
19
+ return word_splitter(text.lower())
20
+
21
+ def lemmatization(words: List[str]) -> List[str]:
22
+ return words # We ignore lemmatization here for simplicity
23
+
24
+ def simple_tokenize(text: str) -> List[str]:
25
+ words = word_splitting(text)
26
+ tokenized = list(filter(lambda w: w not in stopwords, words))
27
+ tokenized = lemmatization(tokenized)
28
+ return tokenized
29
+
30
+ T = TypeVar("T", bound="InvertedIndex")
31
+
32
+ @dataclass
33
+ class PostingList:
34
+ term: str # The term
35
+ docid_postings: List[int] # docid_postings[i] means the docid (int) of the i-th associated posting
36
+ tweight_postings: List[float] # tweight_postings[i] means the term weight (float) of the i-th associated posting
37
+
38
+
39
+ @dataclass
40
+ class InvertedIndex:
41
+ posting_lists: List[PostingList] # docid -> posting_list
42
+ vocab: Dict[str, int]
43
+ cid2docid: Dict[str, int] # collection_id -> docid
44
+ collection_ids: List[str] # docid -> collection_id
45
+ doc_texts: Optional[List[str]] = None # docid -> document text
46
+
47
+ def save(self, output_dir: str) -> None:
48
+ os.makedirs(output_dir, exist_ok=True)
49
+ with open(os.path.join(output_dir, "index.pkl"), "wb") as f:
50
+ pickle.dump(self, f)
51
+
52
+ @classmethod
53
+ def from_saved(cls: Type[T], saved_dir: str) -> T:
54
+ index = cls(
55
+ posting_lists=[], vocab={}, cid2docid={}, collection_ids=[], doc_texts=None
56
+ )
57
+ with open(os.path.join(saved_dir, "index.pkl"), "rb") as f:
58
+ index = pickle.load(f)
59
+ return index
60
+
61
+
62
+ # The output of the counting function:
63
+ @dataclass
64
+ class Counting:
65
+ posting_lists: List[PostingList]
66
+ vocab: Dict[str, int]
67
+ cid2docid: Dict[str, int]
68
+ collection_ids: List[str]
69
+ dfs: List[int] # tid -> df
70
+ dls: List[int] # docid -> doc length
71
+ avgdl: float
72
+ nterms: int
73
+ doc_texts: Optional[List[str]] = None
74
+
75
+ def run_counting(
76
+ documents: Iterable[Document],
77
+ tokenize_fn: Callable[[str], List[str]] = simple_tokenize,
78
+ store_raw: bool = True, # store the document text in doc_texts
79
+ ndocs: Optional[int] = None,
80
+ show_progress_bar: bool = True,
81
+ ) -> Counting:
82
+ """Counting TFs, DFs, doc_lengths, etc."""
83
+ posting_lists: List[PostingList] = []
84
+ vocab: Dict[str, int] = {}
85
+ cid2docid: Dict[str, int] = {}
86
+ collection_ids: List[str] = []
87
+ dfs: List[int] = [] # tid -> df
88
+ dls: List[int] = [] # docid -> doc length
89
+ nterms: int = 0
90
+ doc_texts: Optional[List[str]] = []
91
+ for doc in tqdm.tqdm(
92
+ documents,
93
+ desc="Counting",
94
+ total=ndocs,
95
+ disable=not show_progress_bar,
96
+ ):
97
+ if doc.collection_id in cid2docid:
98
+ continue
99
+ collection_ids.append(doc.collection_id)
100
+ docid = cid2docid.setdefault(doc.collection_id, len(cid2docid))
101
+ toks = tokenize_fn(doc.text)
102
+ tok2tf = Counter(toks)
103
+ dls.append(sum(tok2tf.values()))
104
+ for tok, tf in tok2tf.items():
105
+ nterms += tf
106
+ tid = vocab.get(tok, None)
107
+ if tid is None:
108
+ posting_lists.append(
109
+ PostingList(term=tok, docid_postings=[], tweight_postings=[])
110
+ )
111
+ tid = vocab.setdefault(tok, len(vocab))
112
+ posting_lists[tid].docid_postings.append(docid)
113
+ posting_lists[tid].tweight_postings.append(tf)
114
+ if tid < len(dfs):
115
+ dfs[tid] += 1
116
+ else:
117
+ dfs.append(0)
118
+ if store_raw:
119
+ doc_texts.append(doc.text)
120
+ else:
121
+ doc_texts = None
122
+ return Counting(
123
+ posting_lists=posting_lists,
124
+ vocab=vocab,
125
+ cid2docid=cid2docid,
126
+ collection_ids=collection_ids,
127
+ dfs=dfs,
128
+ dls=dls,
129
+ avgdl=sum(dls) / len(dls),
130
+ nterms=nterms,
131
+ doc_texts=doc_texts,
132
+ )
133
+
134
+ from nlp4web_codebase.ir.data_loaders.sciq import load_sciq
135
+ sciq = load_sciq()
136
+ counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus))
137
+
138
+ from __future__ import annotations
139
+ from dataclasses import asdict, dataclass
140
+ import math
141
+ import os
142
+ from typing import Iterable, List, Optional, Type
143
+ import tqdm
144
+ from nlp4web_codebase.ir.data_loaders.dm import Document
145
+
146
+
147
+ @dataclass
148
+ class BM25Index(InvertedIndex):
149
+
150
+ @staticmethod
151
+ def tokenize(text: str) -> List[str]:
152
+ return simple_tokenize(text)
153
+
154
+ @staticmethod
155
+ def cache_term_weights(
156
+ posting_lists: List[PostingList],
157
+ total_docs: int,
158
+ avgdl: float,
159
+ dfs: List[int],
160
+ dls: List[int],
161
+ k1: float,
162
+ b: float,
163
+ ) -> None:
164
+ """Compute term weights and caching"""
165
+
166
+ N = total_docs
167
+ for tid, posting_list in enumerate(
168
+ tqdm.tqdm(posting_lists, desc="Regularizing TFs")
169
+ ):
170
+ idf = BM25Index.calc_idf(df=dfs[tid], N=N)
171
+ for i in range(len(posting_list.docid_postings)):
172
+ docid = posting_list.docid_postings[i]
173
+ tf = posting_list.tweight_postings[i]
174
+ dl = dls[docid]
175
+ regularized_tf = BM25Index.calc_regularized_tf(
176
+ tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b
177
+ )
178
+ posting_list.tweight_postings[i] = regularized_tf * idf
179
+
180
+ @staticmethod
181
+ def calc_regularized_tf(
182
+ tf: int, dl: float, avgdl: float, k1: float, b: float
183
+ ) -> float:
184
+ return tf / (tf + k1 * (1 - b + b * dl / avgdl))
185
+
186
+ @staticmethod
187
+ def calc_idf(df: int, N: int):
188
+ return math.log(1 + (N - df + 0.5) / (df + 0.5))
189
+
190
+ @classmethod
191
+ def build_from_documents(
192
+ cls: Type[BM25Index],
193
+ documents: Iterable[Document],
194
+ store_raw: bool = True,
195
+ output_dir: Optional[str] = None,
196
+ ndocs: Optional[int] = None,
197
+ show_progress_bar: bool = True,
198
+ k1: float = 0.9,
199
+ b: float = 0.4,
200
+ ) -> BM25Index:
201
+ # Counting TFs, DFs, doc_lengths, etc.:
202
+ counting = run_counting(
203
+ documents=documents,
204
+ tokenize_fn=BM25Index.tokenize,
205
+ store_raw=store_raw,
206
+ ndocs=ndocs,
207
+ show_progress_bar=show_progress_bar,
208
+ )
209
+
210
+ # Compute term weights and caching:
211
+ posting_lists = counting.posting_lists
212
+ total_docs = len(counting.cid2docid)
213
+ BM25Index.cache_term_weights(
214
+ posting_lists=posting_lists,
215
+ total_docs=total_docs,
216
+ avgdl=counting.avgdl,
217
+ dfs=counting.dfs,
218
+ dls=counting.dls,
219
+ k1=k1,
220
+ b=b,
221
+ )
222
+
223
+ # Assembly and save:
224
+ index = BM25Index(
225
+ posting_lists=posting_lists,
226
+ vocab=counting.vocab,
227
+ cid2docid=counting.cid2docid,
228
+ collection_ids=counting.collection_ids,
229
+ doc_texts=counting.doc_texts,
230
+ )
231
+ return index
232
+
233
+ bm25_index = BM25Index.build_from_documents(
234
+ documents=iter(sciq.corpus),
235
+ ndocs=12160,
236
+ show_progress_bar=True,
237
+ )
238
+ bm25_index.save("output/bm25_index")
239
+
240
+ from nlp4web_codebase.ir.models import BaseRetriever
241
+ from typing import Type
242
+ from abc import abstractmethod
243
+
244
+
245
+ class BaseInvertedIndexRetriever(BaseRetriever):
246
+
247
+ @property
248
+ @abstractmethod
249
+ def index_class(self) -> Type[InvertedIndex]:
250
+ pass
251
+
252
+ def __init__(self, index_dir: str) -> None:
253
+ self.index = self.index_class.from_saved(index_dir)
254
+
255
+ def get_term_weights(self, query: str, cid: str) -> Dict[str, float]:
256
+ toks = self.index.tokenize(query)
257
+ target_docid = self.index.cid2docid[cid]
258
+ term_weights = {}
259
+ for tok in toks:
260
+ if tok not in self.index.vocab:
261
+ continue
262
+ tid = self.index.vocab[tok]
263
+ posting_list = self.index.posting_lists[tid]
264
+ for docid, tweight in zip(
265
+ posting_list.docid_postings, posting_list.tweight_postings
266
+ ):
267
+ if docid == target_docid:
268
+ term_weights[tok] = tweight
269
+ break
270
+ return term_weights
271
+
272
+ def score(self, query: str, cid: str) -> float:
273
+ return sum(self.get_term_weights(query=query, cid=cid).values())
274
+
275
+ def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]:
276
+ toks = self.index.tokenize(query)
277
+ docid2score: Dict[int, float] = {}
278
+ for tok in toks:
279
+ if tok not in self.index.vocab:
280
+ continue
281
+ tid = self.index.vocab[tok]
282
+ posting_list = self.index.posting_lists[tid]
283
+ for docid, tweight in zip(
284
+ posting_list.docid_postings, posting_list.tweight_postings
285
+ ):
286
+ docid2score.setdefault(docid, 0)
287
+ docid2score[docid] += tweight
288
+ docid2score = dict(
289
+ sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk]
290
+ )
291
+ return {
292
+ self.index.collection_ids[docid]: score
293
+ for docid, score in docid2score.items()
294
+ }
295
+
296
+
297
+ class BM25Retriever(BaseInvertedIndexRetriever):
298
+
299
+ @property
300
+ def index_class(self) -> Type[BM25Index]:
301
+ return BM25Index
302
+
303
+ bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
304
+ bm25_retriever.retrieve("What type of diseases occur when the immune system attacks normal body cells?")
305
+
306
+ plots_b = {'X': [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], 'Y': [0.694980045351474, 0.8126195011337869, 0.821528798185941, 0.8218562358276644, 0.8222244897959182, 0.8195024943310657, 0.8182163265306123, 0.8174734693877551, 0.8139020408163266, 0.8116893424036281, 0.8083002267573697]} #TODO: Replace
307
+ plots_k1 = {'X': [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], 'Y': [0.7345419501133786, 0.7668607709750567, 0.779508843537415, 0.7900947845804988, 0.8015931972789115, 0.8103560090702948, 0.812374149659864, 0.8156743764172336, 0.8194036281179138, 0.8222244897959182, 0.8221800453514739]}
308
+
309
+ best_b = plots_b["X"][np.argmax(plots_b["Y"])]
310
+ best_k1 = plots_k1["X"][np.argmax(plots_k1["Y"])]
311
+ bm25_index = BM25Index.build_from_documents(
312
+ documents=iter(sciq.corpus),
313
+ ndocs=12160,
314
+ show_progress_bar=True,
315
+ k1=best_k1,
316
+ b=best_b
317
+ )
318
+
319
  import gradio as gr
320
  from typing import TypedDict
321
  import pandas as pd
nlp4web_codebase/__init__.py ADDED
File without changes
nlp4web_codebase/ir/__init__.py ADDED
File without changes
nlp4web_codebase/ir/analysis.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, List, Optional, Protocol
3
+ import pandas as pd
4
+ import tqdm
5
+ import ujson
6
+ from nlp4web_codebase.ir.data_loaders import IRDataset
7
+
8
+
9
+ def round_dict(obj: Dict[str, float], ndigits: int = 4) -> Dict[str, float]:
10
+ return {k: round(v, ndigits=ndigits) for k, v in obj.items()}
11
+
12
+
13
+ def sort_dict(obj: Dict[str, float], reverse: bool = True) -> Dict[str, float]:
14
+ return dict(sorted(obj.items(), key=lambda pair: pair[1], reverse=reverse))
15
+
16
+
17
+ def save_ranking_results(
18
+ output_dir: str,
19
+ query_ids: List[str],
20
+ rankings: List[Dict[str, float]],
21
+ query_performances_lists: List[Dict[str, float]],
22
+ cid2tweights_lists: Optional[List[Dict[str, Dict[str, float]]]] = None,
23
+ ):
24
+ os.makedirs(output_dir, exist_ok=True)
25
+ output_path = os.path.join(output_dir, "ranking_results.jsonl")
26
+ rows = []
27
+ for i, (query_id, ranking, query_performances) in enumerate(
28
+ zip(query_ids, rankings, query_performances_lists)
29
+ ):
30
+ row = {
31
+ "query_id": query_id,
32
+ "ranking": round_dict(ranking),
33
+ "query_performances": round_dict(query_performances),
34
+ "cid2tweights": {},
35
+ }
36
+ if cid2tweights_lists is not None:
37
+ row["cid2tweights"] = {
38
+ cid: round_dict(tws) for cid, tws in cid2tweights_lists[i].items()
39
+ }
40
+ rows.append(row)
41
+ pd.DataFrame(rows).to_json(
42
+ output_path,
43
+ orient="records",
44
+ lines=True,
45
+ )
46
+
47
+
48
+ class TermWeightingFunction(Protocol):
49
+ def __call__(self, query: str, cid: str) -> Dict[str, float]: ...
50
+
51
+
52
+ def compare(
53
+ dataset: IRDataset,
54
+ results_path1: str,
55
+ results_path2: str,
56
+ output_dir: str,
57
+ main_metric: str = "recip_rank",
58
+ system1: Optional[str] = None,
59
+ system2: Optional[str] = None,
60
+ term_weighting_fn1: Optional[TermWeightingFunction] = None,
61
+ term_weighting_fn2: Optional[TermWeightingFunction] = None,
62
+ ) -> None:
63
+ os.makedirs(output_dir, exist_ok=True)
64
+ df1 = pd.read_json(results_path1, orient="records", lines=True)
65
+ df2 = pd.read_json(results_path2, orient="records", lines=True)
66
+ assert len(df1) == len(df2)
67
+ all_qrels = {}
68
+ for split in dataset.split2qrels:
69
+ all_qrels.update(dataset.get_qrels_dict(split))
70
+ qid2query = {query.query_id: query for query in dataset.queries}
71
+ cid2doc = {doc.collection_id: doc for doc in dataset.corpus}
72
+ diff_col = f"{main_metric}:qp1-qp2"
73
+ merged = pd.merge(df1, df2, on="query_id", how="outer")
74
+ rows = []
75
+ for _, example in tqdm.tqdm(merged.iterrows(), desc="Comparing", total=len(merged)):
76
+ docs = {cid: cid2doc[cid].text for cid in dict(example["ranking_x"])}
77
+ docs.update({cid: cid2doc[cid].text for cid in dict(example["ranking_y"])})
78
+ query_id = example["query_id"]
79
+ row = {
80
+ "query_id": query_id,
81
+ "query": qid2query[query_id].text,
82
+ diff_col: example["query_performances_x"][main_metric]
83
+ - example["query_performances_y"][main_metric],
84
+ "ranking1": ujson.dumps(example["ranking_x"], indent=4),
85
+ "ranking2": ujson.dumps(example["ranking_y"], indent=4),
86
+ "docs": ujson.dumps(docs, indent=4),
87
+ "query_performances1": ujson.dumps(
88
+ example["query_performances_x"], indent=4
89
+ ),
90
+ "query_performances2": ujson.dumps(
91
+ example["query_performances_y"], indent=4
92
+ ),
93
+ "qrels": ujson.dumps(all_qrels[query_id], indent=4),
94
+ }
95
+ if term_weighting_fn1 is not None and term_weighting_fn2 is not None:
96
+ all_cids = set(example["ranking_x"]) | set(example["ranking_y"])
97
+ cid2tweights1 = {}
98
+ cid2tweights2 = {}
99
+ ranking1 = {}
100
+ ranking2 = {}
101
+ for cid in all_cids:
102
+ tweights1 = term_weighting_fn1(query=qid2query[query_id].text, cid=cid)
103
+ tweights2 = term_weighting_fn2(query=qid2query[query_id].text, cid=cid)
104
+ ranking1[cid] = sum(tweights1.values())
105
+ ranking2[cid] = sum(tweights2.values())
106
+ cid2tweights1[cid] = tweights1
107
+ cid2tweights2[cid] = tweights2
108
+ ranking1 = sort_dict(ranking1)
109
+ ranking2 = sort_dict(ranking2)
110
+ row["ranking1"] = ujson.dumps(ranking1, indent=4)
111
+ row["ranking2"] = ujson.dumps(ranking2, indent=4)
112
+ cid2tweights1 = {cid: cid2tweights1[cid] for cid in ranking1}
113
+ cid2tweights2 = {cid: cid2tweights2[cid] for cid in ranking2}
114
+ row["cid2tweights1"] = ujson.dumps(cid2tweights1, indent=4)
115
+ row["cid2tweights2"] = ujson.dumps(cid2tweights2, indent=4)
116
+ rows.append(row)
117
+ table = pd.DataFrame(rows).sort_values(by=diff_col, ascending=False)
118
+ output_path = os.path.join(output_dir, f"compare-{system1}_vs_{system2}.tsv")
119
+ table.to_csv(output_path, sep="\t", index=False)
120
+
121
+
122
+ # if __name__ == "__main__":
123
+ # # python -m lecture2.bm25.analysis
124
+ # from nlp4web_codebase.ir.data_loaders.sciq import load_sciq
125
+ # from lecture2.bm25.bm25_retriever import BM25Retriever
126
+ # from lecture2.bm25.tfidf_retriever import TFIDFRetriever
127
+ # import numpy as np
128
+
129
+ # sciq = load_sciq()
130
+ # system1 = "bm25"
131
+ # system2 = "tfidf"
132
+ # results_path1 = f"output/sciq-{system1}/results/ranking_results.jsonl"
133
+ # results_path2 = f"output/sciq-{system2}/results/ranking_results.jsonl"
134
+ # index_dir1 = f"output/sciq-{system1}"
135
+ # index_dir2 = f"output/sciq-{system2}"
136
+ # compare(
137
+ # dataset=sciq,
138
+ # results_path1=results_path1,
139
+ # results_path2=results_path2,
140
+ # output_dir=f"output/sciq-{system1}_vs_{system2}",
141
+ # system1=system1,
142
+ # system2=system2,
143
+ # term_weighting_fn1=BM25Retriever(index_dir1).get_term_weights,
144
+ # term_weighting_fn2=TFIDFRetriever(index_dir2).get_term_weights,
145
+ # )
146
+
147
+ # # bias on #shared_terms of TFIDF:
148
+ # df1 = pd.read_json(results_path1, orient="records", lines=True)
149
+ # df2 = pd.read_json(results_path2, orient="records", lines=True)
150
+ # merged = pd.merge(df1, df2, on="query_id", how="outer")
151
+ # nterms1 = []
152
+ # nterms2 = []
153
+ # for _, row in merged.iterrows():
154
+ # nterms1.append(len(list(dict(row["cid2tweights_x"]).values())[0]))
155
+ # nterms2.append(len(list(dict(row["cid2tweights_y"]).values())[0]))
156
+ # percentiles = (5, 25, 50, 75, 95)
157
+ # print(system1, np.percentile(nterms1, percentiles), np.mean(nterms1).round(2))
158
+ # print(system2, np.percentile(nterms2, percentiles), np.mean(nterms2).round(2))
159
+ # # bm25 [ 3. 4. 5. 7. 11.] 5.64
160
+ # # tfidf [1. 2. 3. 5. 9.] 3.58
nlp4web_codebase/ir/data_loaders/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from enum import Enum
3
+ from typing import Dict, List
4
+ from nlp4web_codebase.ir.data_loaders.dm import Document, Query, QRel
5
+
6
+
7
+ class Split(str, Enum):
8
+ train = "train"
9
+ dev = "dev"
10
+ test = "test"
11
+
12
+
13
+ @dataclass
14
+ class IRDataset:
15
+ corpus: List[Document]
16
+ queries: List[Query]
17
+ split2qrels: Dict[Split, List[QRel]]
18
+
19
+ def get_stats(self) -> Dict[str, int]:
20
+ stats = {"|corpus|": len(self.corpus), "|queries|": len(self.queries)}
21
+ for split, qrels in self.split2qrels.items():
22
+ stats[f"|qrels-{split}|"] = len(qrels)
23
+ return stats
24
+
25
+ def get_qrels_dict(self, split: Split) -> Dict[str, Dict[str, int]]:
26
+ qrels_dict = {}
27
+ for qrel in self.split2qrels[split]:
28
+ qrels_dict.setdefault(qrel.query_id, {})
29
+ qrels_dict[qrel.query_id][qrel.collection_id] = qrel.relevance
30
+ return qrels_dict
31
+
32
+ def get_split_queries(self, split: Split) -> List[Query]:
33
+ qrels = self.split2qrels[split]
34
+ qids = {qrel.query_id for qrel in qrels}
35
+ return list(filter(lambda query: query.query_id in qids, self.queries))
nlp4web_codebase/ir/data_loaders/dm.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+
5
+ @dataclass
6
+ class Document:
7
+ collection_id: str
8
+ text: str
9
+
10
+
11
+ @dataclass
12
+ class Query:
13
+ query_id: str
14
+ text: str
15
+
16
+
17
+ @dataclass
18
+ class QRel:
19
+ query_id: str
20
+ collection_id: str
21
+ relevance: int
22
+ answer: Optional[str] = None
nlp4web_codebase/ir/data_loaders/sciq.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+ from nlp4web_codebase.ir.data_loaders import IRDataset, Split
3
+ from nlp4web_codebase.ir.data_loaders.dm import Document, Query, QRel
4
+ from datasets import load_dataset
5
+ import joblib
6
+
7
+
8
+ @(joblib.Memory(".cache").cache)
9
+ def load_sciq(verbose: bool = False) -> IRDataset:
10
+ train = load_dataset("allenai/sciq", split="train")
11
+ validation = load_dataset("allenai/sciq", split="validation")
12
+ test = load_dataset("allenai/sciq", split="test")
13
+ data = {Split.train: train, Split.dev: validation, Split.test: test}
14
+
15
+ # Each duplicated record is the same to each other:
16
+ df = train.to_pandas() + validation.to_pandas() + test.to_pandas()
17
+ for question, group in df.groupby("question"):
18
+ assert len(set(group["support"].tolist())) == len(group)
19
+ assert len(set(group["correct_answer"].tolist())) == len(group)
20
+
21
+ # Build:
22
+ corpus = []
23
+ queries = []
24
+ split2qrels: Dict[str, List[dict]] = {}
25
+ question2id = {}
26
+ support2id = {}
27
+ for split, rows in data.items():
28
+ if verbose:
29
+ print(f"|raw_{split}|", len(rows))
30
+ split2qrels[split] = []
31
+ for i, row in enumerate(rows):
32
+ example_id = f"{split}-{i}"
33
+ support: str = row["support"]
34
+ if len(support.strip()) == 0:
35
+ continue
36
+ question = row["question"]
37
+ if len(support.strip()) == 0:
38
+ continue
39
+ if support in support2id:
40
+ continue
41
+ else:
42
+ support2id[support] = example_id
43
+ if question in question2id:
44
+ continue
45
+ else:
46
+ question2id[question] = example_id
47
+ doc = {"collection_id": example_id, "text": support}
48
+ query = {"query_id": example_id, "text": row["question"]}
49
+ qrel = {
50
+ "query_id": example_id,
51
+ "collection_id": example_id,
52
+ "relevance": 1,
53
+ "answer": row["correct_answer"],
54
+ }
55
+ corpus.append(Document(**doc))
56
+ queries.append(Query(**query))
57
+ split2qrels[split].append(QRel(**qrel))
58
+
59
+ # Assembly and return:
60
+ return IRDataset(corpus=corpus, queries=queries, split2qrels=split2qrels)
61
+
62
+
63
+ if __name__ == "__main__":
64
+ # python -m nlp4web_codebase.ir.data_loaders.sciq
65
+ import ujson
66
+ import time
67
+
68
+ start = time.time()
69
+ dataset = load_sciq(verbose=True)
70
+ print(f"Loading costs: {time.time() - start}s")
71
+ print(ujson.dumps(dataset.get_stats(), indent=4))
72
+ # ________________________________________________________________________________
73
+ # [Memory] Calling __main__--home-kwang-research-nlp4web-ir-exercise-nlp4web-nlp4web-ir-data_loaders-sciq.load_sciq...
74
+ # load_sciq(verbose=True)
75
+ # |raw_train| 11679
76
+ # |raw_dev| 1000
77
+ # |raw_test| 1000
78
+ # ________________________________________________________load_sciq - 7.3s, 0.1min
79
+ # Loading costs: 7.260092735290527s
80
+ # {
81
+ # "|corpus|": 12160,
82
+ # "|queries|": 12160,
83
+ # "|qrels-train|": 10409,
84
+ # "|qrels-dev|": 875,
85
+ # "|qrels-test|": 876
86
+ # }
nlp4web_codebase/ir/models/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, Dict, Type
3
+
4
+
5
+ class BaseRetriever(ABC):
6
+
7
+ @property
8
+ @abstractmethod
9
+ def index_class(self) -> Type[Any]:
10
+ pass
11
+
12
+ def get_term_weights(self, query: str, cid: str) -> Dict[str, float]:
13
+ raise NotImplementedError
14
+
15
+ @abstractmethod
16
+ def score(self, query: str, cid: str) -> float:
17
+ pass
18
+
19
+ @abstractmethod
20
+ def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]:
21
+ pass
requirements.txt ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.4.0
2
+ alabaster==0.7.13
3
+ anyio==4.4.0
4
+ appnope==0.1.4
5
+ argcomplete==3.2.3
6
+ argon2-cffi==23.1.0
7
+ argon2-cffi-bindings==21.2.0
8
+ arrow==1.3.0
9
+ asttokens==2.4.1
10
+ async-lru==2.0.4
11
+ attrs==23.1.0
12
+ Babel==2.12.1
13
+ beautifulsoup4==4.12.3
14
+ black==23.1.0
15
+ blacken-docs==1.13.0
16
+ bleach==6.1.0
17
+ cachetools==5.3.1
18
+ certifi==2023.5.7
19
+ cffi==1.15.1
20
+ cfgv==3.3.1
21
+ charset-normalizer==3.2.0
22
+ click==8.1.5
23
+ cloudpickle==2.2.1
24
+ coloredlogs==15.0.1
25
+ comm==0.2.2
26
+ contourpy==1.1.0
27
+ coverage==7.2.7
28
+ cryptography==41.0.2
29
+ cycler==0.11.0
30
+ dataclasses==0.6
31
+ DateTime==5.5
32
+ debugpy==1.8.5
33
+ decorator==5.1.1
34
+ defusedxml==0.7.1
35
+ dill==0.3.8
36
+ distlib==0.3.7
37
+ docutils==0.18.1
38
+ eradicate==2.3.0
39
+ et-xmlfile==1.1.0
40
+ # Editable install with no version control (eta-utility==2.2.2b2.dev78+g7a5fece)
41
+ -e /Users/mikaelhailu/Documents/Office/ETA-Fabrik/eta-utility
42
+ exceptiongroup==1.1.2
43
+ executing==2.0.1
44
+ fastjsonschema==2.20.0
45
+ filelock==3.12.2
46
+ flake8==5.0.4
47
+ flake8-builtins==2.1.0
48
+ flake8-comprehensions==3.10.1
49
+ flake8-eradicate==1.4.0
50
+ flake8-mutable==1.2.0
51
+ flake8-plugin-utils==1.3.3
52
+ flake8-print==5.0.0
53
+ flake8-pytest-style==1.7.2
54
+ flake8-requirements==1.7.7
55
+ flake8-rst-docstrings==0.3.0
56
+ flatbuffers==23.5.26
57
+ FMPy==0.3.15
58
+ fonttools==4.41.0
59
+ fqdn==1.5.1
60
+ google-auth==2.22.0
61
+ google-auth-oauthlib==1.0.0
62
+ grpcio==1.56.0
63
+ gym @ git+https://github.com/rlberry-py/gym_fix_021@fd62b4bc15dfd5d8a9be42da54b234c5c47fc98b
64
+ h11==0.14.0
65
+ httpcore==1.0.5
66
+ httpx==0.27.0
67
+ humanfriendly==10.0
68
+ icalendar==6.0.1
69
+ identify==2.5.24
70
+ idna==3.4
71
+ imagesize==1.4.1
72
+ importlib-metadata==4.13.0
73
+ iniconfig==2.0.0
74
+ ipykernel==6.29.5
75
+ ipython==8.24.0
76
+ ipywidgets==8.1.3
77
+ isoduration==20.11.0
78
+ isort==5.12.0
79
+ jedi==0.19.1
80
+ Jinja2==3.1.2
81
+ joblib==1.4.2
82
+ json5==0.9.25
83
+ jsonpointer==3.0.0
84
+ jsonschema==4.23.0
85
+ jsonschema-specifications==2023.12.1
86
+ jupyter==1.0.0
87
+ jupyter-console==6.6.3
88
+ jupyter-events==0.10.0
89
+ jupyter-lsp==2.2.5
90
+ jupyter_client==8.6.2
91
+ jupyter_core==5.7.2
92
+ jupyter_server==2.14.2
93
+ jupyter_server_terminals==0.5.3
94
+ jupyterlab==4.2.4
95
+ jupyterlab_pygments==0.3.0
96
+ jupyterlab_server==2.27.3
97
+ jupyterlab_widgets==3.0.11
98
+ keyboard==0.13.5
99
+ kiwisolver==1.4.4
100
+ lark==1.1.6
101
+ lxml==4.9.3
102
+ Markdown==3.4.3
103
+ MarkupSafe==2.1.5
104
+ matplotlib==3.7.2
105
+ matplotlib-inline==0.1.7
106
+ mccabe==0.7.0
107
+ mistune==3.0.2
108
+ MouseInfo==0.1.3
109
+ mpmath==1.3.0
110
+ msgpack==1.0.5
111
+ mushroom-rl==1.10.1
112
+ mypy==1.0.0
113
+ mypy-extensions==1.0.0
114
+ nbclient==0.10.0
115
+ nbconvert==7.16.4
116
+ nbformat==5.10.4
117
+ nest-asyncio==1.6.0
118
+ networkx==3.1
119
+ nodeenv==1.8.0
120
+ notebook==7.2.1
121
+ notebook_shim==0.2.4
122
+ numpy==1.25.2
123
+ oauthlib==3.2.2
124
+ onnxruntime==1.15.1
125
+ opcua==0.98.13
126
+ opencv-python==4.10.0.84
127
+ openpyxl==3.1.2
128
+ overrides==7.7.0
129
+ packaging==23.1
130
+ pandas==2.0.3
131
+ pandocfilters==1.5.1
132
+ parso==0.8.4
133
+ pathspec==0.11.1
134
+ pbr==6.0.0
135
+ pep8-naming==0.13.3
136
+ pexpect==4.9.0
137
+ Pillow==10.0.0
138
+ pipx==1.4.3
139
+ platformdirs==3.9.1
140
+ pluggy==1.2.0
141
+ ply==3.11
142
+ pre-commit==3.3.3
143
+ prometheus_client==0.20.0
144
+ prompt-toolkit==3.0.43
145
+ protobuf==4.23.4
146
+ psutil==6.0.0
147
+ ptyprocess==0.7.0
148
+ pure-eval==0.2.2
149
+ pyasn1==0.5.0
150
+ pyasn1-modules==0.3.0
151
+ pycodestyle==2.9.1
152
+ pycparser==2.21
153
+ pyflakes==2.5.0
154
+ pygame==2.5.0
155
+ PyGetWindow==0.0.9
156
+ pyglet==2.0.8
157
+ Pygments==2.15.1
158
+ pyModbusTCP==0.2.0
159
+ PyMsgBox==1.0.9
160
+ pyobjc==9.2
161
+ pyobjc-core==9.2
162
+ pyobjc-framework-Accessibility==9.2
163
+ pyobjc-framework-Accounts==9.2
164
+ pyobjc-framework-AddressBook==9.2
165
+ pyobjc-framework-AdServices==9.2
166
+ pyobjc-framework-AdSupport==9.2
167
+ pyobjc-framework-AppleScriptKit==9.2
168
+ pyobjc-framework-AppleScriptObjC==9.2
169
+ pyobjc-framework-ApplicationServices==9.2
170
+ pyobjc-framework-AppTrackingTransparency==9.2
171
+ pyobjc-framework-AudioVideoBridging==9.2
172
+ pyobjc-framework-AuthenticationServices==9.2
173
+ pyobjc-framework-AutomaticAssessmentConfiguration==9.2
174
+ pyobjc-framework-Automator==9.2
175
+ pyobjc-framework-AVFoundation==9.2
176
+ pyobjc-framework-AVKit==9.2
177
+ pyobjc-framework-AVRouting==9.2
178
+ pyobjc-framework-BackgroundAssets==9.2
179
+ pyobjc-framework-BusinessChat==9.2
180
+ pyobjc-framework-CalendarStore==9.2
181
+ pyobjc-framework-CallKit==9.2
182
+ pyobjc-framework-CFNetwork==9.2
183
+ pyobjc-framework-ClassKit==9.2
184
+ pyobjc-framework-CloudKit==9.2
185
+ pyobjc-framework-Cocoa==9.2
186
+ pyobjc-framework-Collaboration==9.2
187
+ pyobjc-framework-ColorSync==9.2
188
+ pyobjc-framework-Contacts==9.2
189
+ pyobjc-framework-ContactsUI==9.2
190
+ pyobjc-framework-CoreAudio==9.2
191
+ pyobjc-framework-CoreAudioKit==9.2
192
+ pyobjc-framework-CoreBluetooth==9.2
193
+ pyobjc-framework-CoreData==9.2
194
+ pyobjc-framework-CoreHaptics==9.2
195
+ pyobjc-framework-CoreLocation==9.2
196
+ pyobjc-framework-CoreMedia==9.2
197
+ pyobjc-framework-CoreMediaIO==9.2
198
+ pyobjc-framework-CoreMIDI==9.2
199
+ pyobjc-framework-CoreML==9.2
200
+ pyobjc-framework-CoreMotion==9.2
201
+ pyobjc-framework-CoreServices==9.2
202
+ pyobjc-framework-CoreSpotlight==9.2
203
+ pyobjc-framework-CoreText==9.2
204
+ pyobjc-framework-CoreWLAN==9.2
205
+ pyobjc-framework-CryptoTokenKit==9.2
206
+ pyobjc-framework-DataDetection==9.2
207
+ pyobjc-framework-DeviceCheck==9.2
208
+ pyobjc-framework-DictionaryServices==9.2
209
+ pyobjc-framework-DiscRecording==9.2
210
+ pyobjc-framework-DiscRecordingUI==9.2
211
+ pyobjc-framework-DiskArbitration==9.2
212
+ pyobjc-framework-DVDPlayback==9.2
213
+ pyobjc-framework-EventKit==9.2
214
+ pyobjc-framework-ExceptionHandling==9.2
215
+ pyobjc-framework-ExecutionPolicy==9.2
216
+ pyobjc-framework-ExtensionKit==9.2
217
+ pyobjc-framework-ExternalAccessory==9.2
218
+ pyobjc-framework-FileProvider==9.2
219
+ pyobjc-framework-FileProviderUI==9.2
220
+ pyobjc-framework-FinderSync==9.2
221
+ pyobjc-framework-FSEvents==9.2
222
+ pyobjc-framework-GameCenter==9.2
223
+ pyobjc-framework-GameController==9.2
224
+ pyobjc-framework-GameKit==9.2
225
+ pyobjc-framework-GameplayKit==9.2
226
+ pyobjc-framework-HealthKit==9.2
227
+ pyobjc-framework-ImageCaptureCore==9.2
228
+ pyobjc-framework-IMServicePlugIn==9.2
229
+ pyobjc-framework-InputMethodKit==9.2
230
+ pyobjc-framework-InstallerPlugins==9.2
231
+ pyobjc-framework-InstantMessage==9.2
232
+ pyobjc-framework-Intents==9.2
233
+ pyobjc-framework-IntentsUI==9.2
234
+ pyobjc-framework-IOBluetooth==9.2
235
+ pyobjc-framework-IOBluetoothUI==9.2
236
+ pyobjc-framework-IOSurface==9.2
237
+ pyobjc-framework-iTunesLibrary==9.2
238
+ pyobjc-framework-KernelManagement==9.2
239
+ pyobjc-framework-LatentSemanticMapping==9.2
240
+ pyobjc-framework-LaunchServices==9.2
241
+ pyobjc-framework-libdispatch==9.2
242
+ pyobjc-framework-libxpc==9.2
243
+ pyobjc-framework-LinkPresentation==9.2
244
+ pyobjc-framework-LocalAuthentication==9.2
245
+ pyobjc-framework-LocalAuthenticationEmbeddedUI==9.2
246
+ pyobjc-framework-MailKit==9.2
247
+ pyobjc-framework-MapKit==9.2
248
+ pyobjc-framework-MediaAccessibility==9.2
249
+ pyobjc-framework-MediaLibrary==9.2
250
+ pyobjc-framework-MediaPlayer==9.2
251
+ pyobjc-framework-MediaToolbox==9.2
252
+ pyobjc-framework-Metal==9.2
253
+ pyobjc-framework-MetalFX==9.2
254
+ pyobjc-framework-MetalKit==9.2
255
+ pyobjc-framework-MetalPerformanceShaders==9.2
256
+ pyobjc-framework-MetalPerformanceShadersGraph==9.2
257
+ pyobjc-framework-MetricKit==9.2
258
+ pyobjc-framework-MLCompute==9.2
259
+ pyobjc-framework-ModelIO==9.2
260
+ pyobjc-framework-MultipeerConnectivity==9.2
261
+ pyobjc-framework-NaturalLanguage==9.2
262
+ pyobjc-framework-NetFS==9.2
263
+ pyobjc-framework-Network==9.2
264
+ pyobjc-framework-NetworkExtension==9.2
265
+ pyobjc-framework-NotificationCenter==9.2
266
+ pyobjc-framework-OpenDirectory==9.2
267
+ pyobjc-framework-OSAKit==9.2
268
+ pyobjc-framework-OSLog==9.2
269
+ pyobjc-framework-PassKit==9.2
270
+ pyobjc-framework-PencilKit==9.2
271
+ pyobjc-framework-PHASE==9.2
272
+ pyobjc-framework-Photos==9.2
273
+ pyobjc-framework-PhotosUI==9.2
274
+ pyobjc-framework-PreferencePanes==9.2
275
+ pyobjc-framework-PushKit==9.2
276
+ pyobjc-framework-Quartz==9.2
277
+ pyobjc-framework-QuickLookThumbnailing==9.2
278
+ pyobjc-framework-ReplayKit==9.2
279
+ pyobjc-framework-SafariServices==9.2
280
+ pyobjc-framework-SafetyKit==9.2
281
+ pyobjc-framework-SceneKit==9.2
282
+ pyobjc-framework-ScreenCaptureKit==9.2
283
+ pyobjc-framework-ScreenSaver==9.2
284
+ pyobjc-framework-ScreenTime==9.2
285
+ pyobjc-framework-ScriptingBridge==9.2
286
+ pyobjc-framework-SearchKit==9.2
287
+ pyobjc-framework-Security==9.2
288
+ pyobjc-framework-SecurityFoundation==9.2
289
+ pyobjc-framework-SecurityInterface==9.2
290
+ pyobjc-framework-ServiceManagement==9.2
291
+ pyobjc-framework-SharedWithYou==9.2
292
+ pyobjc-framework-SharedWithYouCore==9.2
293
+ pyobjc-framework-ShazamKit==9.2
294
+ pyobjc-framework-Social==9.2
295
+ pyobjc-framework-SoundAnalysis==9.2
296
+ pyobjc-framework-Speech==9.2
297
+ pyobjc-framework-SpriteKit==9.2
298
+ pyobjc-framework-StoreKit==9.2
299
+ pyobjc-framework-SyncServices==9.2
300
+ pyobjc-framework-SystemConfiguration==9.2
301
+ pyobjc-framework-SystemExtensions==9.2
302
+ pyobjc-framework-ThreadNetwork==9.2
303
+ pyobjc-framework-UniformTypeIdentifiers==9.2
304
+ pyobjc-framework-UserNotifications==9.2
305
+ pyobjc-framework-UserNotificationsUI==9.2
306
+ pyobjc-framework-VideoSubscriberAccount==9.2
307
+ pyobjc-framework-VideoToolbox==9.2
308
+ pyobjc-framework-Virtualization==9.2
309
+ pyobjc-framework-Vision==9.2
310
+ pyobjc-framework-WebKit==9.2
311
+ Pyomo==6.6.1
312
+ pyparsing==3.0.9
313
+ pyproject-flake8==5.0.4
314
+ PyRect==0.2.0
315
+ PyScreeze==0.1.30
316
+ pytest==7.4.0
317
+ pytest-cov==4.1.0
318
+ python-dateutil==2.9.0.post0
319
+ python-json-logger==2.0.7
320
+ pytweening==1.2.0
321
+ pytz==2023.3
322
+ pyupgrade==3.3.1
323
+ PyYAML==6.0.2
324
+ pyzmq==26.1.0
325
+ qiskit==1.1.1
326
+ qiskit-aer==0.14.2
327
+ qtconsole==5.5.2
328
+ QtPy==2.4.1
329
+ referencing==0.35.1
330
+ requests==2.31.0
331
+ requests-oauthlib==1.3.1
332
+ restructuredtext-lint==1.4.0
333
+ rfc3339-validator==0.1.4
334
+ rfc3986-validator==0.1.1
335
+ rpds-py==0.20.0
336
+ rsa==4.9
337
+ rubicon-objc==0.4.9
338
+ rustworkx==0.15.1
339
+ scikit-learn==1.5.1
340
+ scipy==1.14.0
341
+ Send2Trash==1.8.3
342
+ six==1.16.0
343
+ sniffio==1.3.1
344
+ snowballstemmer==2.2.0
345
+ soupsieve==2.5
346
+ Sphinx==6.2.1
347
+ sphinx-rtd-theme==1.2.2
348
+ sphinxcontrib-applehelp==1.0.4
349
+ sphinxcontrib-devhelp==1.0.2
350
+ sphinxcontrib-htmlhelp==2.0.1
351
+ sphinxcontrib-jquery==4.1
352
+ sphinxcontrib-jsmath==1.0.1
353
+ sphinxcontrib-qthelp==1.0.3
354
+ sphinxcontrib-serializinghtml==1.1.5
355
+ stable-baselines3==1.8.0
356
+ stack-data==0.6.3
357
+ stevedore==5.2.0
358
+ symengine==0.11.0
359
+ sympy==1.12
360
+ TatSu==5.8.3
361
+ tensorboard==2.13.0
362
+ tensorboard-data-server==0.7.1
363
+ terminado==0.18.1
364
+ threadpoolctl==3.5.0
365
+ tinycss2==1.3.0
366
+ tokenize-rt==5.1.0
367
+ tomli==2.0.1
368
+ torch==2.0.1
369
+ tornado==6.4.1
370
+ tqdm==4.66.5
371
+ traitlets==5.14.3
372
+ types-python-dateutil==2.8.19.13
373
+ types-requests==2.31.0.1
374
+ types-urllib3==1.26.25.13
375
+ typing_extensions==4.11.0
376
+ tzdata==2024.2
377
+ uri-template==1.3.0
378
+ urllib3==1.26.16
379
+ userpath==1.9.2
380
+ virtualenv==20.24.0
381
+ wcwidth==0.2.13
382
+ webcolors==24.6.0
383
+ webencodings==0.5.1
384
+ websocket-client==1.8.0
385
+ Werkzeug==2.3.6
386
+ widgetsnbextension==4.0.11
387
+ xlrd==2.0.1
388
+ zipp==3.16.2
389
+ zope.interface==7.1.0