ahuang11's picture
Update app.py
13c31d1 verified
raw
history blame contribute delete
No virus
5.47 kB
import os
import asyncio
import random
import sqlite3
import panel as pn
import pandas as pd
from litellm import acompletion
pn.extension("perspective")
MODELS = [
"mistral/mistral-tiny",
"mistral/mistral-small",
"mistral/mistral-medium",
"mistral/mistral-large-latest",
]
VOTING_LABELS = [
"πŸ‘ˆ A is better",
"πŸ€— About the same",
"πŸ˜“ Both not good",
"πŸ‘‰ B is better",
]
async def respond(content, user, instance):
"""
Respond to the user in the chat interface.
"""
try:
instance.disabled = True
chat_label = instance.name
if chat_model := chat_models.get(chat_label):
model = chat_model
else:
# remove past history up to new message
instance.objects = instance.objects[-1:]
header_a.object = f"## Model: A"
header_b.object = f"## Model: B"
model = chat_models[chat_label] = random.choice(MODELS)
messages = instance.serialize()
messages.append({"role": "user", "content": content})
if api_key_input.value:
api_key = api_key_input.value
else:
api_key = os.environ.get("MISTRAL_API_KEY")
response = await acompletion(
model=model, messages=messages, stream=True, max_tokens=128, api_key=api_key
)
message = None
async for chunk in response:
if not chunk.choices[0].delta["content"]:
continue
message = instance.stream(
chunk.choices[0].delta["content"], user="Assistant", message=message
)
finally:
instance.disabled = False
async def forward_message(content, user, instance):
"""
Send the message to the other chat interface and respond to the user in both.
"""
if instance is chat_interface_a:
other_instance = chat_interface_b
else:
other_instance = chat_interface_a
other_instance.append(pn.chat.ChatMessage(content, user=user))
coroutines = [
respond(content, user, chat_interface)
for chat_interface in (chat_interface_a, chat_interface_b)
]
await asyncio.gather(*coroutines)
def click_vote(event):
"""
Count the votes and update the voting results.
"""
if len(chat_models) == 0:
return
voting_label = event.obj.name
if voting_label == VOTING_LABELS[0]:
chat_model = chat_models[chat_interface_a.name]
voting_counts[chat_model] = voting_counts.get(chat_model, 0) + 1
elif voting_label == VOTING_LABELS[3]:
chat_model = chat_models[chat_interface_b.name]
voting_counts[chat_model] = voting_counts.get(chat_model, 0) + 1
elif voting_label == VOTING_LABELS[1]:
chat_model_a = chat_models[chat_interface_a.name]
chat_model_b = chat_models[chat_interface_b.name]
if chat_model_a == chat_model_b:
voting_counts[chat_model_a] = voting_counts.get(chat_model_a, 0) + 1
else:
voting_counts[chat_model_a] = voting_counts.get(chat_model_a, 0) + 1
voting_counts[chat_model_b] = voting_counts.get(chat_model_b, 0) + 1
header_a.object = f"## Model: {chat_models[chat_interface_a.name]}"
header_b.object = f"## Model: {chat_models[chat_interface_b.name]}"
for chat_label in set(chat_models.keys()):
chat_models.pop(chat_label)
perspective.object = (
pd.DataFrame(voting_counts, index=["Votes"])
.melt(var_name="Model", value_name="Votes")
.set_index("Model")
)
with sqlite3.connect("voting_counts.db") as conn:
pd.DataFrame(voting_counts.items(), columns=["Model", "Votes"]).to_sql(
"voting_counts", conn, if_exists="replace", index=False
)
# initialize
chat_models = {}
with sqlite3.connect("voting_counts.db") as conn:
conn.execute(
"CREATE TABLE IF NOT EXISTS voting_counts (Model TEXT PRIMARY KEY, Votes INTEGER)"
)
voting_counts = (
pd.read_sql("SELECT * FROM voting_counts", conn)
.set_index("Model")["Votes"]
.to_dict()
)
# header
api_key_input = pn.widgets.PasswordInput(
placeholder="Mistral API Key", stylesheets=[".bk-input {color: black};"]
)
# main
tabs = pn.Tabs()
# tab 1
chat_interface_kwargs = dict(
callback=forward_message,
show_undo=False,
show_rerun=False,
show_clear=False,
show_stop=False,
show_button_name=False,
)
header_a = pn.pane.Markdown("## Model: A")
chat_interface_a = pn.chat.ChatInterface(
name="A", header=header_a, **chat_interface_kwargs
)
header_b = pn.pane.Markdown("## Model: B")
chat_interface_b = pn.chat.ChatInterface(
name="B", header=header_b, **chat_interface_kwargs
)
button_kwargs = dict(sizing_mode="stretch_width")
button_row = pn.Row()
for voting_label in VOTING_LABELS:
button = pn.widgets.Button(name=voting_label, **button_kwargs)
button.on_click(click_vote)
button_row.append(button)
tabs.append(("Chat", pn.Column(pn.Row(chat_interface_a, chat_interface_b), button_row)))
# tab 2
perspective = pn.pane.Perspective(
pd.DataFrame(voting_counts, index=["Votes"])
.melt(var_name="Model", value_name="Votes")
.set_index("Model"),
sizing_mode="stretch_both",
editable=False,
)
tabs.append(("Voting Results", perspective))
# layout
pn.template.FastListTemplate(
title="Mistral Chat Arena",
header=[api_key_input],
main=[tabs],
).servable()