AIW-responses / app.py
marianna13's picture
Update app.py
727c7a4 verified
raw
history blame
5.76 kB
import gradio as gr
import pandas as pd
MODEL_MAPPINGS = {
"gpt-4o-2024-05-13": "GPT-4o",
"gpt-4-0613": "GPT-4",
"gpt-4-turbo-2024-04-09": "GPT-4 Turbo",
"gpt-4-0125-preview": "GPT-4 Preview",
"gpt-3.5": "GPT-3.5",
"gpt-3.5-turbo-0125": "GPT-3.5 Turbo",
"claude-3-opus-20240229": "Claude-3 O",
"claude-3-sonnet-20240229": "Claude-3 S",
"claude-3-haiku-20240307": "Claude-3 H",
"claude-3-5-sonnet-20240620": "Claude-3.5 S",
"llama-2-70b-chat": "Llama-2 70b",
"llama-2-13b-chat": "Llama-2 13b",
"llama-2-7b-chat": "Llama-2 7b",
"llama-3-8b-chat": "Llama-3 8b",
"llama-3-70b-chat": "Llama-3 70b",
"codellama-70b-instruct": "Codellama 70b",
"mistral-large-2402": "Mistral Large",
"mistral-medium-2312": "Mistral Medium",
"open-mixtral-8x22b-instruct-v0.1": "Mixtral 8x22b",
"open-mixtral-8x7b-instruct": "Mixtral 8x7b",
"open-mistral-7b-instruct": "Mistral 7b",
"open-mistral-7b": "Mistral 7b",
"open-mixtral-8x22b": "Mixtral 8x22b",
"open-mixtral-8x7b": "Mixtral 8x7b",
"open-mistral-7b-instruct-v0.1": "Mistral 7b",
"dbrx-instruct": "DBRX",
"command-r-plus": "Command R Plus",
"gemma-7b-it": "Gemma 7b",
"gemma-2b-it": "Gemma 2b",
"gemini-1.5-pro-latest": "Gemini 1.5",
"gemini-pro": "Gemini 1.0",
"qwen1.5-7b-chat": "Qwen 1.5 7b",
"qwen1.5-14b-chat": "Qwen 1.5 14b",
"qwen1.5-32b-chat": "Qwen 1.5 32b",
"qwen1.5-72b-chat": "Qwen 1.5 72b",
"qwen1.5-0.5b-chat": "Qwen 1.5 0.5b",
"qwen1.5-1.8b-chat": "Qwen 1.5 1.8b",
"qwen2-72b-instruct": "Qwen 2 72b",
"codestral-2405": "Codestral"
}
resp_url = 'https://github.com/LAION-AI/AIW/raw/main/collected_responses/responses.jsonl?download='
df = pd.read_json(resp_url, lines=True)
df['model'] = df['model'].map(MODEL_MAPPINGS)
df['prompt'] = df[['prompt', 'prompt_id']].apply(lambda x: f"{x['prompt']} [{x['prompt_id']}]", axis=1)
model_list = df['model'].unique()
prompt_id_list = list(df['prompt'].unique())
prompt_id_list = sorted(prompt_id_list, key=lambda x: int(x.split('[')[1].split(']')[0]))
def response(num_responses, model, correct, prompt_ids):
responses = df
if model:
responses = responses[responses['model'].isin(model)]
if correct:
responses = responses[responses['correct'].isin(correct)]
if prompt_ids:
responses = responses[responses['prompt'].isin(prompt_ids)]
if num_responses > len(responses):
num_responses = len(responses)
return responses.sample(num_responses)[['model', 'prompt', 'model_response', 'correct']]
def barplot_for_prompt_id(prompt_ids, models):
responses = df
if prompt_ids:
responses = responses[responses['prompt'].isin(prompt_ids)]
if models:
responses = responses[responses['model'].isin(models)]
means = responses.groupby(['model', 'prompt_id'])['correct'].mean()
means = means.reset_index()
means['prompt_id'] = means['prompt_id'].astype(str)
prompt_ids = list(set([p for p in means['prompt_id']]))
prompt_ids_str = ', '.join(prompt_ids)
return gr.BarPlot(
means,
x='prompt_id',
y='correct',
group='model',
color='prompt_id',
group_title="",
title=f'Correctness for Prompt IDs: {prompt_ids_str}',
x_title="",
)
title= "🎩🐇 Alice in Wonderland: Simple Tasks Showing Complete Reasoning Breakdown in State-Of-the-Art Large Language Models"
with gr.Blocks() as demo:
with gr.Row(elem_id="header-row"):
gr.HTML(
f"""<h1 style='font-size: 30px; font-weight: bold; text-align: center;'>{title}</h1>
<h4 align="center"><a href="https://marianna13.github.io/aiw/" target="_blank">🌐Homepage</a> | <a href="https://arxiv.org/pdf/2406.02061" target="_blank"> 📝Paper</a> | <a href="https://github.com/LAION-AI/AIW"target="_blank">🛠️Code</a></h4>
<p style='color: #000000; font-size: 20px; text-align: center;'>This demo shows the responses of different models to a set of prompts. The responses are categorized as correct or incorrect. You can choose the number of responses, the model, the correctness of the responses, and the prompt IDs to see the responses.</p>
<p style='color: #000000; font-size: 20px; text-align: center;'>You can also see the correctness of the responses for different prompt IDs using the robustness plot tab.</p>
"""
)
with gr.Tab("Responses"):
gr.Interface(
response,
[
gr.Slider(2, 20, value=4, label="Number of responses", info="Choose between 2 and 20"),
gr.Dropdown(
list(model_list), label="Model", info="Choose to see responses", multiselect=True
),
gr.CheckboxGroup([("Correct", True), ("Incorrect", False)], label="Correct or not", info="Choose to see correct or incorrect responses"),
gr.Dropdown(
prompt_id_list, multiselect=True, label="Prompt IDs", info="Choose to see responses for a specific prompt ID(s)"
),
],
gr.DataFrame(type="pandas", wrap=True, label="Responses"),
)
with gr.Tab("Robustness plot"):
gr.Interface(
barplot_for_prompt_id,
[
gr.Dropdown(
prompt_id_list, multiselect=True, label="Prompt IDs", info="Choose to see responses for a specific prompt ID(s)"
),
gr.Dropdown(
list(model_list), label="Model", info="Choose to see responses", multiselect=True
)],
gr.BarPlot( title="Correctness for Prompt IDs"),
)
if __name__ == "__main__":
demo.launch()