MultiHop-RAG / app.py
yixuantt's picture
first_commit
e55642f
import streamlit as st
import pandas as pd
import json
from note import SUBMISSION
from st_aggrid import JsCode
from st_aggrid import AgGrid, GridOptionsBuilder
def load_data():
with open("data.jsonl", 'r', encoding='utf-8') as file:
data = [json.loads(line) for line in file]
df = pd.DataFrame(data)
return df
def case_insensitive_search(data, query, column):
if query:
return data[data[column].str.lower().str.contains(query.lower())]
return data
def colored_note(text, background_color='#fcfced', text_color='black'):
st.markdown(f"""
<div style='background-color: {background_color}; color: {text_color};
border-radius: 8px; padding: 10px; margin: 8px 0; box-shadow: 2px 2px 5px grey;'>
{text}
""", unsafe_allow_html=True)
html_render = JsCode(
"""
class UrlCellRenderer {
init(params) {
this.eGui = document.createElement('span');
if (params.value && params.value.includes('href=\\"')) {
const parser = new DOMParser();
const parsedHtml = parser.parseFromString(params.value, 'text/html');
const link = parsedHtml.querySelector('a');
if (link) {
this.eGui = document.createElement('a');
this.eGui.setAttribute('href', link.getAttribute('href'));
this.eGui.innerText = link.innerText;
this.eGui.setAttribute('style', "text-decoration:none");
this.eGui.setAttribute('target', "_blank");
} else {
this.eGui.innerText = params.value;
}
} else {
this.eGui.innerText = params.value;
}
}
getGui() {
return this.eGui;
}
}
"""
)
def display_table(data, rows_per_page=12):
st.markdown("""
<style>
.centered {
display: flex;
justify-content: center;
}
.css-1l02zno {
flex: 1;
}
</style>
""", unsafe_allow_html=True)
container = st.container()
gb = GridOptionsBuilder.from_dataframe(data)
gb.configure_columns(['Chat Model','Embedding Model','Reranker Model','Framework'],
cellRenderer=html_render,
sortable=True, filterable=True, resizable=True, )
gb.configure_column("Accuracy", sort='desc')
gridOptions = gb.build()
with container:
height = min(40 + rows_per_page * 38, 800)
col2, col3 = st.columns([5, 3])
with col2:
st.markdown("""
<style>
.ag-theme-balham {
height: 500px;
width: 50%;
margin: auto;
}
</style>
""", unsafe_allow_html=True)
AgGrid(data, height=height, gridOptions=gridOptions, allow_unsafe_jscode=True)
with col3:
colored_note(SUBMISSION)
def main():
st.set_page_config(layout="wide")
st.title("Multihop-RAG πŸ’‘")
st.write("Displaying results across different frameworks, embedding models, chat models, and chunks.")
data = load_data()
st.markdown("""
<style>
div.stButton > button:first-child {
height: 2em;
width: 100%;
margin-top: 1.8em;
}
</style>
""", unsafe_allow_html=True)
col1, col2, col3, col4, col5 = st.columns(5)
with col1:
chat_model_query = st.text_input("Chat Model", key="chat_model_query")
with col2:
embedding_model_query = st.text_input("Embedding Model", key="embedding_model_query")
with col3:
chunk_query = st.text_input("Chunk", key="chunk_query")
with col4:
frame_query = st.text_input("Framework", key="frame_query")
with col5:
search_button = st.button("πŸ” Search")
if search_button:
if chat_model_query:
data = case_insensitive_search(data, chat_model_query, 'Chat Model')
if embedding_model_query:
data = case_insensitive_search(data, embedding_model_query, 'Embedding Model')
if chunk_query:
data = case_insensitive_search(data, chunk_query, 'Chunk Size')
if frame_query:
data = case_insensitive_search(data, frame_query, 'Framework')
st.info("Retrieval Stage: MRR@10 and Hit@10; Response Stage: Accuracy ")
display_table(data)
st.markdown("---")
st.caption("For citation, please use: 'Tang, Yixuan, and Yi Yang. MultiHop-RAG: Benchmarking Retrieval-Augmented Generation for Multi-Hop Queries. ArXiv, 2024, /abs/2401.15391. '")
# st.markdown("---")
# st.caption("For results self-reporting, please send an email to ytangch@connect.ust.hk")
if __name__ == "__main__":
main()