Spaces:
Runtime error
Runtime error
File size: 5,465 Bytes
eaf5b04 a7d2870 ab13803 a7d2870 ee2c314 a7d2870 ee2c314 a7d2870 1e333df a7d2870 3ebfb41 a7d2870 3ebfb41 a7d2870 ab13803 a7d2870 eaf5b04 a7d2870 ab13803 a7d2870 ab13803 a7d2870 eaf5b04 a7d2870 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
import os
import asyncio
import random
import sqlite3
import panel as pn
import pandas as pd
from litellm import acompletion
pn.extension("perspective")
MODELS = [
"mistral/mistral-tiny",
"mistral/mistral-small",
"mistral/mistral-medium",
"mistral/mistral-large-latest",
]
VOTING_LABELS = [
"👈 A is better",
"🤗 About the same",
"😓 Both not good",
"👉 B is better",
]
async def respond(content, user, instance):
"""
Respond to the user in the chat interface.
"""
try:
instance.disabled = True
chat_label = instance.name
if chat_model := chat_models.get(chat_label):
model = chat_model
else:
# remove past history up to new message
instance.objects = instance.objects[-1:]
header_a.object = f"## Model: A"
header_b.object = f"## Model: B"
model = chat_models[chat_label] = random.choice(MODELS)
messages = instance.serialize()
messages.append({"role": "user", "content": content})
if api_key_input.value:
api_key = api_key_input.value
else:
api_key = os.environ.get("MISTRAL_API_KEY")
response = await acompletion(
model=model, messages=messages, stream=True, max_tokens=128, api_key=api_key
)
message = None
async for chunk in response:
if not chunk.choices[0].delta["content"]:
continue
message = instance.stream(
chunk.choices[0].delta["content"], user="Assistant", message=message
)
finally:
instance.disabled = False
async def forward_message(content, user, instance):
"""
Send the message to the other chat interface and respond to the user in both.
"""
if instance is chat_interface_a:
other_instance = chat_interface_b
else:
other_instance = chat_interface_a
other_instance.append(pn.chat.ChatMessage(content, user=user))
coroutines = [
respond(content, user, chat_interface)
for chat_interface in (chat_interface_a, chat_interface_b)
]
await asyncio.gather(*coroutines)
def click_vote(event):
"""
Count the votes and update the voting results.
"""
if len(chat_models) == 0:
return
voting_label = event.obj.name
if voting_label == VOTING_LABELS[0]:
chat_model = chat_models[chat_interface_a.name]
voting_counts[chat_model] = voting_counts.get(chat_model, 0) + 1
elif voting_label == VOTING_LABELS[3]:
chat_model = chat_models[chat_interface_b.name]
voting_counts[chat_model] = voting_counts.get(chat_model, 0) + 1
elif voting_label == VOTING_LABELS[1]:
chat_model_a = chat_models[chat_interface_a.name]
chat_model_b = chat_models[chat_interface_b.name]
if chat_model_a == chat_model_b:
voting_counts[chat_model_a] = voting_counts.get(chat_model_a, 0) + 1
else:
voting_counts[chat_model_a] = voting_counts.get(chat_model_a, 0) + 1
voting_counts[chat_model_b] = voting_counts.get(chat_model_b, 0) + 1
header_a.object = f"## Model: {chat_models[chat_interface_a.name]}"
header_b.object = f"## Model: {chat_models[chat_interface_b.name]}"
for chat_label in set(chat_models.keys()):
chat_models.pop(chat_label)
perspective.object = (
pd.DataFrame(voting_counts, index=["Votes"])
.melt(var_name="Model", value_name="Votes")
.set_index("Model")
)
with sqlite3.connect("voting_counts.db") as conn:
pd.DataFrame(voting_counts.items(), columns=["Model", "Votes"]).to_sql(
"voting_counts", conn, if_exists="replace", index=False
)
# initialize
chat_models = {}
with sqlite3.connect("voting_counts.db") as conn:
conn.execute(
"CREATE TABLE IF NOT EXISTS voting_counts (Model TEXT PRIMARY KEY, Votes INTEGER)"
)
voting_counts = (
pd.read_sql("SELECT * FROM voting_counts", conn)
.set_index("Model")["Votes"]
.to_dict()
)
# header
api_key_input = pn.widgets.PasswordInput(
placeholder="Mistral API Key", stylesheets=[".bk-input {color: black};"]
)
# main
tabs = pn.Tabs()
# tab 1
chat_interface_kwargs = dict(
callback=forward_message,
show_undo=False,
show_rerun=False,
show_clear=False,
show_stop=False,
show_button_name=False,
)
header_a = pn.pane.Markdown("## Model: A")
chat_interface_a = pn.chat.ChatInterface(
name="A", header=header_a, **chat_interface_kwargs
)
header_b = pn.pane.Markdown("## Model: B")
chat_interface_b = pn.chat.ChatInterface(
name="B", header=header_b, **chat_interface_kwargs
)
button_kwargs = dict(sizing_mode="stretch_width")
button_row = pn.Row()
for voting_label in VOTING_LABELS:
button = pn.widgets.Button(name=voting_label, **button_kwargs)
button.on_click(click_vote)
button_row.append(button)
tabs.append(("Chat", pn.Column(pn.Row(chat_interface_a, chat_interface_b), button_row)))
# tab 2
perspective = pn.pane.Perspective(
pd.DataFrame(voting_counts, index=["Votes"])
.melt(var_name="Model", value_name="Votes")
.set_index("Model"),
sizing_mode="stretch_both",
editable=False,
)
tabs.append(("Voting Results", perspective))
# layout
pn.template.FastListTemplate(
title="Mistral Chat Arena",
header=[api_key_input],
main=[tabs],
).servable()
|