import os import asyncio import random import sqlite3 import panel as pn import pandas as pd from litellm import acompletion pn.extension("perspective") MODELS = [ "mistral/mistral-tiny", "mistral/mistral-small", "mistral/mistral-medium", "mistral/mistral-large-latest", ] VOTING_LABELS = [ "👈 A is better", "🤗 About the same", "😓 Both not good", "👉 B is better", ] async def respond(content, user, instance): """ Respond to the user in the chat interface. """ try: instance.disabled = True chat_label = instance.name if chat_model := chat_models.get(chat_label): model = chat_model else: # remove past history up to new message instance.objects = instance.objects[-1:] header_a.object = f"## Model: A" header_b.object = f"## Model: B" model = chat_models[chat_label] = random.choice(MODELS) messages = instance.serialize() messages.append({"role": "user", "content": content}) if api_key_input.value: api_key = api_key_input.value else: api_key = os.environ.get("MISTRAL_API_KEY") response = await acompletion( model=model, messages=messages, stream=True, max_tokens=128, api_key=api_key ) message = None async for chunk in response: if not chunk.choices[0].delta["content"]: continue message = instance.stream( chunk.choices[0].delta["content"], user="Assistant", message=message ) finally: instance.disabled = False async def forward_message(content, user, instance): """ Send the message to the other chat interface and respond to the user in both. """ if instance is chat_interface_a: other_instance = chat_interface_b else: other_instance = chat_interface_a other_instance.append(pn.chat.ChatMessage(content, user=user)) coroutines = [ respond(content, user, chat_interface) for chat_interface in (chat_interface_a, chat_interface_b) ] await asyncio.gather(*coroutines) def click_vote(event): """ Count the votes and update the voting results. """ if len(chat_models) == 0: return voting_label = event.obj.name if voting_label == VOTING_LABELS[0]: chat_model = chat_models[chat_interface_a.name] voting_counts[chat_model] = voting_counts.get(chat_model, 0) + 1 elif voting_label == VOTING_LABELS[3]: chat_model = chat_models[chat_interface_b.name] voting_counts[chat_model] = voting_counts.get(chat_model, 0) + 1 elif voting_label == VOTING_LABELS[1]: chat_model_a = chat_models[chat_interface_a.name] chat_model_b = chat_models[chat_interface_b.name] if chat_model_a == chat_model_b: voting_counts[chat_model_a] = voting_counts.get(chat_model_a, 0) + 1 else: voting_counts[chat_model_a] = voting_counts.get(chat_model_a, 0) + 1 voting_counts[chat_model_b] = voting_counts.get(chat_model_b, 0) + 1 header_a.object = f"## Model: {chat_models[chat_interface_a.name]}" header_b.object = f"## Model: {chat_models[chat_interface_b.name]}" for chat_label in set(chat_models.keys()): chat_models.pop(chat_label) perspective.object = ( pd.DataFrame(voting_counts, index=["Votes"]) .melt(var_name="Model", value_name="Votes") .set_index("Model") ) with sqlite3.connect("voting_counts.db") as conn: pd.DataFrame(voting_counts.items(), columns=["Model", "Votes"]).to_sql( "voting_counts", conn, if_exists="replace", index=False ) # initialize chat_models = {} with sqlite3.connect("voting_counts.db") as conn: conn.execute( "CREATE TABLE IF NOT EXISTS voting_counts (Model TEXT PRIMARY KEY, Votes INTEGER)" ) voting_counts = ( pd.read_sql("SELECT * FROM voting_counts", conn) .set_index("Model")["Votes"] .to_dict() ) # header api_key_input = pn.widgets.PasswordInput( placeholder="Mistral API Key", stylesheets=[".bk-input {color: black};"] ) # main tabs = pn.Tabs() # tab 1 chat_interface_kwargs = dict( callback=forward_message, show_undo=False, show_rerun=False, show_clear=False, show_stop=False, show_button_name=False, ) header_a = pn.pane.Markdown("## Model: A") chat_interface_a = pn.chat.ChatInterface( name="A", header=header_a, **chat_interface_kwargs ) header_b = pn.pane.Markdown("## Model: B") chat_interface_b = pn.chat.ChatInterface( name="B", header=header_b, **chat_interface_kwargs ) button_kwargs = dict(sizing_mode="stretch_width") button_row = pn.Row() for voting_label in VOTING_LABELS: button = pn.widgets.Button(name=voting_label, **button_kwargs) button.on_click(click_vote) button_row.append(button) tabs.append(("Chat", pn.Column(pn.Row(chat_interface_a, chat_interface_b), button_row))) # tab 2 perspective = pn.pane.Perspective( pd.DataFrame(voting_counts, index=["Votes"]) .melt(var_name="Model", value_name="Votes") .set_index("Model"), sizing_mode="stretch_both", editable=False, ) tabs.append(("Voting Results", perspective)) # layout pn.template.FastListTemplate( title="Mistral Chat Arena", header=[api_key_input], main=[tabs], ).servable()