Spaces:
Runtime error
Runtime error
import os | |
import asyncio | |
import random | |
import sqlite3 | |
import panel as pn | |
import pandas as pd | |
from litellm import acompletion | |
pn.extension("perspective") | |
MODELS = [ | |
"mistral/mistral-tiny", | |
"mistral/mistral-small", | |
"mistral/mistral-medium", | |
"mistral/mistral-large-latest", | |
] | |
VOTING_LABELS = [ | |
"π A is better", | |
"π€ About the same", | |
"π Both not good", | |
"π B is better", | |
] | |
async def respond(content, user, instance): | |
""" | |
Respond to the user in the chat interface. | |
""" | |
try: | |
instance.disabled = True | |
chat_label = instance.name | |
if chat_model := chat_models.get(chat_label): | |
model = chat_model | |
else: | |
# remove past history up to new message | |
instance.objects = instance.objects[-1:] | |
header_a.object = f"## Model: A" | |
header_b.object = f"## Model: B" | |
model = chat_models[chat_label] = random.choice(MODELS) | |
messages = instance.serialize() | |
messages.append({"role": "user", "content": content}) | |
if api_key_input.value: | |
api_key = api_key_input.value | |
else: | |
api_key = os.environ.get("MISTRAL_API_KEY") | |
response = await acompletion( | |
model=model, messages=messages, stream=True, max_tokens=128, api_key=api_key | |
) | |
message = None | |
async for chunk in response: | |
if not chunk.choices[0].delta["content"]: | |
continue | |
message = instance.stream( | |
chunk.choices[0].delta["content"], user="Assistant", message=message | |
) | |
finally: | |
instance.disabled = False | |
async def forward_message(content, user, instance): | |
""" | |
Send the message to the other chat interface and respond to the user in both. | |
""" | |
if instance is chat_interface_a: | |
other_instance = chat_interface_b | |
else: | |
other_instance = chat_interface_a | |
other_instance.append(pn.chat.ChatMessage(content, user=user)) | |
coroutines = [ | |
respond(content, user, chat_interface) | |
for chat_interface in (chat_interface_a, chat_interface_b) | |
] | |
await asyncio.gather(*coroutines) | |
def click_vote(event): | |
""" | |
Count the votes and update the voting results. | |
""" | |
if len(chat_models) == 0: | |
return | |
voting_label = event.obj.name | |
if voting_label == VOTING_LABELS[0]: | |
chat_model = chat_models[chat_interface_a.name] | |
voting_counts[chat_model] = voting_counts.get(chat_model, 0) + 1 | |
elif voting_label == VOTING_LABELS[3]: | |
chat_model = chat_models[chat_interface_b.name] | |
voting_counts[chat_model] = voting_counts.get(chat_model, 0) + 1 | |
elif voting_label == VOTING_LABELS[1]: | |
chat_model_a = chat_models[chat_interface_a.name] | |
chat_model_b = chat_models[chat_interface_b.name] | |
if chat_model_a == chat_model_b: | |
voting_counts[chat_model_a] = voting_counts.get(chat_model_a, 0) + 1 | |
else: | |
voting_counts[chat_model_a] = voting_counts.get(chat_model_a, 0) + 1 | |
voting_counts[chat_model_b] = voting_counts.get(chat_model_b, 0) + 1 | |
header_a.object = f"## Model: {chat_models[chat_interface_a.name]}" | |
header_b.object = f"## Model: {chat_models[chat_interface_b.name]}" | |
for chat_label in set(chat_models.keys()): | |
chat_models.pop(chat_label) | |
perspective.object = ( | |
pd.DataFrame(voting_counts, index=["Votes"]) | |
.melt(var_name="Model", value_name="Votes") | |
.set_index("Model") | |
) | |
with sqlite3.connect("voting_counts.db") as conn: | |
pd.DataFrame(voting_counts.items(), columns=["Model", "Votes"]).to_sql( | |
"voting_counts", conn, if_exists="replace", index=False | |
) | |
# initialize | |
chat_models = {} | |
with sqlite3.connect("voting_counts.db") as conn: | |
conn.execute( | |
"CREATE TABLE IF NOT EXISTS voting_counts (Model TEXT PRIMARY KEY, Votes INTEGER)" | |
) | |
voting_counts = ( | |
pd.read_sql("SELECT * FROM voting_counts", conn) | |
.set_index("Model")["Votes"] | |
.to_dict() | |
) | |
# header | |
api_key_input = pn.widgets.PasswordInput( | |
placeholder="Mistral API Key", stylesheets=[".bk-input {color: black};"] | |
) | |
# main | |
tabs = pn.Tabs() | |
# tab 1 | |
chat_interface_kwargs = dict( | |
callback=forward_message, | |
show_undo=False, | |
show_rerun=False, | |
show_clear=False, | |
show_stop=False, | |
show_button_name=False, | |
callback_exception="verbose", | |
) | |
header_a = pn.pane.Markdown("## Model: A") | |
chat_interface_a = pn.chat.ChatInterface( | |
name="A", header=header_a, **chat_interface_kwargs | |
) | |
header_b = pn.pane.Markdown("## Model: B") | |
chat_interface_b = pn.chat.ChatInterface( | |
name="B", header=header_b, **chat_interface_kwargs | |
) | |
button_kwargs = dict(sizing_mode="stretch_width") | |
button_row = pn.Row() | |
for voting_label in VOTING_LABELS: | |
button = pn.widgets.Button(name=voting_label, **button_kwargs) | |
button.on_click(click_vote) | |
button_row.append(button) | |
tabs.append(("Chat", pn.Column(pn.Row(chat_interface_a, chat_interface_b), button_row))) | |
# tab 2 | |
perspective = pn.pane.Perspective( | |
pd.DataFrame(voting_counts, index=["Votes"]) | |
.melt(var_name="Model", value_name="Votes") | |
.set_index("Model"), | |
sizing_mode="stretch_both", | |
editable=False, | |
) | |
tabs.append(("Voting Results", perspective)) | |
# layout | |
pn.template.FastListTemplate( | |
title="Mistral Chat Arena", | |
header=[api_key_input], | |
main=[tabs], | |
).servable() | |