miki5799 commited on
Commit
3b10cf9
1 Parent(s): 813e603

Enhance BM25 index creation and retrieval functionality; save index to output directory

Browse files
Files changed (1) hide show
  1. app.py +15 -8
app.py CHANGED
@@ -357,22 +357,29 @@ bm25_index = BM25Index.build_from_documents(
357
 
358
 
359
  class Hit(TypedDict):
360
- cid: str
361
- score: float
362
- text: str
363
-
364
 
365
  demo: Optional[gr.Interface] = None # Assign your gradio demo to this variable
366
  return_type = List[Hit]
367
-
 
 
 
 
 
 
 
 
368
 
369
  ## YOUR_CODE_STARTS_HERE
370
  def retrieve(query: str, topk: int = 10) -> return_type:
371
- ranking = bm25_retriever.retrieve(query=query, topk=3)
372
  hits = []
373
  for cid, score in ranking.items():
374
  text = bm25_retriever.index.doc_texts[bm25_retriever.index.cid2docid[cid]]
375
- hits.append({"cid": cid, "score": score, "text": text})
376
  return hits
377
 
378
 
@@ -388,5 +395,5 @@ demo = gr.Interface(
388
  ["What are the symptoms of immunodeficiency?"],
389
  ],
390
  )
391
- ## YOUR_CODE_ENDS_HERE
392
  demo.launch()
 
 
357
 
358
 
359
  class Hit(TypedDict):
360
+ cid: str
361
+ score: float
362
+ text: str
 
363
 
364
  demo: Optional[gr.Interface] = None # Assign your gradio demo to this variable
365
  return_type = List[Hit]
366
+ bm25_index = BM25Index.build_from_documents(
367
+ documents=iter(sciq.corpus),
368
+ ndocs=12160,
369
+ show_progress_bar=True,
370
+ k1=0.9,
371
+ b=0.4,
372
+ )
373
+ bm25_index.save("output/bm25_index")
374
+ bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
375
 
376
  ## YOUR_CODE_STARTS_HERE
377
  def retrieve(query: str, topk: int = 10) -> return_type:
378
+ ranking = bm25_retriever.retrieve(query=query, topk=topk)
379
  hits = []
380
  for cid, score in ranking.items():
381
  text = bm25_retriever.index.doc_texts[bm25_retriever.index.cid2docid[cid]]
382
+ hits.append(Hit(cid=cid, score=score, text=text))
383
  return hits
384
 
385
 
 
395
  ["What are the symptoms of immunodeficiency?"],
396
  ],
397
  )
 
398
  demo.launch()
399
+ ## YOUR_CODE_ENDS_HERE