j1503 commited on
Commit
e03dd66
·
verified ·
1 Parent(s): 9d7a728

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +586 -0
app.py ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import joblib
2
+ import gradio as gr
3
+ from collections import Counter
4
+ from typing import TypedDict
5
+ from abc import ABC, abstractmethod
6
+ from typing import Any, Dict, Type
7
+ from scipy.sparse._csc import csc_matrix
8
+ from typing import Iterable, Callable, List, Dict, Optional, Type, TypeVar
9
+ import pickle
10
+ from dataclasses import dataclass
11
+ import tqdm
12
+ import re
13
+ import os
14
+ import nltk
15
+ nltk.download("stopwords", quiet=True)
16
+ from nltk.corpus import stopwords as nltk_stopwords
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Optional
20
+ from datasets import load_dataset
21
+
22
+ @dataclass
23
+ class Document:
24
+ collection_id: str
25
+ text: str
26
+
27
+
28
+ @dataclass
29
+ class Query:
30
+ query_id: str
31
+ text: str
32
+
33
+
34
+ @dataclass
35
+ class QRel:
36
+ query_id: str
37
+ collection_id: str
38
+ relevance: int
39
+ answer: Optional[str] = None
40
+
41
+ class Split(str, Enum):
42
+ train = "train"
43
+ dev = "dev"
44
+ test = "test"
45
+
46
+ @dataclass
47
+ class IRDataset:
48
+ corpus: List[Document]
49
+ queries: List[Query]
50
+ split2qrels: Dict[Split, List[QRel]]
51
+
52
+ def get_stats(self) -> Dict[str, int]:
53
+ stats = {"|corpus|": len(self.corpus), "|queries|": len(self.queries)}
54
+ for split, qrels in self.split2qrels.items():
55
+ stats[f"|qrels-{split}|"] = len(qrels)
56
+ return stats
57
+
58
+ def get_qrels_dict(self, split: Split) -> Dict[str, Dict[str, int]]:
59
+ qrels_dict = {}
60
+ for qrel in self.split2qrels[split]:
61
+ qrels_dict.setdefault(qrel.query_id, {})
62
+ qrels_dict[qrel.query_id][qrel.collection_id] = qrel.relevance
63
+ return qrels_dict
64
+
65
+ def get_split_queries(self, split: Split) -> List[Query]:
66
+ qrels = self.split2qrels[split]
67
+ qids = {qrel.query_id for qrel in qrels}
68
+ return list(filter(lambda query: query.query_id in qids, self.queries))
69
+
70
+
71
+
72
+ @(joblib.Memory(".cache").cache)
73
+ def load_sciq(verbose: bool = False) -> IRDataset:
74
+ train = load_dataset("allenai/sciq", split="train")
75
+ validation = load_dataset("allenai/sciq", split="validation")
76
+ test = load_dataset("allenai/sciq", split="test")
77
+ data = {Split.train: train, Split.dev: validation, Split.test: test}
78
+
79
+ # Each duplicated record is the same to each other:
80
+ df = train.to_pandas() + validation.to_pandas() + test.to_pandas()
81
+ for question, group in df.groupby("question"):
82
+ assert len(set(group["support"].tolist())) == len(group)
83
+ assert len(set(group["correct_answer"].tolist())) == len(group)
84
+
85
+ # Build:
86
+ corpus = []
87
+ queries = []
88
+ split2qrels: Dict[str, List[dict]] = {}
89
+ question2id = {}
90
+ support2id = {}
91
+ for split, rows in data.items():
92
+ if verbose:
93
+ print(f"|raw_{split}|", len(rows))
94
+ split2qrels[split] = []
95
+ for i, row in enumerate(rows):
96
+ example_id = f"{split}-{i}"
97
+ support: str = row["support"]
98
+ if len(support.strip()) == 0:
99
+ continue
100
+ question = row["question"]
101
+ if len(support.strip()) == 0:
102
+ continue
103
+ if support in support2id:
104
+ continue
105
+ else:
106
+ support2id[support] = example_id
107
+ if question in question2id:
108
+ continue
109
+ else:
110
+ question2id[question] = example_id
111
+ doc = {"collection_id": example_id, "text": support}
112
+ query = {"query_id": example_id, "text": row["question"]}
113
+ qrel = {
114
+ "query_id": example_id,
115
+ "collection_id": example_id,
116
+ "relevance": 1,
117
+ "answer": row["correct_answer"],
118
+ }
119
+ corpus.append(Document(**doc))
120
+ queries.append(Query(**query))
121
+ split2qrels[split].append(QRel(**qrel))
122
+
123
+ # Assembly and return:
124
+ return IRDataset(corpus=corpus, queries=queries, split2qrels=split2qrels)
125
+
126
+ LANGUAGE = "english"
127
+ word_splitter = re.compile(r"(?u)\b\w\w+\b").findall
128
+ stopwords = set(nltk_stopwords.words(LANGUAGE))
129
+
130
+ def word_splitting(text: str) -> List[str]:
131
+ return word_splitter(text.lower())
132
+
133
+ def lemmatization(words: List[str]) -> List[str]:
134
+ return words # We ignore lemmatization here for simplicity
135
+
136
+ def simple_tokenize(text: str) -> List[str]:
137
+ words = word_splitting(text)
138
+ tokenized = list(filter(lambda w: w not in stopwords, words))
139
+ tokenized = lemmatization(tokenized)
140
+ return tokenized
141
+
142
+ T = TypeVar("T", bound="InvertedIndex")
143
+
144
+ @dataclass
145
+ class PostingList:
146
+ term: str # The term
147
+ docid_postings: List[int] # docid_postings[i] means the docid (int) of the i-th associated posting
148
+ tweight_postings: List[float] # tweight_postings[i] means the term weight (float) of the i-th associated posting
149
+
150
+ @dataclass
151
+ class InvertedIndex:
152
+ posting_lists: List[PostingList] # docid -> posting_list
153
+ vocab: Dict[str, int]
154
+ cid2docid: Dict[str, int] # collection_id -> docid
155
+ collection_ids: List[str] # docid -> collection_id
156
+ doc_texts: Optional[List[str]] = None # docid -> document text
157
+
158
+ def save(self, output_dir: str) -> None:
159
+ os.makedirs(output_dir, exist_ok=True)
160
+ with open(os.path.join(output_dir, "index.pkl"), "wb") as f:
161
+ pickle.dump(self, f)
162
+
163
+ @classmethod
164
+ def from_saved(cls: Type[T], saved_dir: str) -> T:
165
+ index = cls(
166
+ posting_lists=[], vocab={}, cid2docid={}, collection_ids=[], doc_texts=None
167
+ )
168
+ with open(os.path.join(saved_dir, "index.pkl"), "rb") as f:
169
+ index = pickle.load(f)
170
+ return index
171
+
172
+ class BaseRetriever(ABC):
173
+
174
+ @property
175
+ @abstractmethod
176
+ def index_class(self) -> Type[Any]:
177
+ pass
178
+
179
+ def get_term_weights(self, query: str, cid: str) -> Dict[str, float]:
180
+ raise NotImplementedError
181
+
182
+ @abstractmethod
183
+ def score(self, query: str, cid: str) -> float:
184
+ pass
185
+
186
+ @abstractmethod
187
+ def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]:
188
+ pass
189
+
190
+ @dataclass
191
+ class Counting:
192
+ posting_lists: List[PostingList]
193
+ vocab: Dict[str, int]
194
+ cid2docid: Dict[str, int]
195
+ collection_ids: List[str]
196
+ dfs: List[int] # tid -> df
197
+ dls: List[int] # docid -> doc length
198
+ avgdl: float
199
+ nterms: int
200
+ doc_texts: Optional[List[str]] = None
201
+
202
+ def run_counting(
203
+ documents: Iterable[Document],
204
+ tokenize_fn: Callable[[str], List[str]] = simple_tokenize,
205
+ store_raw: bool = True, # store the document text in doc_texts
206
+ ndocs: Optional[int] = None,
207
+ show_progress_bar: bool = True,
208
+ ) -> Counting:
209
+ """Counting TFs, DFs, doc_lengths, etc."""
210
+ posting_lists: List[PostingList] = []
211
+ vocab: Dict[str, int] = {}
212
+ cid2docid: Dict[str, int] = {}
213
+ collection_ids: List[str] = []
214
+ dfs: List[int] = [] # tid -> df
215
+ dls: List[int] = [] # docid -> doc length
216
+ nterms: int = 0
217
+ doc_texts: Optional[List[str]] = []
218
+ for doc in tqdm.tqdm(
219
+ documents,
220
+ desc="Counting",
221
+ total=ndocs,
222
+ disable=not show_progress_bar,
223
+ ):
224
+ if doc.collection_id in cid2docid:
225
+ continue
226
+ collection_ids.append(doc.collection_id)
227
+ docid = cid2docid.setdefault(doc.collection_id, len(cid2docid))
228
+ toks = tokenize_fn(doc.text)
229
+ tok2tf = Counter(toks)
230
+ dls.append(sum(tok2tf.values()))
231
+ for tok, tf in tok2tf.items():
232
+ nterms += tf
233
+ tid = vocab.get(tok, None)
234
+ if tid is None:
235
+ posting_lists.append(
236
+ PostingList(term=tok, docid_postings=[], tweight_postings=[])
237
+ )
238
+ tid = vocab.setdefault(tok, len(vocab))
239
+ posting_lists[tid].docid_postings.append(docid)
240
+ posting_lists[tid].tweight_postings.append(tf)
241
+ if tid < len(dfs):
242
+ dfs[tid] += 1
243
+ else:
244
+ dfs.append(0)
245
+ if store_raw:
246
+ doc_texts.append(doc.text)
247
+ else:
248
+ doc_texts = None
249
+ return Counting(
250
+ posting_lists=posting_lists,
251
+ vocab=vocab,
252
+ cid2docid=cid2docid,
253
+ collection_ids=collection_ids,
254
+ dfs=dfs,
255
+ dls=dls,
256
+ avgdl=sum(dls) / len(dls),
257
+ nterms=nterms,
258
+ doc_texts=doc_texts,
259
+ )
260
+
261
+ @dataclass
262
+ class BM25Index(InvertedIndex):
263
+
264
+ @staticmethod
265
+ def tokenize(text: str) -> List[str]:
266
+ return simple_tokenize(text)
267
+
268
+ @staticmethod
269
+ def cache_term_weights(
270
+ posting_lists: List[PostingList],
271
+ total_docs: int,
272
+ avgdl: float,
273
+ dfs: List[int],
274
+ dls: List[int],
275
+ k1: float,
276
+ b: float,
277
+ ) -> None:
278
+ """Compute term weights and caching"""
279
+
280
+ N = total_docs
281
+ for tid, posting_list in enumerate(
282
+ tqdm.tqdm(posting_lists, desc="Regularizing TFs")
283
+ ):
284
+ idf = BM25Index.calc_idf(df=dfs[tid], N=N)
285
+ for i in range(len(posting_list.docid_postings)):
286
+ docid = posting_list.docid_postings[i]
287
+ tf = posting_list.tweight_postings[i]
288
+ dl = dls[docid]
289
+ regularized_tf = BM25Index.calc_regularized_tf(
290
+ tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b
291
+ )
292
+ posting_list.tweight_postings[i] = regularized_tf * idf
293
+
294
+ @staticmethod
295
+ def calc_regularized_tf(
296
+ tf: int, dl: float, avgdl: float, k1: float, b: float
297
+ ) -> float:
298
+ return tf / (tf + k1 * (1 - b + b * dl / avgdl))
299
+
300
+ @staticmethod
301
+ def calc_idf(df: int, N: int):
302
+ return math.log(1 + (N - df + 0.5) / (df + 0.5))
303
+
304
+ @classmethod
305
+ def build_from_documents(
306
+ cls: Type[BM25Index],
307
+ documents: Iterable[Document],
308
+ store_raw: bool = True,
309
+ output_dir: Optional[str] = None,
310
+ ndocs: Optional[int] = None,
311
+ show_progress_bar: bool = True,
312
+ k1: float = 0.9,
313
+ b: float = 0.4,
314
+ ) -> BM25Index:
315
+ # Counting TFs, DFs, doc_lengths, etc.:
316
+ counting = run_counting(
317
+ documents=documents,
318
+ tokenize_fn=BM25Index.tokenize,
319
+ store_raw=store_raw,
320
+ ndocs=ndocs,
321
+ show_progress_bar=show_progress_bar,
322
+ )
323
+
324
+ # Compute term weights and caching:
325
+ posting_lists = counting.posting_lists
326
+ total_docs = len(counting.cid2docid)
327
+ BM25Index.cache_term_weights(
328
+ posting_lists=posting_lists,
329
+ total_docs=total_docs,
330
+ avgdl=counting.avgdl,
331
+ dfs=counting.dfs,
332
+ dls=counting.dls,
333
+ k1=k1,
334
+ b=b,
335
+ )
336
+
337
+ # Assembly and save:
338
+ index = BM25Index(
339
+ posting_lists=posting_lists,
340
+ vocab=counting.vocab,
341
+ cid2docid=counting.cid2docid,
342
+ collection_ids=counting.collection_ids,
343
+ doc_texts=counting.doc_texts,
344
+ )
345
+ return index
346
+
347
+
348
+ @dataclass
349
+ class CSCInvertedIndex:
350
+ posting_lists_matrix: csc_matrix # docid -> posting_list
351
+ vocab: Dict[str, int]
352
+ cid2docid: Dict[str, int] # collection_id -> docid
353
+ collection_ids: List[str] # docid -> collection_id
354
+ doc_texts: Optional[List[str]] = None # docid -> document text
355
+
356
+ def save(self, output_dir: str) -> None:
357
+ os.makedirs(output_dir, exist_ok=True)
358
+ with open(os.path.join(output_dir, "index.pkl"), "wb") as f:
359
+ pickle.dump(self, f)
360
+
361
+ @classmethod
362
+ def from_saved(cls: Type[T], saved_dir: str) -> T:
363
+ index = cls(
364
+ posting_lists_matrix=None, vocab={}, cid2docid={}, collection_ids=[], doc_texts=None
365
+ )
366
+ with open(os.path.join(saved_dir, "index.pkl"), "rb") as f:
367
+ index = pickle.load(f)
368
+ return index
369
+
370
+ @dataclass
371
+ class CSCBM25Index(CSCInvertedIndex):
372
+
373
+ @staticmethod
374
+ def tokenize(text: str) -> List[str]:
375
+ return simple_tokenize(text)
376
+
377
+ @staticmethod
378
+ def cache_term_weights(
379
+ posting_lists: List[PostingList],
380
+ total_docs: int,
381
+ avgdl: float,
382
+ dfs: List[int],
383
+ dls: List[int],
384
+ k1: float,
385
+ b: float,
386
+ ) -> csc_matrix:
387
+ """Compute term weights and caching"""
388
+
389
+ ## YOUR_CODE_STARTS_HERE
390
+ data = []
391
+ indices = []
392
+ indptr = [0]
393
+ N = total_docs
394
+ for tid, posting_list in enumerate(
395
+ tqdm.tqdm(posting_lists, desc="Regularizing TFs")
396
+ ):
397
+ idf = BM25Index.calc_idf(df=dfs[tid], N=N)
398
+ for i in range(len(posting_list.docid_postings)):
399
+ docid = posting_list.docid_postings[i]
400
+ tf = posting_list.tweight_postings[i]
401
+ dl = dls[docid]
402
+ regularized_tf = BM25Index.calc_regularized_tf(
403
+ tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b
404
+ )
405
+ weight = regularized_tf * idf
406
+ data.append(weight)
407
+ indices.append(docid)
408
+ indptr.append(len(data))
409
+
410
+ posting_lists_matrix = csc_matrix(
411
+ (data, indices, indptr),
412
+ shape=(total_docs, len(posting_lists))
413
+ )
414
+
415
+ return posting_lists_matrix
416
+ ## YOUR_CODE_ENDS_HERE
417
+
418
+ @staticmethod
419
+ def calc_regularized_tf(
420
+ tf: int, dl: float, avgdl: float, k1: float, b: float
421
+ ) -> float:
422
+ return tf / (tf + k1 * (1 - b + b * dl / avgdl))
423
+
424
+ @staticmethod
425
+ def calc_idf(df: int, N: int):
426
+ return math.log(1 + (N - df + 0.5) / (df + 0.5))
427
+
428
+ @classmethod
429
+ def build_from_documents(
430
+ cls: Type[CSCBM25Index],
431
+ documents: Iterable[Document],
432
+ store_raw: bool = True,
433
+ output_dir: Optional[str] = None,
434
+ ndocs: Optional[int] = None,
435
+ show_progress_bar: bool = True,
436
+ k1: float = 0.9,
437
+ b: float = 0.4,
438
+ ) -> CSCBM25Index:
439
+ # Counting TFs, DFs, doc_lengths, etc.:
440
+ counting = run_counting(
441
+ documents=documents,
442
+ tokenize_fn=CSCBM25Index.tokenize,
443
+ store_raw=store_raw,
444
+ ndocs=ndocs,
445
+ show_progress_bar=show_progress_bar,
446
+ )
447
+
448
+ # Compute term weights and caching:
449
+ posting_lists = counting.posting_lists
450
+ total_docs = len(counting.cid2docid)
451
+ posting_lists_matrix = CSCBM25Index.cache_term_weights(
452
+ posting_lists=posting_lists,
453
+ total_docs=total_docs,
454
+ avgdl=counting.avgdl,
455
+ dfs=counting.dfs,
456
+ dls=counting.dls,
457
+ k1=k1,
458
+ b=b,
459
+ )
460
+
461
+ # Assembly and save:
462
+ index = CSCBM25Index(
463
+ posting_lists_matrix=posting_lists_matrix,
464
+ vocab=counting.vocab,
465
+ cid2docid=counting.cid2docid,
466
+ collection_ids=counting.collection_ids,
467
+ doc_texts=counting.doc_texts,
468
+ )
469
+ return index
470
+
471
+ class BaseCSCInvertedIndexRetriever(BaseRetriever):
472
+
473
+ @property
474
+ @abstractmethod
475
+ def index_class(self) -> Type[CSCInvertedIndex]:
476
+ pass
477
+
478
+ def __init__(self, index_dir: str) -> None:
479
+ self.index = self.index_class.from_saved(index_dir)
480
+
481
+ def get_term_weights(self, query: str, cid: str) -> Dict[str, float]:
482
+ ## YOUR_CODE_STARTS_HERE
483
+ toks = self.index.tokenize(query)
484
+ target_docid = self.index.cid2docid[cid]
485
+ term_weights = {}
486
+
487
+ for tok in toks:
488
+ if tok not in self.index.vocab:
489
+ continue
490
+ tid = self.index.vocab[tok]
491
+ weight = self.index.posting_lists_matrix[target_docid, tid]
492
+ term_weights[tok] = weight
493
+ return term_weights
494
+ ## YOUR_CODE_ENDS_HERE
495
+
496
+ def score(self, query: str, cid: str) -> float:
497
+ return sum(self.get_term_weights(query=query, cid=cid).values())
498
+
499
+ def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]:
500
+ ## YOUR_CODE_STARTS_HERE
501
+ toks = self.index.tokenize(query)
502
+ docid2score: Dict[int, float] = {}
503
+ for tok in toks:
504
+ if tok not in self.index.vocab:
505
+ continue
506
+ tid = self.index.vocab[tok]
507
+ col = self.index.posting_lists_matrix[:, tid]
508
+ rows, data = col.indices, col.data
509
+
510
+ for docid, tweight in zip(rows, data):
511
+ docid2score.setdefault(docid, 0)
512
+ docid2score[docid] += tweight
513
+
514
+ docid2score = dict(
515
+ sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk]
516
+ )
517
+ return {
518
+ self.index.collection_ids[docid]: score
519
+ for docid, score in docid2score.items()
520
+ }
521
+ ## YOUR_CODE_ENDS_HERE
522
+
523
+ class CSCBM25Retriever(BaseCSCInvertedIndexRetriever):
524
+
525
+ @property
526
+ def index_class(self) -> Type[CSCBM25Index]:
527
+ return CSCBM25Index
528
+
529
+ class Hit(TypedDict):
530
+ cid: str
531
+ score: float
532
+ text: str
533
+
534
+ # Best b 0.9
535
+ # Best k1: 0.4
536
+ best_b = 0.9
537
+ best_k1 = 0.4
538
+ sciq = load_sciq()
539
+ csc_bm25_index = CSCBM25Index.build_from_documents(
540
+ documents=iter(sciq.corpus),
541
+ ndocs=12160,
542
+ show_progress_bar=True,
543
+ k1=best_k1,
544
+ b=best_b
545
+ )
546
+ csc_bm25_index.save("output/csc_bm25_index")
547
+
548
+ demo: Optional[gr.Interface] = None # Assign your gradio demo to this variable
549
+ return_type = List[Hit]
550
+
551
+ ## YOUR_CODE_STARTS_HERE
552
+ csc_bm25_retriever = CSCBM25Retriever(index_dir="output/csc_bm25_index")
553
+ doc2text = {doc.collection_id: doc.text for doc in sciq.corpus}
554
+
555
+ def retrieve(query: str) -> List[Hit]:
556
+ results = csc_bm25_retriever.retrieve(query)
557
+
558
+ hits: List[Hit] = []
559
+ for cid, score in results.items():
560
+ hit: Hit = {
561
+ "cid": cid,
562
+ "score": score,
563
+ "text": doc2text[cid]
564
+ }
565
+ hits.append(hit)
566
+ hits = sorted(hits, key=lambda x: x["score"], reverse=True)
567
+ return hits
568
+
569
+ def format_hits(hits: List[Hit]):
570
+ output = ""
571
+ for i, hit in enumerate(hits, 1):
572
+ output += f"\n\n{i}. Score: {hit['score']:.3f}\n"
573
+ output += f"ID: {hit['cid']}\n"
574
+ output += f"Text: {hit['text']}\n"
575
+ output += "-" * 80
576
+ return output
577
+
578
+ demo = gr.Interface(
579
+ fn=retrieve,
580
+ inputs=gr.Textbox(label="Query"),
581
+ outputs=gr.JSON(label="Results"),
582
+ title="Document Search",
583
+ description="Search documents using BM25 retrieval"
584
+ )
585
+ ## YOUR_CODE_ENDS_HERE
586
+ demo.launch()