model-evaluation / tabs /arena_battle.py
AlekseyKorshuk's picture
updates
d7f914d
raw
history blame
12.9 kB
import time
import gradio as gr
import random
from conversation import Conversation
def get_tab_arena_battle(download_bot_config, get_bot_profile, model_mapping, client):
gr.Markdown("""
# ⚔️ Chatbot Arena (battle) ⚔️
## Rules
* Chat with two anonymous models side-by-side and vote for which one is better!
* You can do multiple rounds of conversations before voting or vote for each message.
* The names of the models will be revealed of the top after your voted and pressed "Show models".
* Click “Restart” to start a new round with new models.
""")
default_bot_id = "_bot_e21de304-6151-4a04-b025-4c553ae8cbca"
bot_config = download_bot_config(default_bot_id)
user_state = gr.State(
bot_config
)
with gr.Row():
bot_id = gr.Textbox(label="Chai bot ID", value=default_bot_id, interactive=True)
reload_bot_button = gr.Button("Reload bot")
bot_profile = gr.HTML(get_bot_profile(bot_config))
with gr.Accordion("Bot config:", open=False):
bot_config_text = gr.Markdown(f"# Memory\n{bot_config['memory']}\n# Prompt\n{bot_config['prompt']}\n")
with gr.Row():
values = list(model_mapping.keys())
first_message = (None, bot_config["firstMessage"])
height = 450
model_a_value, model_b_value = random.sample(values, 2)
with gr.Column():
model_a = gr.Textbox(value=model_a_value, label="Model A", interactive=False, visible=False)
chatbot_a = gr.Chatbot([first_message])
chatbot_a.style(height=height)
with gr.Column():
model_b = gr.Textbox(value=model_b_value, label="Model B", interactive=False, visible=False)
chatbot_b = gr.Chatbot([first_message])
chatbot_b.style(height=height)
with gr.Row():
with gr.Column(scale=3):
msg = gr.Textbox(show_label=False, value="Hi there!", interactive=True)
with gr.Column(scale=3):
send = gr.Button("Send")
with gr.Row():
vote_a = gr.Button("👈 A is better", interactive=False)
vote_b = gr.Button("👉 B is better", interactive=False)
vote_tie = gr.Button("🤝 Tie", interactive=False)
vote_bad = gr.Button("💩 Both are bad", interactive=False)
show_models_button = gr.Button("Show models", interactive=False)
with gr.Row():
regenerate = gr.Button("Regenerate", interactive=False)
clear = gr.Button("Restart")
with gr.Accordion("Generation parameters for model A", open=False):
model = model_mapping[model_a.value]
temperature_model_a = gr.Slider(minimum=0.0, maximum=1.0, value=model.generation_params["temperature"],
interactive=True, label="Temperature")
repetition_penalty_model_a = gr.Slider(minimum=0.0, maximum=2.0,
value=model.generation_params["repetition_penalty"],
interactive=True, label="Repetition penalty")
max_new_tokens_model_a = gr.Slider(minimum=1, maximum=512, value=model.generation_params["max_new_tokens"],
interactive=True, label="Max new tokens")
top_k_model_a = gr.Slider(minimum=1, maximum=100, value=model.generation_params["top_k"],
interactive=True, label="Top-K")
top_p_model_a = gr.Slider(minimum=0.0, maximum=1.0, value=model.generation_params["top_p"],
interactive=True, label="Top-P")
with gr.Accordion("Generation parameters for model B", open=False):
model = model_mapping[model_b.value]
temperature_model_b = gr.Slider(minimum=0.0, maximum=1.0, value=model.generation_params["temperature"],
interactive=True, label="Temperature")
repetition_penalty_model_b = gr.Slider(minimum=0.0, maximum=2.0,
value=model.generation_params["repetition_penalty"],
interactive=True, label="Repetition penalty")
max_new_tokens_model_b = gr.Slider(minimum=1, maximum=512, value=model.generation_params["max_new_tokens"],
interactive=True, label="Max new tokens")
top_k_model_b = gr.Slider(minimum=1, maximum=100, value=model.generation_params["top_k"],
interactive=True, label="Top-K")
top_p_model_b = gr.Slider(minimum=0.0, maximum=1.0, value=model.generation_params["top_p"],
interactive=True, label="Top-P")
def clear_chat(user_state):
return "", [(None, user_state["firstMessage"])], [(None, user_state["firstMessage"])]
def reload_bot(bot_id):
bot_config = download_bot_config(bot_id)
bot_profile = get_bot_profile(bot_config)
return bot_profile, [(None, bot_config["firstMessage"])], [(None, bot_config[
"firstMessage"])], bot_config, f"# Memory\n{bot_config['memory']}\n# Prompt\n{bot_config['prompt']}"
def get_generation_args(model_tag):
model = model_mapping[model_tag]
return (
model.generation_params["temperature"],
model.generation_params["repetition_penalty"],
model.generation_params["max_new_tokens"],
model.generation_params["top_k"],
model.generation_params["top_p"],
)
def respond(message, chat_history, user_state, model_tag,
temperature, repetition_penalty, max_new_tokens, top_k, top_p):
custom_generation_params = {
'temperature': temperature,
'repetition_penalty': repetition_penalty,
'max_new_tokens': max_new_tokens,
'top_k': top_k,
'top_p': top_p,
}
conv = Conversation(user_state)
conv.set_chat_history(chat_history)
conv.add_user_message(message)
model = model_mapping[model_tag]
bot_message = model.generate_response(conv, custom_generation_params)
chat_history.append(
(message, bot_message)
)
return "", chat_history
def record_vote(user_state, vote,
chat_history_a, model_tag_a,
chat_history_b, model_tag_b):
conv_a = Conversation(user_state)
conv_a.set_chat_history(chat_history_a)
conv_b = Conversation(user_state)
conv_b.set_chat_history(chat_history_b)
if "A is better" in vote:
vote_str = "model_a"
elif "B is better" in vote:
vote_str = "model_b"
elif "Tie" in vote:
vote_str = "tie"
else:
vote_str = "tie (bothbad)"
row = {
"timestamp": time.time(),
"bot_id": user_state["bot_id"],
"vote": vote_str,
"model_a": model_tag_a,
"model_b": model_tag_b,
"is_anonymous": int(True)
}
sheet = client.open("Chat Arena").sheet1
num_rows = len(sheet.get_all_records())
sheet.insert_row(list(row.values()), index=num_rows + 2)
return gr.Button.update(interactive=True)
def regenerate_response(chat_history, user_state, model_tag,
temperature, repetition_penalty, max_new_tokens, top_k, top_p):
if len(chat_history) == 1:
return "", chat_history
custom_generation_params = {
'temperature': temperature,
'repetition_penalty': repetition_penalty,
'max_new_tokens': max_new_tokens,
'top_k': top_k,
'top_p': top_p,
}
last_row = chat_history.pop(-1)
chat_history.append((last_row[0], None))
model = model_mapping[model_tag]
conv = Conversation(user_state)
conv.set_chat_history(chat_history)
bot_message = model.generate_response(conv, custom_generation_params)
chat_history[-1] = (last_row[0], bot_message)
return "", chat_history
def disable_voting():
return [gr.Button.update(interactive=False)] * 4
def enable_voting():
return [gr.Button.update(interactive=True)] * 4
def show_models():
return [gr.Textbox.update(visible=True)] * 2
def hide_models():
model_a_value, model_b_value = random.sample(values, 2)
return [gr.Textbox.update(visible=False, value=model_a_value),
gr.Textbox.update(visible=False, value=model_b_value)]
def disable_send():
return [gr.Button.update(interactive=False)] * 3
def enable_send():
return [gr.Button.update(interactive=True), gr.Button.update(interactive=False)]
def enable_regenerate():
return gr.Button.update(interactive=True)
for vote in [vote_a, vote_b, vote_tie, vote_bad]:
vote.click(record_vote,
[user_state, vote, chatbot_a, model_a, chatbot_b, model_b],
[show_models_button],
queue=False)
vote.click(disable_voting, None, [vote_a, vote_b, vote_tie, vote_bad], queue=False)
show_models_button.click(show_models, None, [model_a, model_b], queue=False)
clear.click(hide_models, None, [model_a, model_b], queue=False)
reload_bot_button.click(hide_models, None, [model_a, model_b], queue=False)
show_models_button.click(disable_voting, None, [vote_a, vote_b, vote_tie, vote_bad], queue=False)
show_models_button.click(disable_send, None, [send, regenerate, show_models_button], queue=False)
clear.click(enable_send, None, [send, regenerate], queue=False)
reload_bot_button.click(enable_send, None, [send, regenerate], queue=False)
model_a.change(get_generation_args, [model_a],
[temperature_model_a, repetition_penalty_model_a, max_new_tokens_model_a, top_k_model_a,
top_p_model_a], queue=False)
model_b.change(get_generation_args, [model_b],
[temperature_model_b, repetition_penalty_model_b, max_new_tokens_model_b, top_k_model_b,
top_p_model_b], queue=False)
clear.click(clear_chat, [user_state], [msg, chatbot_a, chatbot_b], queue=False)
model_a.change(clear_chat, [user_state], [msg, chatbot_a, chatbot_b], queue=False)
model_b.change(clear_chat, [user_state], [msg, chatbot_a, chatbot_b], queue=False)
# model_a.change(enable_voting, None, [vote_a, vote_b, vote_tie, vote_bad], queue=False)
# model_b.change(enable_voting, None, [vote_a, vote_b, vote_tie, vote_bad], queue=False)
reload_bot_button.click(disable_voting, None, [vote_a, vote_b, vote_tie, vote_bad], queue=False)
reload_bot_button.click(reload_bot, [bot_id], [bot_profile, chatbot_a, chatbot_b, user_state, bot_config_text],
queue=False)
send.click(enable_voting, None, [vote_a, vote_b, vote_tie, vote_bad], queue=False)
clear.click(disable_voting, None, [vote_a, vote_b, vote_tie, vote_bad], queue=False)
regenerate.click(enable_voting, None, [vote_a, vote_b, vote_tie, vote_bad], queue=False)
msg.submit(enable_voting, None, [vote_a, vote_b, vote_tie, vote_bad], queue=False)
send.click(respond,
[msg, chatbot_a, user_state, model_a, temperature_model_a, repetition_penalty_model_a,
max_new_tokens_model_a, top_k_model_a, top_p_model_a], [msg, chatbot_a],
queue=False)
msg.submit(respond,
[msg, chatbot_a, user_state, model_a, temperature_model_a, repetition_penalty_model_a,
max_new_tokens_model_a, top_k_model_a, top_p_model_a], [msg, chatbot_a],
queue=False)
send.click(respond,
[msg, chatbot_b, user_state, model_b, temperature_model_b, repetition_penalty_model_b,
max_new_tokens_model_b, top_k_model_b, top_p_model_b], [msg, chatbot_b],
queue=False)
msg.submit(respond,
[msg, chatbot_b, user_state, model_b, temperature_model_b, repetition_penalty_model_b,
max_new_tokens_model_b, top_k_model_b, top_p_model_b], [msg, chatbot_b],
queue=False)
send.click(enable_regenerate, None, [regenerate], queue=False)
msg.submit(enable_regenerate, None, [regenerate], queue=False)
regenerate.click(regenerate_response,
[chatbot_a, user_state, model_a, temperature_model_a, repetition_penalty_model_a,
max_new_tokens_model_a, top_k_model_a,
top_p_model_a], [msg, chatbot_a], queue=False)
regenerate.click(regenerate_response,
[chatbot_b, user_state, model_b, temperature_model_b, repetition_penalty_model_b,
max_new_tokens_model_b, top_k_model_b,
top_p_model_b], [msg, chatbot_b], queue=False)