SeaEval_Leaderboard / app /draw_diagram.py
zhuohan-7's picture
Upload folder using huggingface_hub
863c2e6 verified
raw
history blame
23.7 kB
import streamlit as st
import pandas as pd
import numpy as np
from streamlit_echarts import st_echarts
# from streamlit_echarts import JsCode
from streamlit_javascript import st_javascript
# from PIL import Image
# links_dic = {"random": "https://seaeval.github.io/",
# "meta_llama_3_8b": "https://huggingface.co/meta-llama/Meta-Llama-3-8B",
# "mistral_7b_instruct_v0_2": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2",
# "sailor_0_5b": "https://huggingface.co/sail/Sailor-0.5B",
# "sailor_1_8b": "https://huggingface.co/sail/Sailor-1.8B",
# "sailor_4b": "https://huggingface.co/sail/Sailor-4B",
# "sailor_7b": "https://huggingface.co/sail/Sailor-7B",
# "sailor_0_5b_chat": "https://huggingface.co/sail/Sailor-0.5B-Chat",
# "sailor_1_8b_chat": "https://huggingface.co/sail/Sailor-1.8B-Chat",
# "sailor_4b_chat": "https://huggingface.co/sail/Sailor-4B-Chat",
# "sailor_7b_chat": "https://huggingface.co/sail/Sailor-7B-Chat",
# "sea_mistral_highest_acc_inst_7b": "https://seaeval.github.io/",
# "meta_llama_3_8b_instruct": "https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct",
# "flan_t5_base": "https://huggingface.co/google/flan-t5-base",
# "flan_t5_large": "https://huggingface.co/google/flan-t5-large",
# "flan_t5_xl": "https://huggingface.co/google/flan-t5-xl",
# "flan_t5_xxl": "https://huggingface.co/google/flan-t5-xxl",
# "flan_ul2": "https://huggingface.co/google/flan-t5-ul2",
# "flan_t5_small": "https://huggingface.co/google/flan-t5-small",
# "mt0_xxl": "https://huggingface.co/bigscience/mt0-xxl",
# "seallm_7b_v2": "https://huggingface.co/SeaLLMs/SeaLLM-7B-v2",
# "gpt_35_turbo_1106": "https://openai.com/blog/chatgpt",
# "meta_llama_3_70b": "https://huggingface.co/meta-llama/Meta-Llama-3-70B",
# "meta_llama_3_70b_instruct": "https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct",
# "sea_lion_3b": "https://huggingface.co/aisingapore/sea-lion-3b",
# "sea_lion_7b": "https://huggingface.co/aisingapore/sea-lion-7b",
# "qwen1_5_110b": "https://huggingface.co/Qwen/Qwen1.5-110B",
# "qwen1_5_110b_chat": "https://huggingface.co/Qwen/Qwen1.5-110B-Chat",
# "llama_2_7b_chat": "https://huggingface.co/meta-llama/Llama-2-7b-chat-hf",
# "gpt4_1106_preview": "https://openai.com/blog/chatgpt",
# "gemma_2b": "https://huggingface.co/google/gemma-2b",
# "gemma_7b": "https://huggingface.co/google/gemma-7b",
# "gemma_2b_it": "https://huggingface.co/google/gemma-2b-it",
# "gemma_7b_it": "https://huggingface.co/google/gemma-7b-it",
# "qwen_1_5_7b": "https://huggingface.co/Qwen/Qwen1.5-7B",
# "qwen_1_5_7b_chat": "https://huggingface.co/Qwen/Qwen1.5-7B-Chat",
# "sea_lion_7b_instruct": "https://huggingface.co/aisingapore/sea-lion-7b-instruct",
# "sea_lion_7b_instruct_research": "https://huggingface.co/aisingapore/sea-lion-7b-instruct-research",
# "LLaMA_3_Merlion_8B": "https://seaeval.github.io/",
# "LLaMA_3_Merlion_8B_v1_1": "https://seaeval.github.io/"}
# links_dic = {k.lower().replace('_', '-') : v for k, v in links_dic.items()}
# # huggingface_image = Image.open('style/huggingface.jpg')
# def nav_to(value):
# try:
# url = links_dic[str(value).lower()]
# js = f'window.open("{url}", "_blank").then(r => window.parent.location.href);'
# st_javascript(js)
# except:
# pass
# # nav_script = """
# # <meta http-equiv="refresh" content="0; url='%s'">
# # """ % (url)
# # st.write(nav_script, unsafe_allow_html=True)
# def highlight_table_line(model_name):
# st.write(model_name)
def draw_cross_lingual(category_one, category_two, sort, sorted):
folder = "./results/cross_lingual/"
subtitle = ''
data_path = f'{folder}/{category_one}/{category_two}.csv'
chart_data = pd.read_csv(data_path).dropna(axis='columns').round(3)
if sorted == 'Ascending':
ascend = True
else:
ascend = False
chart_data = chart_data.sort_values(by=[sort], ascending=ascend)
min_value = round(chart_data.iloc[:, 1::].min().min() - 0.1, 1)
max_value = round(chart_data.iloc[:, 1::].max().max() + 0.1, 1)
st.markdown("""
<style>
.stMultiSelect [data-baseweb=select] span{
max-width: 800px;
font-size: 0.9rem;
}
</style>
""", unsafe_allow_html=True)
models = st.multiselect("Please choose the models", chart_data['Model'].tolist(), default = chart_data['Model'].tolist())
if category_two in ['cross_mmlu', 'cross_logiqa']:
# print(category_two)
if category_two == 'cross_mmlu':
subtitle = 'Cross-MMLU'
elif category_two == 'cross_logiqa':
subtitle = 'Cross-LogiQA'
options = {
"title": {"text": f"{subtitle}"},
"tooltip": {
"trigger": "axis",
"axisPointer": {"type": "cross", "label": {"backgroundColor": "#6a7985"}},
"triggerOn": 'mousemove',
},
"legend": {"data": ['Overall Accuracy','Cross-Lingual Consistency', 'AC3',
'English', 'Chinese', 'Spanish', 'Vietnamese', 'Indonesian', 'Malay', 'Filipino']},
"toolbox": {"feature": {"saveAsImage": {}}},
"grid": {"left": "3%", "right": "4%", "bottom": "3%", "containLabel": True},
"xAxis": [
{
"type": "category",
"boundaryGap": True,
"triggerEvent": True,
"data": models,
}
],
"yAxis": [{"type": "value",
"min": min_value,
"max": max_value,
"boundaryGap": True
# "splitNumber": 10
}],
"series": [
{
"name": "Overall Accuracy",
"type": "bar", # "line"
"data": chart_data['Accuracy'].tolist(),
},
{
"name": "Cross-Lingual Consistency",
"type": "bar",
"data": chart_data["Cross-Lingual Consistency"].tolist(),
},
{
"name": "AC3",
"type": "bar",
"data": chart_data["AC3"].tolist(),
},
{
"name": "English",
"type": "bar",
"data": chart_data["English"].tolist(),
},
{
"name": "Chinese",
"type": "bar",
"data": chart_data["Chinese"].tolist(),
},
{
"name": "Spanish",
"type": "bar",
"data": chart_data["Spanish"].tolist(),
},
{
"name": "Vietnamese",
"type": "bar",
"data": chart_data["Vietnamese"].tolist(),
},
{
"name": "Indonesian",
"type": "bar",
"data": chart_data["Indonesian"].tolist(),
},
{
"name": "Malay",
"type": "bar",
"data": chart_data["Malay"].tolist(),
},
{
"name": "Filipino",
"type": "bar",
"data": chart_data["Filipino"].tolist(),
},
],
}
# events = {
# "click": "function(params) { return params.value }",
# # "dblclick": "function(params) { return params.value }"
# }
value = st_echarts(options=options, height="500px") #events=events,
# if value != None:
# # print(value)
# nav_to(value)
# if value != None:
# highlight_table_line(value)
elif category_two == 'cross_xquad':
subtitle = 'Cross-XQUAD'
options = {
"title": {"text": f"{subtitle}"},
"tooltip": {
"trigger": "axis",
"axisPointer": {"type": "cross", "label": {"backgroundColor": "#6a7985"}},
"triggerOn": 'mousemove',
},
"legend": {"data": ['Overall Accuracy','Cross-Lingual Consistency', 'AC3',
'English', 'Chinese', 'Spanish', 'Vietnamese', 'Indonesian', 'Malay', 'Filipino']},
"toolbox": {"feature": {"saveAsImage": {}}},
"grid": {"left": "3%", "right": "4%", "bottom": "3%", "containLabel": True},
"xAxis": [
{
"type": "category",
"boundaryGap": True,
"data": models,
}
],
"yAxis": [{"type": "value",
"min": min_value,
"max": max_value,
"boundaryGap": True
# "splitNumber": 10
}],
"series": [
{
"name": "Overall Accuracy",
"type": "bar",
"data": chart_data['Accuracy'].tolist(),
},
{
"name": "Cross-Lingual Consistency",
"type": "bar",
"data": chart_data["Cross-Lingual Consistency"].tolist(),
},
{
"name": "AC3",
"type": "bar",
"data": chart_data["AC3"].tolist(),
},
{
"name": "English",
"type": "bar",
"data": chart_data["English"].tolist(),
},
{
"name": "Chinese",
"type": "bar",
"data": chart_data["Chinese"].tolist(),
},
{
"name": "Spanish",
"type": "bar",
"data": chart_data["Spanish"].tolist(),
},
{
"name": "Vietnamese",
"type": "bar",
"data": chart_data["Vietnamese"].tolist(),
},
],
}
# events = {
# "click": "function(params) { return params.value }"
# }
value = st_echarts(options=options, height="500px")
# if value != None:
# # print(value)
# nav_to(value)
# if value != None:
# highlight_table_line(value)
### create table
st.divider()
# chart_data['Link'] = chart_data['Model'].map(links_dic)
st.dataframe(chart_data,
# column_config = {
# "Link": st.column_config.LinkColumn(
# display_text= st.image(huggingface_image)
# ),
# },
hide_index = True,
use_container_width=True)
def draw_only_acc(folder_name, category_one, category_two, sorted):
# Cultural Reasonling / General Reasoning / Emotion / Fundamental NLP Tasks
folder = f"./results/{folder_name}/"
category_two_dict = {}
if folder_name == 'cultural_reasoning':
category_two_dict = {'SG EVAL': 'sg_eval',
'SG EVAL V1 Cleaned': 'sg_eval_v1_cleaned',
'SG EVAL V2 MCQ': 'sg_eval_v2_mcq',
'SG EVAL V2 Open Ended': 'sg_eval_v2_open',
'US EVAL': 'us_eval',
'CN EVAL': 'cn_eval',
'PH EVAL': 'ph_eval'}
elif folder_name == 'general_reasoning':
category_two_dict = {'MMLU': 'mmlu',
'C Eval': 'c_eval',
'CMMLU': 'cmmlu',
'ZBench': 'zbench',
'IndoMMLU': 'indommlu'}
elif folder_name == 'emotion':
category_two_dict = {'Indonesian Emotion Classification': 'ind_emotion',
'SST2': 'sst2'}
elif folder_name == 'fundamental_nlp_tasks':
category_two_dict = {'OCNLI': 'ocnli',
'C3': 'c3',
'COLA': 'cola',
'QQP': 'qqp',
'MNLI': 'mnli',
'QNLI': 'qnli',
'WNLI': 'wnli',
'RTE': 'rte',
'MRPC': 'mrpc'}
subtitle = category_two_dict[category_two]
data_path = f'{folder}/{category_one}/{subtitle}.csv'
chart_data = pd.read_csv(data_path).round(3)
st.markdown("""
<style>
.stMultiSelect [data-baseweb=select] span{
max-width: 800px;
font-size: 0.9rem;
}
</style>
""", unsafe_allow_html=True)
models = st.multiselect("Please choose the models", chart_data['Model'].tolist(), default = chart_data['Model'].tolist())
if sorted == 'Ascending':
ascend = True
else:
ascend = False
chart_data = chart_data.sort_values(by=['Accuracy'], ascending=ascend)
min_value = round(chart_data.iloc[:, 1::].min().min() - 0.1, 1)
max_value = round(chart_data.iloc[:, 1::].max().max() + 0.1, 1)
options = {
"title": {"text": f"{category_two}"},
"tooltip": {
"trigger": "axis",
"axisPointer": {"type": "cross", "label": {"backgroundColor": "#6a7985"}},
"triggerOn": 'mousemove',
},
"legend": {"data": ['Overall Accuracy']},
"toolbox": {"feature": {"saveAsImage": {}}},
"grid": {"left": "3%", "right": "4%", "bottom": "3%", "containLabel": True},
"xAxis": [
{
"type": "category",
"boundaryGap": True,
"triggerEvent": True,
"data": models,
}
],
"yAxis": [{"type": "value",
"min": min_value,
"max": max_value,
"boundaryGap": True
# "splitNumber": 10
}],
"series": [
{
"name": "Overall Accuracy",
"type": "bar",
"data": chart_data['Accuracy'].tolist(),
},
],
}
# events = {
# "click": "function(params) { return params.value }"
# }
value = st_echarts(options=options, height="500px")
# if value != None:
# # print(value)
# nav_to(value)
# if value != None:
# highlight_table_line(value)
### create table
st.divider()
# chart_data['Link'] = chart_data['Model'].map(links_dic)
st.dataframe(chart_data,
# column_config = {
# "Link": st.column_config.LinkColumn(
# display_text= st.image(huggingface_image)
# ),
# },
hide_index = True,
use_container_width=True)
def draw_flores_translation(category_one, category_two, sorted):
folder = "./results/flores_translation/"
category_two_dict = {'Indonesian to English': 'ind2eng',
'Vitenamese to English': 'vie2eng',
'Chinese to English': 'zho2eng',
'Malay to English': 'zsm2eng'}
subtitle = category_two_dict[category_two]
data_path = f'{folder}/{category_one}/{subtitle}.csv'
chart_data = pd.read_csv(data_path).round(3)
if sorted == 'Ascending':
ascend = True
else:
ascend = False
chart_data = chart_data.sort_values(by=['BLEU'], ascending=ascend)
min_value = round(chart_data.iloc[:, 1::].min().min() - 0.1, 1)
max_value = round(chart_data.iloc[:, 1::].max().max() + 0.1, 1)
st.markdown("""
<style>
.stMultiSelect [data-baseweb=select] span{
max-width: 800px;
font-size: 0.9rem;
}
</style>
""", unsafe_allow_html=True)
models = st.multiselect("Please choose the models", chart_data['Model'].tolist(), default = chart_data['Model'].tolist())
options = {
"title": {"text": f"{category_two}"},
"tooltip": {
"trigger": "axis",
"axisPointer": {"type": "cross", "label": {"backgroundColor": "#6a7985"}},
"triggerOn": 'mousemove',
},
"legend": {"data": ['BLEU']},
"toolbox": {"feature": {"saveAsImage": {}}},
"grid": {"left": "3%", "right": "4%", "bottom": "3%", "containLabel": True},
"xAxis": [
{
"type": "category",
"boundaryGap": True,
"triggerEvent": True,
"data": models,
}
],
"yAxis": [{"type": "value",
"min": min_value,
"max": max_value,
"boundaryGap": True
# "splitNumber": 10
}],
"series": [
{
"name": "BLEU",
"type": "bar",
"data": chart_data['BLEU'].tolist(),
},
],
}
# events = {
# "click": "function(params) { return params.value }"
# }
value = st_echarts(options=options, height="500px")
# if value != None:
# # print(value)
# nav_to(value)
### create table
st.divider()
# chart_data['Link'] = chart_data['Model'].map(links_dic)
st.dataframe(chart_data,
# column_config = {
# "Link": st.column_config.LinkColumn(
# display_text= st.image(huggingface_image)
# ),
# },
hide_index = True,
use_container_width=True)
def draw_dialogue(category_one, category_two, sort, sorted):
folder = "./results/dialogue"
category_two_dict = {'DREAM': 'dream',
'SAMSum': 'samsum',
'DialogSum': 'dialogsum'}
subtitle = category_two_dict[category_two]
data_path = f'{folder}/{category_one}/{subtitle}.csv'
chart_data = pd.read_csv(data_path).round(3)
st.markdown("""
<style>
.stMultiSelect [data-baseweb=select] span{
max-width: 800px;
font-size: 0.9rem;
}
</style>
""", unsafe_allow_html=True)
models = st.multiselect("Please choose the models", chart_data['Model'].tolist(), default = chart_data['Model'].tolist())
if sorted == 'Ascending':
ascend = True
else:
ascend = False
chart_data = chart_data.sort_values(by=[sort], ascending=ascend)
min_value = round(chart_data.iloc[:, 1::].min().min() - 0.1, 1)
max_value = round(chart_data.iloc[:, 1::].max().max() + 0.1, 1)
options = {}
if category_two in ['SAMSum', 'DialogSum']:
options = {
"title": {"text": f"{category_two}"},
"tooltip": {
"trigger": "axis",
"axisPointer": {"type": "cross", "label": {"backgroundColor": "#6a7985"}},
"triggerOn": 'mousemove',
},
"legend": {"data": list(chart_data.columns)},
"toolbox": {"feature": {"saveAsImage": {}}},
"grid": {"left": "3%", "right": "4%", "bottom": "3%", "containLabel": True},
"xAxis": [
{
"type": "category",
"boundaryGap": True,
"triggerEvent": True,
"data": models,
}
],
"yAxis": [{"type": "value",
"min": min_value,
"max": max_value,
"boundaryGap": True
# "splitNumber": 10
}],
"series": [
{
"name": "Average",
"type": "bar",
"data": chart_data['Average'].tolist(),
},
{
"name": "ROUGE-1",
"type": "bar",
"data": chart_data["ROUGE-1"].tolist(),
},
{
"name": "ROUGE-2",
"type": "bar",
"data": chart_data["ROUGE-2"].tolist(),
},
{
"name": "ROUGE-L",
"type": "bar",
"data": chart_data["ROUGE-L"].tolist(),
},
],
}
elif category_two == 'DREAM':
options = {
"title": {"text": f"{category_two}"},
"tooltip": {
"trigger": "axis",
"axisPointer": {"type": "cross", "label": {"backgroundColor": "#6a7985"}},
"triggerOn": 'mousemove',
},
"legend": {"data": list(chart_data.columns)},
"toolbox": {"feature": {"saveAsImage": {}}},
"grid": {"left": "3%", "right": "4%", "bottom": "3%", "containLabel": True},
"xAxis": [
{
"type": "category",
"boundaryGap": True,
"triggerEvent": True,
"data": models,
}
],
"yAxis": [{"type": "value",
"min": min_value,
"max": max_value,
# "splitNumber": 10
"boundaryGap": True
}],
"series": [
{
"name": "Accuracy",
"type": "bar",
"data": chart_data['Accuracy'].tolist(),
},
],
}
# events = {
# "click": "function(params) { return params.value }"
# }
value = st_echarts(options=options, height="500px")
# if value != None:
# # print(value)
# nav_to(value)
### create table
st.divider()
# chart_data['Link'] = chart_data['Model'].map(links_dic)
st.dataframe(chart_data,
# column_config = {
# "Link": st.column_config.LinkColumn(
# display_text= st.image(huggingface_image)
# ),
# },
hide_index = True,
use_container_width=True)