retrieve / app.py
Sean MacAvaney
update
7e8e133
raw
history blame
2.38 kB
import pandas as pd
import gradio as gr
import pyterrier as pt
pt.init()
from pyterrier_gradio import Demo, MarkdownFile, interface, df2code, code2md, EX_Q
retr = pt.TerrierRetrieve.from_dataset('vaswani', 'terrier_stemmed')
sdm = pt.rewrite.SDM()
COLAB_NAME = 'pyterrier_retrieve.ipynb'
COLAB_INSTALL = '''
!pip install -q python-terrier
'''.strip()
def predict(input, _, wmodel, num_results, pipe_text):
retr.controls["wmodel"] = wmodel
retr.controls["end"] = str(num_results -1)
code = f'''import pandas as pd
import pyterrier as pt ; pt.init()
retr = pt.TerrierRetrieve.from_dataset('vaswani', 'terrier_stemmed', wmodel={repr(wmodel)}, num_results={num_results})
'''
pipeline = retr
if pipe_text:
pipeline = pipeline >> pt.text.get_text(pt.get_dataset('irds:vaswani'), 'text')
code += f'''
pipeline = retr >> pt.text.get_text(pt.get_dataset('irds:vaswani'), 'text')
pipeline({df2code(input)})'''
else:
code += f'''
retr({df2code(input)})'''
res = pipeline(input)
res['score'] = res['score'].map(lambda x: round(x, 2))
return (res, code2md(code, COLAB_INSTALL, COLAB_NAME))
def predict_sdm(input):
code = f'''import pandas as pd
import pyterrier as pt ; pt.init()
sdm = pt.rewrite.SDM()
sdm({df2code(input)})
'''
res = sdm(input)
return (res, code2md(code, COLAB_INSTALL, COLAB_NAME))
Q = pd.DataFrame([
['1', 'measurement of dielectric constant of liquids by the use of microwave techniques'],
['2', 'mathematical analysis and design details of waveguide fed microwave radiations'],
['3', 'use of digital computers in the design of band pass filters having given phase and attenuation characteristics'],
], columns=['qid', 'query'])
interface(
MarkdownFile('README.md'),
Demo(
predict,
Q,
[
gr.Dropdown(
choices=['vaswani stemmed'],
value='vaswani stemmed',
label='Index',
interactive=False,
), gr.Dropdown(
choices=['TF_IDF', 'BM25', 'PL2', 'DPH'],
value='BM25',
label='Retrieval Model',
), gr.Slider(
minimum=1,
maximum=10,
value=5,
step=1.,
label='# Results'
), gr.Checkbox(
value=True,
label="Include get_text in pipeline",
)],
scale=2/3
),
MarkdownFile('sdm.md'),
Demo(
predict_sdm,
Q,
[],
scale=2/3
),
MarkdownFile('wrapup.md'),
).launch(share=False)