File size: 4,786 Bytes
58e78d3
13ba238
2546424
e55642f
 
 
13ba238
883179e
 
 
1372f55
13ba238
250af9b
feaef97
250af9b
13ba238
 
e55642f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
feaef97
e55642f
 
 
 
 
 
feaef97
e55642f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
feaef97
13ba238
e55642f
 
 
13ba238
 
 
e55642f
 
 
 
 
 
 
 
 
 
13ba238
e55642f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce91faa
5c300c5
e55642f
feaef97
 
e55642f
 
feaef97
13ba238
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import streamlit as st
import pandas as pd
import json
from note import SUBMISSION
from st_aggrid import JsCode
from st_aggrid import AgGrid, GridOptionsBuilder
def load_data():
    with open("data.jsonl", 'r', encoding='utf-8') as file:
        data = [json.loads(line) for line in file]
    df = pd.DataFrame(data)
    return df

def case_insensitive_search(data, query, column):
    if query: 
        return data[data[column].str.lower().str.contains(query.lower())]
    return data

def colored_note(text, background_color='#fcfced', text_color='black'):
    st.markdown(f"""
        <div style='background-color: {background_color}; color: {text_color}; 
                    border-radius: 8px; padding: 10px; margin: 8px 0; box-shadow: 2px 2px 5px grey;'>
            {text}
            """, unsafe_allow_html=True)

html_render = JsCode(
        """
    class UrlCellRenderer {
      init(params) {
        this.eGui = document.createElement('span'); 
        if (params.value && params.value.includes('href=\\"')) {
          const parser = new DOMParser();
          const parsedHtml = parser.parseFromString(params.value, 'text/html');
          const link = parsedHtml.querySelector('a');
          if (link) {
            this.eGui = document.createElement('a');
            this.eGui.setAttribute('href', link.getAttribute('href'));
            this.eGui.innerText = link.innerText; 
            this.eGui.setAttribute('style', "text-decoration:none");
            this.eGui.setAttribute('target', "_blank");
          } else {
            this.eGui.innerText = params.value; 
          }
        } else {
          this.eGui.innerText = params.value;
        }
      }
      getGui() {
        return this.eGui;
      }
    }
    """
) 
def display_table(data, rows_per_page=12):
    st.markdown("""
        <style>
        .centered {
            display: flex;
            justify-content: center;
        }
        .css-1l02zno {
            flex: 1;
        }
        </style>
        """, unsafe_allow_html=True)
    container = st.container()
    gb = GridOptionsBuilder.from_dataframe(data)
    gb.configure_columns(['Chat Model','Embedding Model','Reranker Model','Framework'], 
                        cellRenderer=html_render,
                        sortable=True, filterable=True, resizable=True, )
    gb.configure_column("Accuracy", sort='desc') 
    gridOptions = gb.build()
    with container:
        height = min(40 + rows_per_page * 38, 800)
        col2, col3 = st.columns([5, 3])
        with col2:
            st.markdown("""
            <style>
                .ag-theme-balham {
                    height: 500px; 
                    width: 50%;
                    margin: auto; 
                }
            </style>
            """, unsafe_allow_html=True)
            AgGrid(data, height=height, gridOptions=gridOptions, allow_unsafe_jscode=True)
        with col3:
            colored_note(SUBMISSION)

def main():
    st.set_page_config(layout="wide")
    st.title("Multihop-RAG πŸ’‘")
    st.write("Displaying results across different frameworks, embedding models, chat models, and chunks.")

    data = load_data()

    st.markdown("""
    <style>
    div.stButton > button:first-child {
        height: 2em;     
        width: 100%;    
        margin-top: 1.8em;
    }
    </style>
    """, unsafe_allow_html=True)
    col1, col2, col3, col4, col5 = st.columns(5)

    with col1:
        chat_model_query = st.text_input("Chat Model", key="chat_model_query")
    with col2:
        embedding_model_query = st.text_input("Embedding Model", key="embedding_model_query")
    with col3:
        chunk_query = st.text_input("Chunk", key="chunk_query")
    with col4:
        frame_query = st.text_input("Framework", key="frame_query")
    with col5:
        search_button = st.button("πŸ” Search")

    if search_button:
        if chat_model_query:
            data = case_insensitive_search(data, chat_model_query, 'Chat Model')
        if embedding_model_query:
            data = case_insensitive_search(data, embedding_model_query, 'Embedding Model')
        if chunk_query:
            data = case_insensitive_search(data, chunk_query, 'Chunk Size')
        if frame_query:
            data = case_insensitive_search(data, frame_query, 'Framework')
            
    st.info("Retrieval Stage: MRR@10 and Hit@10; Response Stage: Accuracy ")

    display_table(data)
    st.markdown("---")
    st.caption("For citation, please use: 'Tang, Yixuan, and Yi Yang. MultiHop-RAG: Benchmarking Retrieval-Augmented Generation for Multi-Hop Queries. ArXiv, 2024,  /abs/2401.15391. '")
    # st.markdown("---")
    # st.caption("For results self-reporting, please send an email to ytangch@connect.ust.hk")
 
if __name__ == "__main__":
    main()