|
import gradio as gr |
|
from typing import TypedDict, List |
|
from nlp4web_codebase.ir.data_loaders.sciq import load_sciq |
|
|
|
sciq = load_sciq() |
|
sciq.corpus |
|
|
|
class Hit(TypedDict): |
|
cid: str |
|
score: float |
|
text: str |
|
|
|
return_type = List[Hit] |
|
|
|
|
|
def search(query: str) -> List[Hit]: |
|
bm25_index = BM25Index.build_from_documents( |
|
documents=iter(sciq.corpus), |
|
ndocs=12160, |
|
show_progress_bar=True |
|
) |
|
bm25_index.save("output/bm25_index") |
|
bm25_retriever = BM25Retriever(index_dir="output/bm25_index") |
|
ranking = bm25_retriever.retrieve(query=query) |
|
hits = [] |
|
for cid, score in ranking.items(): |
|
doc = next((doc for doc in sciq.corpus if doc.collection_id == cid), None) |
|
if doc: |
|
hits.append({"cid": cid, "score": score, "text": doc.text}) |
|
return hits |
|
|
|
demo = gr.Interface( |
|
fn=search, |
|
inputs=gr.Textbox(lines=2, placeholder="Enter your query here..."), |
|
outputs=gr.JSON(label="Search Results"), |
|
title="SciQ Search Engine", |
|
description="Enter a query to search the SciQ dataset using BM25.", |
|
) |
|
|
|
demo.launch() |