Spaces:
Runtime error
Runtime error
import time | |
import gradio as gr | |
import random | |
from conversation import Conversation | |
def get_tab_arena_battle(download_bot_config, get_bot_profile, model_mapping, client): | |
gr.Markdown(""" | |
# ⚔️ Chatbot Arena (battle) ⚔️ | |
## Rules | |
* Chat with two anonymous models side-by-side and vote for which one is better! | |
* You can do multiple rounds of conversations before voting or vote for each message. | |
* The names of the models will be revealed of the top after your voted and pressed "Show models". | |
* Click “Restart” to start a new round with new models. | |
""") | |
default_bot_id = "_bot_e21de304-6151-4a04-b025-4c553ae8cbca" | |
bot_config = download_bot_config(default_bot_id) | |
user_state = gr.State( | |
bot_config | |
) | |
with gr.Row(): | |
bot_id = gr.Textbox(label="Chai bot ID", value=default_bot_id, interactive=True) | |
reload_bot_button = gr.Button("Reload bot") | |
bot_profile = gr.HTML(get_bot_profile(bot_config)) | |
with gr.Accordion("Bot config:", open=False): | |
bot_config_text = gr.Markdown(f"# Memory\n{bot_config['memory']}\n# Prompt\n{bot_config['prompt']}\n") | |
with gr.Row(): | |
values = list(model_mapping.keys()) | |
first_message = (None, bot_config["firstMessage"]) | |
height = 450 | |
model_a_value, model_b_value = random.sample(values, 2) | |
with gr.Column(): | |
model_a = gr.Textbox(value=model_a_value, label="Model A", interactive=False, visible=False) | |
chatbot_a = gr.Chatbot([first_message]) | |
chatbot_a.style(height=height) | |
with gr.Column(): | |
model_b = gr.Textbox(value=model_b_value, label="Model B", interactive=False, visible=False) | |
chatbot_b = gr.Chatbot([first_message]) | |
chatbot_b.style(height=height) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
msg = gr.Textbox(show_label=False, value="Hi there!", interactive=True) | |
with gr.Column(scale=3): | |
send = gr.Button("Send") | |
with gr.Row(): | |
vote_a = gr.Button("👈 A is better", interactive=False) | |
vote_b = gr.Button("👉 B is better", interactive=False) | |
vote_tie = gr.Button("🤝 Tie", interactive=False) | |
vote_bad = gr.Button("💩 Both are bad", interactive=False) | |
show_models_button = gr.Button("Show models", interactive=False) | |
with gr.Row(): | |
regenerate = gr.Button("Regenerate", interactive=False) | |
clear = gr.Button("Restart") | |
with gr.Accordion("Generation parameters for model A", open=False): | |
model = model_mapping[model_a.value] | |
temperature_model_a = gr.Slider(minimum=0.0, maximum=1.0, value=model.generation_params["temperature"], | |
interactive=True, label="Temperature") | |
repetition_penalty_model_a = gr.Slider(minimum=0.0, maximum=2.0, | |
value=model.generation_params["repetition_penalty"], | |
interactive=True, label="Repetition penalty") | |
max_new_tokens_model_a = gr.Slider(minimum=1, maximum=512, value=model.generation_params["max_new_tokens"], | |
interactive=True, label="Max new tokens") | |
top_k_model_a = gr.Slider(minimum=1, maximum=100, value=model.generation_params["top_k"], | |
interactive=True, label="Top-K") | |
top_p_model_a = gr.Slider(minimum=0.0, maximum=1.0, value=model.generation_params["top_p"], | |
interactive=True, label="Top-P") | |
with gr.Accordion("Generation parameters for model B", open=False): | |
model = model_mapping[model_b.value] | |
temperature_model_b = gr.Slider(minimum=0.0, maximum=1.0, value=model.generation_params["temperature"], | |
interactive=True, label="Temperature") | |
repetition_penalty_model_b = gr.Slider(minimum=0.0, maximum=2.0, | |
value=model.generation_params["repetition_penalty"], | |
interactive=True, label="Repetition penalty") | |
max_new_tokens_model_b = gr.Slider(minimum=1, maximum=512, value=model.generation_params["max_new_tokens"], | |
interactive=True, label="Max new tokens") | |
top_k_model_b = gr.Slider(minimum=1, maximum=100, value=model.generation_params["top_k"], | |
interactive=True, label="Top-K") | |
top_p_model_b = gr.Slider(minimum=0.0, maximum=1.0, value=model.generation_params["top_p"], | |
interactive=True, label="Top-P") | |
def clear_chat(user_state): | |
return "", [(None, user_state["firstMessage"])], [(None, user_state["firstMessage"])] | |
def reload_bot(bot_id): | |
bot_config = download_bot_config(bot_id) | |
bot_profile = get_bot_profile(bot_config) | |
return bot_profile, [(None, bot_config["firstMessage"])], [(None, bot_config[ | |
"firstMessage"])], bot_config, f"# Memory\n{bot_config['memory']}\n# Prompt\n{bot_config['prompt']}" | |
def get_generation_args(model_tag): | |
model = model_mapping[model_tag] | |
return ( | |
model.generation_params["temperature"], | |
model.generation_params["repetition_penalty"], | |
model.generation_params["max_new_tokens"], | |
model.generation_params["top_k"], | |
model.generation_params["top_p"], | |
) | |
def respond(message, chat_history, user_state, model_tag, | |
temperature, repetition_penalty, max_new_tokens, top_k, top_p): | |
custom_generation_params = { | |
'temperature': temperature, | |
'repetition_penalty': repetition_penalty, | |
'max_new_tokens': max_new_tokens, | |
'top_k': top_k, | |
'top_p': top_p, | |
} | |
conv = Conversation(user_state) | |
conv.set_chat_history(chat_history) | |
conv.add_user_message(message) | |
model = model_mapping[model_tag] | |
bot_message = model.generate_response(conv, custom_generation_params) | |
chat_history.append( | |
(message, bot_message) | |
) | |
return "", chat_history | |
def record_vote(user_state, vote, | |
chat_history_a, model_tag_a, | |
chat_history_b, model_tag_b): | |
conv_a = Conversation(user_state) | |
conv_a.set_chat_history(chat_history_a) | |
conv_b = Conversation(user_state) | |
conv_b.set_chat_history(chat_history_b) | |
if "A is better" in vote: | |
vote_str = "model_a" | |
elif "B is better" in vote: | |
vote_str = "model_b" | |
elif "Tie" in vote: | |
vote_str = "tie" | |
else: | |
vote_str = "tie (bothbad)" | |
row = { | |
"timestamp": time.time(), | |
"bot_id": user_state["bot_id"], | |
"vote": vote_str, | |
"model_a": model_tag_a, | |
"model_b": model_tag_b, | |
"is_anonymous": int(True) | |
} | |
sheet = client.open("Chat Arena").sheet1 | |
num_rows = len(sheet.get_all_records()) | |
sheet.insert_row(list(row.values()), index=num_rows + 2) | |
return gr.Button.update(interactive=True) | |
def regenerate_response(chat_history, user_state, model_tag, | |
temperature, repetition_penalty, max_new_tokens, top_k, top_p): | |
if len(chat_history) == 1: | |
return "", chat_history | |
custom_generation_params = { | |
'temperature': temperature, | |
'repetition_penalty': repetition_penalty, | |
'max_new_tokens': max_new_tokens, | |
'top_k': top_k, | |
'top_p': top_p, | |
} | |
last_row = chat_history.pop(-1) | |
chat_history.append((last_row[0], None)) | |
model = model_mapping[model_tag] | |
conv = Conversation(user_state) | |
conv.set_chat_history(chat_history) | |
bot_message = model.generate_response(conv, custom_generation_params) | |
chat_history[-1] = (last_row[0], bot_message) | |
return "", chat_history | |
def disable_voting(): | |
return [gr.Button.update(interactive=False)] * 4 | |
def enable_voting(): | |
return [gr.Button.update(interactive=True)] * 4 | |
def show_models(): | |
return [gr.Textbox.update(visible=True)] * 2 | |
def hide_models(): | |
model_a_value, model_b_value = random.sample(values, 2) | |
return [gr.Textbox.update(visible=False, value=model_a_value), | |
gr.Textbox.update(visible=False, value=model_b_value)] | |
def disable_send(): | |
return [gr.Button.update(interactive=False)] * 3 | |
def enable_send(): | |
return [gr.Button.update(interactive=True), gr.Button.update(interactive=False)] | |
def enable_regenerate(): | |
return gr.Button.update(interactive=True) | |
for vote in [vote_a, vote_b, vote_tie, vote_bad]: | |
vote.click(record_vote, | |
[user_state, vote, chatbot_a, model_a, chatbot_b, model_b], | |
[show_models_button], | |
queue=False) | |
vote.click(disable_voting, None, [vote_a, vote_b, vote_tie, vote_bad], queue=False) | |
show_models_button.click(show_models, None, [model_a, model_b], queue=False) | |
clear.click(hide_models, None, [model_a, model_b], queue=False) | |
reload_bot_button.click(hide_models, None, [model_a, model_b], queue=False) | |
show_models_button.click(disable_voting, None, [vote_a, vote_b, vote_tie, vote_bad], queue=False) | |
show_models_button.click(disable_send, None, [send, regenerate, show_models_button], queue=False) | |
clear.click(enable_send, None, [send, regenerate], queue=False) | |
reload_bot_button.click(enable_send, None, [send, regenerate], queue=False) | |
model_a.change(get_generation_args, [model_a], | |
[temperature_model_a, repetition_penalty_model_a, max_new_tokens_model_a, top_k_model_a, | |
top_p_model_a], queue=False) | |
model_b.change(get_generation_args, [model_b], | |
[temperature_model_b, repetition_penalty_model_b, max_new_tokens_model_b, top_k_model_b, | |
top_p_model_b], queue=False) | |
clear.click(clear_chat, [user_state], [msg, chatbot_a, chatbot_b], queue=False) | |
model_a.change(clear_chat, [user_state], [msg, chatbot_a, chatbot_b], queue=False) | |
model_b.change(clear_chat, [user_state], [msg, chatbot_a, chatbot_b], queue=False) | |
# model_a.change(enable_voting, None, [vote_a, vote_b, vote_tie, vote_bad], queue=False) | |
# model_b.change(enable_voting, None, [vote_a, vote_b, vote_tie, vote_bad], queue=False) | |
reload_bot_button.click(disable_voting, None, [vote_a, vote_b, vote_tie, vote_bad], queue=False) | |
reload_bot_button.click(reload_bot, [bot_id], [bot_profile, chatbot_a, chatbot_b, user_state, bot_config_text], | |
queue=False) | |
send.click(enable_voting, None, [vote_a, vote_b, vote_tie, vote_bad], queue=False) | |
clear.click(disable_voting, None, [vote_a, vote_b, vote_tie, vote_bad], queue=False) | |
regenerate.click(enable_voting, None, [vote_a, vote_b, vote_tie, vote_bad], queue=False) | |
msg.submit(enable_voting, None, [vote_a, vote_b, vote_tie, vote_bad], queue=False) | |
send.click(respond, | |
[msg, chatbot_a, user_state, model_a, temperature_model_a, repetition_penalty_model_a, | |
max_new_tokens_model_a, top_k_model_a, top_p_model_a], [msg, chatbot_a], | |
queue=False) | |
msg.submit(respond, | |
[msg, chatbot_a, user_state, model_a, temperature_model_a, repetition_penalty_model_a, | |
max_new_tokens_model_a, top_k_model_a, top_p_model_a], [msg, chatbot_a], | |
queue=False) | |
send.click(respond, | |
[msg, chatbot_b, user_state, model_b, temperature_model_b, repetition_penalty_model_b, | |
max_new_tokens_model_b, top_k_model_b, top_p_model_b], [msg, chatbot_b], | |
queue=False) | |
msg.submit(respond, | |
[msg, chatbot_b, user_state, model_b, temperature_model_b, repetition_penalty_model_b, | |
max_new_tokens_model_b, top_k_model_b, top_p_model_b], [msg, chatbot_b], | |
queue=False) | |
send.click(enable_regenerate, None, [regenerate], queue=False) | |
msg.submit(enable_regenerate, None, [regenerate], queue=False) | |
regenerate.click(regenerate_response, | |
[chatbot_a, user_state, model_a, temperature_model_a, repetition_penalty_model_a, | |
max_new_tokens_model_a, top_k_model_a, | |
top_p_model_a], [msg, chatbot_a], queue=False) | |
regenerate.click(regenerate_response, | |
[chatbot_b, user_state, model_b, temperature_model_b, repetition_penalty_model_b, | |
max_new_tokens_model_b, top_k_model_b, | |
top_p_model_b], [msg, chatbot_b], queue=False) | |