Spaces:
Runtime error
Runtime error
import streamlit as st | |
from main import benchmark_model_multithreaded, benchmark_model_sequential | |
from prompts import questions as predefined_questions | |
import requests | |
import pandas as pd | |
# Set the title in the browser tab | |
st.set_page_config(page_title="Aidan Bench - Generator") | |
st.title("Aidan Bench - Generator") | |
# API Key Inputs with Security and User Experience Enhancements | |
st.warning("Please keep your API keys secure and confidential. This app does not store or log your API keys.") | |
if "open_router_key" not in st.session_state: | |
st.session_state.open_router_key = "" | |
if "openai_api_key" not in st.session_state: | |
st.session_state.openai_api_key = "" | |
open_router_key = st.text_input("Enter your Open Router API Key:", type="password", value=st.session_state.open_router_key) | |
openai_api_key = st.text_input("Enter your OpenAI API Key:", type="password", value=st.session_state.openai_api_key) | |
if st.button("Confirm API Keys"): | |
if open_router_key and openai_api_key: | |
st.session_state.open_router_key = open_router_key | |
st.session_state.openai_api_key = openai_api_key | |
st.success("API keys confirmed!") | |
else: | |
st.warning("Please enter both API keys.") | |
# Access API keys from session state | |
if st.session_state.open_router_key and st.session_state.openai_api_key: | |
# Fetch models from OpenRouter API | |
try: | |
response = requests.get("https://openrouter.ai/api/v1/models") | |
response.raise_for_status() # Raise an exception for bad status codes | |
all_models = response.json()["data"] | |
# Sort models alphabetically by their ID | |
all_models.sort(key=lambda model: model["id"]) | |
# --- Create dictionaries for easy model lookup --- | |
models_by_id = {model["id"]: model for model in all_models} | |
judge_models = [model["id"] for model in all_models if "gpt" in model["id"]] | |
judge_models.sort() | |
model_names = list(models_by_id.keys()) | |
except requests.exceptions.RequestException as e: | |
st.error(f"Error fetching models from OpenRouter API: {e}") | |
model_names = [] # Provide an empty list if API call fails | |
judge_models = [] | |
# Model Selection | |
if model_names: | |
model_name = st.selectbox("Select a Contestant Model", model_names) | |
# --- Display pricing for the selected model --- | |
selected_model = models_by_id.get(model_name) | |
if selected_model: | |
pricing_info = selected_model.get('pricing', {}) | |
prompt_price = float(pricing_info.get("prompt", 0)) * 1000000 | |
completion_price = float(pricing_info.get("completion", 0)) * 1000000 | |
# Display pricing information with increased precision | |
st.write(f"**Prompt Pricing:** ${prompt_price:.2f}/Million tokens (if applicable)") | |
st.write(f"**Completion Pricing:** ${completion_price:.2f}/Million tokens") | |
else: | |
st.write("**Pricing:** N/A") | |
else: | |
st.error("No models available. Please check your API connection.") | |
st.stop() | |
# Judge Model Selection | |
if judge_models: | |
judge_model_name = st.selectbox("Select a Judge Model", judge_models) | |
# --- Display pricing for the selected judge model --- | |
selected_judge_model = models_by_id.get(judge_model_name) | |
if selected_judge_model: | |
pricing_info = selected_judge_model.get('pricing', {}) | |
prompt_price = float(pricing_info.get("prompt", 0)) * 1000000 | |
completion_price = float(pricing_info.get("completion", 0)) * 1000000 | |
# Display pricing information with increased precision | |
st.write(f"**Prompt Pricing:** ${prompt_price:.2f}/Million tokens (if applicable)") | |
st.write(f"**Completion Pricing:** ${completion_price:.2f}/Million tokens") | |
else: | |
st.write("**Pricing:** N/A") | |
else: | |
st.error("No judge models available. Please check your API connection.") | |
st.stop() | |
# Initialize session state for user_questions and predefined_questions | |
if "user_questions" not in st.session_state: | |
st.session_state.user_questions = [] | |
# Threshold Sliders | |
st.sidebar.subheader("Threshold Sliders") | |
coherence_threshold = st.sidebar.slider("Coherence Threshold (0-5):", 0, 5, 3) | |
novelty_threshold = st.sidebar.slider("Novelty Threshold (0-1):", 0.0, 1.0, 0.1) | |
st.sidebar.subheader("Temp Sliders") | |
temp_threshold = st.sidebar.slider("Temperature (0-2):", 0.0, 2.0, 1.0) | |
top_p = st.sidebar.slider("Top P (0-1):", 0.0, 1.0, 1.0) | |
# Workflow Selection | |
workflow = st.radio("Select Workflow:", ["Use Predefined Questions", "Use User-Defined Questions"]) | |
# Handle Predefined Questions | |
if workflow == "Use Predefined Questions": | |
st.header("Question Selection") | |
# Multiselect for predefined questions | |
selected_questions = st.multiselect( | |
"Select questions to benchmark:", | |
predefined_questions, | |
predefined_questions # Select all by default | |
) | |
# Handle User-Defined Questions | |
elif workflow == "Use User-Defined Questions": | |
st.header("Question Input") | |
# Input for adding a new question | |
new_question = st.text_input("Enter a new question:") | |
if st.button("Add Question") and new_question: | |
new_question = new_question.strip() # Remove leading/trailing whitespace | |
if new_question and new_question not in st.session_state.user_questions: | |
st.session_state.user_questions.append(new_question) # Append to session state | |
st.success(f"Question '{new_question}' added successfully.") | |
else: | |
st.warning("Question already exists or is empty!") | |
# Display multiselect with updated user questions | |
selected_questions = st.multiselect( | |
"Select your custom questions:", | |
options=st.session_state.user_questions, | |
default=st.session_state.user_questions | |
) | |
# Display selected questions | |
st.write("Selected Questions:", selected_questions) | |
# Choose execution mode | |
execution_mode = st.radio("Execution Mode:", ["Sequential", "Multithreaded"]) | |
# If multithreaded, allow user to configure thread pool size | |
if execution_mode == "Multithreaded": | |
max_threads = st.slider("Maximum Number of Threads:", 1, 10, 4) # Default to 4 threads | |
else: | |
max_threads = None # For sequential mode | |
# Benchmark Execution | |
if st.button("Start Benchmark"): | |
if not selected_questions: | |
st.warning("Please select at least one question.") | |
else: | |
num_questions = len(selected_questions) | |
results = [] | |
# Stop button (not implemented yet) | |
stop_button = st.button("Stop Benchmark") | |
# Benchmarking logic using the chosen execution mode | |
if execution_mode == "Sequential": | |
question_results = benchmark_model_sequential(model_name, selected_questions, st.session_state.open_router_key, st.session_state.openai_api_key,judge_model_name,coherence_threshold,novelty_threshold,temp_threshold,top_p) | |
else: # Multithreaded | |
question_results = benchmark_model_multithreaded(model_name, selected_questions, st.session_state.open_router_key, st.session_state.openai_api_key, max_threads, judge_model_name, coherence_threshold,novelty_threshold,temp_threshold,top_p) | |
results.extend(question_results) | |
# Display results in a table | |
st.write("Results:") | |
results_table = [] | |
for result in results: | |
for answer in result["answers"]: | |
results_table.append({ | |
"Question": result["question"], | |
"Answer": answer, | |
"Contestant Model": model_name, | |
"Judge Model": judge_model_name, | |
"Coherence Score": result["coherence_score"], | |
"Novelty Score": result["novelty_score"] | |
}) | |
st.table(results_table) | |
df = pd.DataFrame(results_table) # Create a Pandas DataFrame from the results | |
csv = df.to_csv(index=False).encode('utf-8') # Convert DataFrame to CSV | |
st.download_button( | |
label="Export Results as CSV", | |
data=csv, | |
file_name="benchmark_results.csv", | |
mime='text/csv' | |
) | |
if stop_button: | |
st.warning("Partial results displayed due to interruption.") | |
else: | |
st.success("Benchmark completed!") | |
else: | |
st.warning("Please confirm your API keys first.") | |