Spaces:
Sleeping
Sleeping
File size: 4,786 Bytes
58e78d3 13ba238 2546424 e55642f 13ba238 883179e 1372f55 13ba238 250af9b feaef97 250af9b 13ba238 e55642f feaef97 e55642f feaef97 e55642f feaef97 13ba238 e55642f 13ba238 e55642f 13ba238 e55642f ce91faa 5c300c5 e55642f feaef97 e55642f feaef97 13ba238 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import streamlit as st
import pandas as pd
import json
from note import SUBMISSION
from st_aggrid import JsCode
from st_aggrid import AgGrid, GridOptionsBuilder
def load_data():
with open("data.jsonl", 'r', encoding='utf-8') as file:
data = [json.loads(line) for line in file]
df = pd.DataFrame(data)
return df
def case_insensitive_search(data, query, column):
if query:
return data[data[column].str.lower().str.contains(query.lower())]
return data
def colored_note(text, background_color='#fcfced', text_color='black'):
st.markdown(f"""
<div style='background-color: {background_color}; color: {text_color};
border-radius: 8px; padding: 10px; margin: 8px 0; box-shadow: 2px 2px 5px grey;'>
{text}
""", unsafe_allow_html=True)
html_render = JsCode(
"""
class UrlCellRenderer {
init(params) {
this.eGui = document.createElement('span');
if (params.value && params.value.includes('href=\\"')) {
const parser = new DOMParser();
const parsedHtml = parser.parseFromString(params.value, 'text/html');
const link = parsedHtml.querySelector('a');
if (link) {
this.eGui = document.createElement('a');
this.eGui.setAttribute('href', link.getAttribute('href'));
this.eGui.innerText = link.innerText;
this.eGui.setAttribute('style', "text-decoration:none");
this.eGui.setAttribute('target', "_blank");
} else {
this.eGui.innerText = params.value;
}
} else {
this.eGui.innerText = params.value;
}
}
getGui() {
return this.eGui;
}
}
"""
)
def display_table(data, rows_per_page=12):
st.markdown("""
<style>
.centered {
display: flex;
justify-content: center;
}
.css-1l02zno {
flex: 1;
}
</style>
""", unsafe_allow_html=True)
container = st.container()
gb = GridOptionsBuilder.from_dataframe(data)
gb.configure_columns(['Chat Model','Embedding Model','Reranker Model','Framework'],
cellRenderer=html_render,
sortable=True, filterable=True, resizable=True, )
gb.configure_column("Accuracy", sort='desc')
gridOptions = gb.build()
with container:
height = min(40 + rows_per_page * 38, 800)
col2, col3 = st.columns([5, 3])
with col2:
st.markdown("""
<style>
.ag-theme-balham {
height: 500px;
width: 50%;
margin: auto;
}
</style>
""", unsafe_allow_html=True)
AgGrid(data, height=height, gridOptions=gridOptions, allow_unsafe_jscode=True)
with col3:
colored_note(SUBMISSION)
def main():
st.set_page_config(layout="wide")
st.title("Multihop-RAG π‘")
st.write("Displaying results across different frameworks, embedding models, chat models, and chunks.")
data = load_data()
st.markdown("""
<style>
div.stButton > button:first-child {
height: 2em;
width: 100%;
margin-top: 1.8em;
}
</style>
""", unsafe_allow_html=True)
col1, col2, col3, col4, col5 = st.columns(5)
with col1:
chat_model_query = st.text_input("Chat Model", key="chat_model_query")
with col2:
embedding_model_query = st.text_input("Embedding Model", key="embedding_model_query")
with col3:
chunk_query = st.text_input("Chunk", key="chunk_query")
with col4:
frame_query = st.text_input("Framework", key="frame_query")
with col5:
search_button = st.button("π Search")
if search_button:
if chat_model_query:
data = case_insensitive_search(data, chat_model_query, 'Chat Model')
if embedding_model_query:
data = case_insensitive_search(data, embedding_model_query, 'Embedding Model')
if chunk_query:
data = case_insensitive_search(data, chunk_query, 'Chunk Size')
if frame_query:
data = case_insensitive_search(data, frame_query, 'Framework')
st.info("Retrieval Stage: MRR@10 and Hit@10; Response Stage: Accuracy ")
display_table(data)
st.markdown("---")
st.caption("For citation, please use: 'Tang, Yixuan, and Yi Yang. MultiHop-RAG: Benchmarking Retrieval-Augmented Generation for Multi-Hop Queries. ArXiv, 2024, /abs/2401.15391. '")
# st.markdown("---")
# st.caption("For results self-reporting, please send an email to ytangch@connect.ust.hk")
if __name__ == "__main__":
main() |