Spaces:
Running
Running
import gradio as gr | |
from functools import lru_cache | |
import random | |
import requests | |
import logging | |
import arena_config | |
import plotly.graph_objects as go | |
from typing import Dict | |
from leaderboard import ( | |
get_current_leaderboard, | |
update_leaderboard, | |
start_backup_thread, | |
get_leaderboard, | |
get_elo_leaderboard, | |
ensure_elo_ratings_initialized | |
) | |
import sys | |
# Initialize logging for errors only | |
logging.basicConfig(level=logging.ERROR) | |
logger = logging.getLogger(__name__) | |
# Start the backup thread | |
start_backup_thread() | |
# Function to get available models (using predefined list) | |
def get_available_models(): | |
return [model[0] for model in arena_config.APPROVED_MODELS] | |
# Function to call Ollama API with caching | |
def call_ollama_api(model, prompt): | |
payload = { | |
"model": model, | |
"messages": [{"role": "user", "content": prompt}], | |
} | |
try: | |
response = requests.post( | |
f"{arena_config.API_URL}/v1/chat/completions", | |
headers=arena_config.HEADERS, | |
json=payload, | |
timeout=100 | |
) | |
response.raise_for_status() | |
data = response.json() | |
return data["choices"][0]["message"]["content"] | |
except requests.exceptions.RequestException as e: | |
logger.error(f"Error calling Ollama API for model {model}: {e}") | |
return f"Error: Unable to get response from the model." | |
# Generate responses using two randomly selected models | |
def generate_responses(prompt): | |
available_models = get_available_models() | |
if len(available_models) < 2: | |
return "Error: Not enough models available", "Error: Not enough models available", None, None | |
selected_models = random.sample(available_models, 2) | |
model_a, model_b = selected_models | |
model_a_response = call_ollama_api(model_a, prompt) | |
model_b_response = call_ollama_api(model_b, prompt) | |
return model_a_response, model_b_response, model_a, model_b | |
def battle_arena(prompt): | |
response_a, response_b, model_a, model_b = generate_responses(prompt) | |
nickname_a = random.choice(arena_config.model_nicknames) | |
nickname_b = random.choice(arena_config.model_nicknames) | |
# Format responses for gr.Chatbot, including the user's prompt | |
response_a_formatted = [ | |
{"role": "user", "content": prompt}, | |
{"role": "assistant", "content": response_a} | |
] | |
response_b_formatted = [ | |
{"role": "user", "content": prompt}, | |
{"role": "assistant", "content": response_b} | |
] | |
if random.choice([True, False]): | |
return ( | |
response_a_formatted, response_b_formatted, model_a, model_b, | |
gr.update(label=nickname_a, value=response_a_formatted), | |
gr.update(label=nickname_b, value=response_b_formatted), | |
gr.update(interactive=True, value=f"Vote for {nickname_a}"), | |
gr.update(interactive=True, value=f"Vote for {nickname_b}"), | |
gr.update(interactive=True, visible=True), # Enable and show Tie button | |
prompt, # Return the original prompt | |
0 # Initialize tie count | |
) | |
else: | |
return ( | |
response_b_formatted, response_a_formatted, model_b, model_a, | |
gr.update(label=nickname_a, value=response_b_formatted), | |
gr.update(label=nickname_b, value=response_a_formatted), | |
gr.update(interactive=True, value=f"Vote for {nickname_a}"), | |
gr.update(interactive=True, value=f"Vote for {nickname_b}"), | |
gr.update(interactive=True, visible=True), # Enable and show Tie button | |
prompt, # Return the original prompt | |
0 # Initialize tie count | |
) | |
def record_vote(prompt, left_response, right_response, left_model, right_model, choice): | |
# Check if outputs are generated | |
if not left_response or not right_response or not left_model or not right_model: | |
return ( | |
"Please generate responses before voting.", | |
gr.update(), | |
gr.update(interactive=False), | |
gr.update(interactive=False), | |
gr.update(visible=False), | |
gr.update() | |
) | |
winner = left_model if choice == "Left is better" else right_model | |
loser = right_model if choice == "Left is better" else left_model | |
# Update the leaderboard | |
battle_results = update_leaderboard(winner, loser) | |
result_message = f""" | |
π Vote recorded! You're awesome! π | |
π΅ In the left corner: {get_human_readable_name(left_model)} | |
π΄ In the right corner: {get_human_readable_name(right_model)} | |
π And the champion you picked is... {get_human_readable_name(winner)}! π₯ | |
""" | |
return ( | |
gr.update(value=result_message, visible=True), # Show result as Markdown | |
get_leaderboard(), # Update leaderboard | |
get_elo_leaderboard(), # Add this line | |
gr.update(interactive=False), # Disable left vote button | |
gr.update(interactive=False), # Disable right vote button | |
gr.update(interactive=False), # Disable tie button | |
gr.update(visible=True), # Show model names | |
get_leaderboard_chart() # Update leaderboard chart | |
) | |
def get_leaderboard_chart(): | |
battle_results = get_current_leaderboard() | |
# Calculate scores and sort results | |
for model, results in battle_results.items(): | |
total_battles = results["wins"] + results["losses"] | |
if total_battles > 0: | |
win_rate = results["wins"] / total_battles | |
results["score"] = win_rate * (1 - 1 / (total_battles + 1)) | |
else: | |
results["score"] = 0 | |
sorted_results = sorted( | |
battle_results.items(), | |
key=lambda x: (x[1]["score"], x[1]["wins"] + x[1]["losses"]), | |
reverse=True | |
) | |
models = [get_human_readable_name(model) for model, _ in sorted_results] | |
wins = [results["wins"] for _, results in sorted_results] | |
losses = [results["losses"] for _, results in sorted_results] | |
scores = [results["score"] for _, results in sorted_results] | |
fig = go.Figure() | |
# Stacked Bar chart for Wins and Losses | |
fig.add_trace(go.Bar( | |
x=models, | |
y=wins, | |
name='Wins', | |
marker_color='#22577a' | |
)) | |
fig.add_trace(go.Bar( | |
x=models, | |
y=losses, | |
name='Losses', | |
marker_color='#38a3a5' | |
)) | |
# Line chart for Scores | |
fig.add_trace(go.Scatter( | |
x=models, | |
y=scores, | |
name='Score', | |
yaxis='y2', | |
line=dict(color='#ff7f0e', width=2) | |
)) | |
# Update layout for full-width, increased height, and secondary y-axis | |
fig.update_layout( | |
title='Model Performance', | |
xaxis_title='Models', | |
yaxis_title='Number of Battles', | |
yaxis2=dict( | |
title='Score', | |
overlaying='y', | |
side='right' | |
), | |
barmode='stack', | |
height=800, | |
width=1450, | |
autosize=True, | |
legend=dict( | |
orientation='h', | |
yanchor='bottom', | |
y=1.02, | |
xanchor='right', | |
x=1 | |
) | |
) | |
chart_data = fig.to_json() | |
print(f"Chart size: {sys.getsizeof(chart_data)} bytes") | |
return fig | |
def new_battle(): | |
nickname_a = random.choice(arena_config.model_nicknames) | |
nickname_b = random.choice(arena_config.model_nicknames) | |
return ( | |
"", # Reset prompt_input | |
gr.update(value=[], label=nickname_a), # Reset left Chatbot | |
gr.update(value=[], label=nickname_b), # Reset right Chatbot | |
None, | |
None, | |
gr.update(interactive=False, value=f"Vote for {nickname_a}"), | |
gr.update(interactive=False, value=f"Vote for {nickname_b}"), | |
gr.update(interactive=False, visible=False), # Reset Tie button | |
gr.update(value="", visible=False), | |
gr.update(), | |
gr.update(visible=False), | |
gr.update(), | |
0 # Reset tie_count | |
) | |
# Add this new function | |
def get_human_readable_name(model_name: str) -> str: | |
model_dict = dict(arena_config.APPROVED_MODELS) | |
return model_dict.get(model_name, model_name) | |
# Add this new function to randomly select a prompt | |
def random_prompt(): | |
return random.choice(arena_config.example_prompts) | |
# Modify the continue_conversation function | |
def continue_conversation(prompt, left_chat, right_chat, left_model, right_model, previous_prompt, tie_count): | |
# Check if the prompt is empty or the same as the previous one | |
if not prompt or prompt == previous_prompt: | |
prompt = random.choice(arena_config.example_prompts) | |
left_response = call_ollama_api(left_model, prompt) | |
right_response = call_ollama_api(right_model, prompt) | |
left_chat.append({"role": "user", "content": prompt}) | |
left_chat.append({"role": "assistant", "content": left_response}) | |
right_chat.append({"role": "user", "content": prompt}) | |
right_chat.append({"role": "assistant", "content": right_response}) | |
tie_count += 1 | |
tie_button_state = gr.update(interactive=True) if tie_count < 3 else gr.update(interactive=False, value="Max ties reached. Please vote!") | |
return ( | |
gr.update(value=left_chat), | |
gr.update(value=right_chat), | |
gr.update(value=""), # Clear the prompt input | |
tie_button_state, | |
prompt, # Return the new prompt | |
tie_count | |
) | |
# Initialize Gradio Blocks | |
with gr.Blocks(css=""" | |
#dice-button { | |
min-height: 90px; | |
font-size: 35px; | |
} | |
""") as demo: | |
gr.Markdown(arena_config.ARENA_NAME) | |
gr.Markdown(arena_config.ARENA_DESCRIPTION) | |
# Battle Arena Tab | |
with gr.Tab("Battle Arena"): | |
with gr.Row(): | |
prompt_input = gr.Textbox( | |
label="Enter your prompt", | |
placeholder="Type your prompt here...", | |
scale=20 | |
) | |
random_prompt_btn = gr.Button("π²", scale=1, elem_id="dice-button") | |
gr.Markdown("<br>") | |
# Add the random prompt button functionality | |
random_prompt_btn.click( | |
random_prompt, | |
outputs=prompt_input | |
) | |
submit_btn = gr.Button("Generate Responses", variant="primary") | |
with gr.Row(): | |
left_output = gr.Chatbot(label=random.choice(arena_config.model_nicknames), type="messages") | |
right_output = gr.Chatbot(label=random.choice(arena_config.model_nicknames), type="messages") | |
with gr.Row(): | |
left_vote_btn = gr.Button(f"Vote for {left_output.label}", interactive=False) | |
tie_btn = gr.Button("Tie π Continue with a new prompt", interactive=False, visible=False) | |
right_vote_btn = gr.Button(f"Vote for {right_output.label}", interactive=False) | |
result = gr.Textbox(label="Result", interactive=False, visible=False) | |
with gr.Row(visible=False) as model_names_row: | |
left_model = gr.Textbox(label="π΅ Left Model", interactive=False) | |
right_model = gr.Textbox(label="π΄ Right Model", interactive=False) | |
previous_prompt = gr.State("") # Add this line to store the previous prompt | |
tie_count = gr.State(0) # Add this line to keep track of tie count | |
new_battle_btn = gr.Button("New Battle") | |
# Leaderboard Tab | |
with gr.Tab("Leaderboard"): | |
leaderboard = gr.HTML(label="Leaderboard") | |
# Performance Chart Tab | |
with gr.Tab("Performance Chart"): | |
leaderboard_chart = gr.Plot(label="Model Performance Chart") | |
# ELO Leaderboard Tab | |
with gr.Tab("ELO Leaderboard"): | |
elo_leaderboard = gr.HTML(label="ELO Leaderboard") | |
# Define interactions | |
submit_btn.click( | |
battle_arena, | |
inputs=prompt_input, | |
outputs=[left_output, right_output, left_model, right_model, | |
left_output, right_output, left_vote_btn, right_vote_btn, | |
tie_btn, previous_prompt, tie_count] | |
) | |
left_vote_btn.click( | |
lambda *args: record_vote(*args, "Left is better"), | |
inputs=[prompt_input, left_output, right_output, left_model, right_model], | |
outputs=[result, leaderboard, elo_leaderboard, left_vote_btn, | |
right_vote_btn, tie_btn, model_names_row, leaderboard_chart] | |
) | |
right_vote_btn.click( | |
lambda *args: record_vote(*args, "Right is better"), | |
inputs=[prompt_input, left_output, right_output, left_model, right_model], | |
outputs=[result, leaderboard, elo_leaderboard, left_vote_btn, | |
right_vote_btn, tie_btn, model_names_row, leaderboard_chart] | |
) | |
tie_btn.click( | |
continue_conversation, | |
inputs=[prompt_input, left_output, right_output, left_model, right_model, previous_prompt, tie_count], | |
outputs=[left_output, right_output, prompt_input, tie_btn, previous_prompt, tie_count] | |
) | |
new_battle_btn.click( | |
new_battle, | |
outputs=[prompt_input, left_output, right_output, left_model, | |
right_model, left_vote_btn, right_vote_btn, tie_btn, | |
result, leaderboard, model_names_row, leaderboard_chart, tie_count] | |
) | |
# Update leaderboard and chart on launch | |
demo.load(get_leaderboard, outputs=leaderboard) | |
demo.load(get_elo_leaderboard, outputs=elo_leaderboard) | |
demo.load(get_leaderboard_chart, outputs=leaderboard_chart) | |
if __name__ == "__main__": | |
# Initialize ELO ratings before launching the app | |
ensure_elo_ratings_initialized() | |
demo.launch(show_api=False) | |