antoinelouis commited on
Commit
bd5481a
1 Parent(s): 56d2916

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +249 -0
app.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ from huggingface_hub import HfApi
4
+
5
+
6
+ DATASETS = [
7
+ "mMARCO-fr",
8
+ "BSARD",
9
+ ]
10
+ DENSE_SINGLE_BIENCODERS = [
11
+ "antoinelouis/biencoder-camembert-base-mmarcoFR",
12
+ "antoinelouis/biencoder-distilcamembert-mmarcoFR",
13
+ "antoinelouis/biencoder-mMiniLMv2-L12-mmarcoFR",
14
+ "antoinelouis/biencoder-camemberta-base-mmarcoFR",
15
+ "antoinelouis/biencoder-electra-base-french-mmarcoFR",
16
+ "antoinelouis/biencoder-mMiniLMv2-L6-mmarcoFR",
17
+ "antoinelouis/biencoder-camembert-L10-mmarcoFR",
18
+ "antoinelouis/biencoder-camembert-L8-mmarcoFR",
19
+ "antoinelouis/biencoder-camembert-L6-mmarcoFR",
20
+ "antoinelouis/biencoder-camembert-L4-mmarcoFR",
21
+ "antoinelouis/biencoder-camembert-L2-mmarcoFR",
22
+ ]
23
+ DENSE_MULTI_BIENCODERS = [
24
+ "antoinelouis/colbertv1-camembert-base-mmarcoFR",
25
+ "antoinelouis/colbertv2-camembert-L4-mmarcoFR",
26
+ "antoinelouis/colbert-xm",
27
+ ]
28
+ SPARSE_SINGLE_BIENCODERS = []
29
+ CROSS_ENCODERS = []
30
+ LLMS = []
31
+ COLUMNS = {
32
+ "Model": "html",
33
+ "#Params (M)": "number",
34
+ "Type": "str",
35
+ "Dataset": "str",
36
+ "Recall@1000": "number",
37
+ "Recall@500": "number",
38
+ "Recall@100": "number",
39
+ "Recall@10": "number",
40
+ "MRR@10": "number",
41
+ "nDCG@10": "number",
42
+ "MAP@10": "number",
43
+ }
44
+
45
+
46
+ def get_model_info(model_id: str, model_type: str) -> pd.DataFrame:
47
+ data = {}
48
+ api = HfApi()
49
+ model_info = api.model_info(model_id)
50
+ for result in model_info.card_data.eval_results:
51
+ if result.dataset_name in DATASETS and result.dataset_name not in data:
52
+ data[result.dataset_name] = {key: None for key in COLUMNS.keys()}
53
+ data[result.dataset_name]["Model"] = f'<a href="https://huggingface.co/{model_id}" target="_blank" style="color: blue; text-decoration: none;">{model_id}</a>'
54
+ data[result.dataset_name]["#Params (M)"] = round(model_info.safetensors.total/1e6) if model_info.safetensors else None
55
+ data[result.dataset_name]["Type"] = model_type
56
+ data[result.dataset_name]["Dataset"] = result.dataset_name
57
+
58
+ if result.dataset_name in DATASETS and result.metric_name in data[result.dataset_name]:
59
+ data[result.dataset_name][result.metric_name] = result.metric_value
60
+
61
+ return pd.DataFrame(list(data.values()))
62
+
63
+ def load_all_results() -> pd.DataFrame:
64
+ df = pd.DataFrame()
65
+ for model_id in DENSE_SINGLE_BIENCODERS:
66
+ df = pd.concat([df, get_model_info(model_id, model_type="DSVBE")])
67
+ for model_id in DENSE_MULTI_BIENCODERS:
68
+ df = pd.concat([df, get_model_info(model_id, model_type="DMVBE")])
69
+ for model_id in SPARSE_SINGLE_BIENCODERS:
70
+ df = pd.concat([df, get_model_info(model_id, model_type="SSVBE")])
71
+ for model_id in CROSS_ENCODERS:
72
+ df = pd.concat([df, get_model_info(model_id, model_type="CE")])
73
+ for model_id in LLMS:
74
+ df = pd.concat([df, get_model_info(model_id, model_type="LLM")])
75
+ return df
76
+
77
+ def filter_dataf_by_dataset(dataf: pd.DataFrame, dataset_name: str, sort_by: str) -> pd.DataFrame:
78
+ return (dataf
79
+ .loc[dataf["Dataset"] == dataset_name]
80
+ .drop(columns=["Dataset"])
81
+ .sort_values(by=sort_by, ascending=False)
82
+ )
83
+
84
+
85
+ def update_table(dataf: pd.DataFrame, query: str, selected_types: list, selected_sizes: list) -> pd.DataFrame:
86
+ filtered_df = dataf.copy()
87
+ conditions = []
88
+
89
+ for val in selected_types:
90
+ if val == 'Dense single-vector bi-encoder (DSVBE)':
91
+ conditions.append((filtered_df['Type'] == 'DSVBE'))
92
+ elif val == 'Dense multi-vector bi-encoder (DMVBE)':
93
+ conditions.append((filtered_df['Type'] == 'DMVBE'))
94
+ elif val == 'Sparse single-vector bi-encoder (SSVBE)':
95
+ conditions.append((filtered_df['Type'] == 'SSVBE'))
96
+ elif val == 'Cross-encoder (CE)':
97
+ conditions.append((filtered_df['Type'] == 'CE'))
98
+ elif val == 'LLM':
99
+ conditions.append((filtered_df['Type'] == 'LLM'))
100
+
101
+ for val in selected_sizes:
102
+ if val == 'Small (< 100M)':
103
+ conditions.append((filtered_df['#Params (M)'] < 100))
104
+ elif val == 'Base (100M-300M)':
105
+ conditions.append((filtered_df['#Params (M)'] >= 100) & (filtered_df['#Params (M)'] <= 300))
106
+ elif val == 'Large (300M-500M)':
107
+ conditions.append((filtered_df['#Params (M)'] >= 300) & (filtered_df['#Params (M)'] <= 500))
108
+ elif val == 'Extra-large (500M+)':
109
+ conditions.append((filtered_df['#Params (M)'] > 500))
110
+
111
+ if conditions:
112
+ filtered_df = filtered_df[pd.concat(conditions, axis=1).any(axis=1)]
113
+
114
+ if query:
115
+ filtered_df = filtered_df[filtered_df['Model'].str.contains(query, case=False)]
116
+
117
+ return filtered_df
118
+
119
+
120
+ with gr.Blocks() as demo:
121
+ gr.HTML("""
122
+ <div style="display: flex; flex-direction: column; align-items: center;">
123
+ <div style="align-self: flex-start;">
124
+ <a href="mailto:antoiloui@gmail.com" target="_blank" style="color: blue; text-decoration: none;">Contact/Submissions</a>
125
+ </div>
126
+ <h1 style="margin: 0;">🥇 DécouvrIR\n</h1>A Benchmark for Evaluating the Robustness of Information Retrieval Models in French</h1>
127
+ </div>
128
+ """)
129
+
130
+ # Create the Pandas dataframes (one per dataset)
131
+ all_df = load_all_results()
132
+ mmarco_df = filter_dataf_by_dataset(all_df, dataset_name="mMARCO-fr", sort_by="Recall@500")
133
+ bsard_df = filter_dataf_by_dataset(all_df, dataset_name="BSARD", sort_by="Recall@500")
134
+
135
+ # Search and filter widgets
136
+ with gr.Column():
137
+ with gr.Row():
138
+ search_bar = gr.Textbox(placeholder=" 🔍 Search for a model...", show_label=False, elem_id="search-bar")
139
+
140
+ with gr.Row():
141
+ filter_type = gr.CheckboxGroup(
142
+ label="Model type",
143
+ choices=[
144
+ 'Dense single-vector bi-encoder (DSVBE)',
145
+ 'Dense multi-vector bi-encoder (DMVBE)',
146
+ 'Sparse single-vector bi-encoder (SSVBE)',
147
+ 'Cross-encoder (CE)',
148
+ 'LLM',
149
+ ],
150
+ value=[],
151
+ interactive=True,
152
+ elem_id="filter-type",
153
+ )
154
+
155
+ with gr.Row():
156
+ filter_size = gr.CheckboxGroup(
157
+ label="Model size",
158
+ choices=['Small (< 100M)', 'Base (100M-300M)', 'Large (300M-500M)', 'Extra-large (500M+)'],
159
+ value=[],
160
+ interactive=True,
161
+ elem_id="filter-size",
162
+ )
163
+
164
+ # Leaderboard tables
165
+ with gr.Tabs():
166
+ with gr.TabItem("🌐 mMARCO-fr"):
167
+ gr.HTML("""
168
+ <p>The <a href="https://huggingface.co/datasets/unicamp-dl/mmarco" target="_blank" style="color: blue; text-decoration: none;">mMARCO</a> dataset is a machine-translated version of
169
+ the widely popular MS MARCO dataset across 13 languages (including French) for studying <strong> domain-general</strong> passage retrieval.</p>
170
+ <p>The evaluation is performed on <strong>6,980 dev questions</strong> labeled with relevant passages to be retrieved from a corpus of <strong>8,841,823 candidates</strong>.</p>
171
+ """)
172
+ mmarco_table = gr.Dataframe(
173
+ value=mmarco_df,
174
+ datatype=[COLUMNS[col] for col in mmarco_df.columns],
175
+ interactive=False,
176
+ elem_classes="text-sm",
177
+ )
178
+
179
+ with gr.TabItem("⚖️ BSARD"):
180
+ gr.HTML("""
181
+ <p>The <a href="https://huggingface.co/datasets/maastrichtlawtech/bsard" target="_blank" style="color: blue; text-decoration: none;">Belgian Statutory Article Retrieval Dataset (BSARD)</a> is a
182
+ French native dataset for studying <strong>legal</strong> document retrieval.</p>
183
+ <p>The evaluation is performed on <strong>222 test questions</strong> labeled by experienced jurists with relevant Belgian law articles to be retrieved from a corpus of <strong>22,633 candidates</strong>.</p>
184
+ <i>[Coming soon...]</i>
185
+ """)
186
+ # bsard_table = gr.Dataframe(
187
+ # value=bsard_df,
188
+ # datatype=[COLUMNS[col] for col in bsard_df.columns],
189
+ # interactive=False,
190
+ # elem_classes="text-sm",
191
+ # )
192
+
193
+ # Update tables on search.
194
+ search_bar.change(
195
+ fn=lambda x: update_table(dataf=mmarco_df, query=x, selected_types=filter_type.value, selected_sizes=filter_size.value),
196
+ inputs=[search_bar],
197
+ outputs=mmarco_table,
198
+ )
199
+ # search_bar.change(
200
+ # fn=lambda x: update_table(dataf=bsard_df, query=x, selected_types=filter_type.value, selected_sizes=filter_size.value),
201
+ # inputs=[search_bar],
202
+ # outputs=bsard_table,
203
+ # )
204
+
205
+ # Update tables on model type filter.
206
+ filter_type.change(
207
+ fn=lambda selected_types: update_table(mmarco_df, search_bar.value, selected_types, filter_size.value),
208
+ inputs=[filter_type],
209
+ outputs=mmarco_table,
210
+ )
211
+ # filter_type.change(
212
+ # fn=lambda selected_types: update_table(bsard_df, search_bar.value, selected_types, filter_size.value),
213
+ # inputs=[filter_type],
214
+ # outputs=bsard_table,
215
+ # )
216
+
217
+ # Update tables on model size filter.
218
+ filter_size.change(
219
+ fn=lambda selected_sizes: update_table(mmarco_df, search_bar.value, filter_type.value, selected_sizes),
220
+ inputs=[filter_size],
221
+ outputs=mmarco_table,
222
+ )
223
+ # filter_size.change(
224
+ # fn=lambda selected_sizes: update_table(bsard_df, search_bar.value, filter_type.value, selected_sizes),
225
+ # inputs=[filter_size],
226
+ # outputs=bsard_table,
227
+ # )
228
+
229
+ # Citation
230
+ with gr.Column():
231
+ with gr.Row():
232
+ gr.HTML("""
233
+ <h2>Citation</h2>
234
+ <p>For attribution in academic contexts, please cite this benchmark and any of the models released by <a href="https://huggingface.co/antoinelouis" target="_blank" style="color: blue; text-decoration: none;">@antoinelouis</a> as follows:</p>
235
+ """)
236
+ with gr.Row():
237
+ citation_block = (
238
+ "@online{louis2024decouvrir,\n"
239
+ "\tauthor = 'Antoine Louis',\n"
240
+ "\ttitle = 'DécouvrIR: A Benchmark for Evaluating the Robustness of Information Retrieval Models in French',\n"
241
+ "\tpublisher = 'Hugging Face',\n"
242
+ "\tmonth = 'mar',\n"
243
+ "\tyear = '2024',\n"
244
+ "\turl = 'https://huggingface.co/spaces/antoinelouis/decouvrir',\n"
245
+ "}\n"
246
+ )
247
+ gr.Code(citation_block, language=None, show_label=False)
248
+
249
+ demo.launch()