t-montes's picture
backend
cbf5f99
raw
history blame
7.32 kB
import gradio as gr
import random
import os
import requests # added import for API calls
from dotenv import load_dotenv
load_dotenv()
# Define model names for randomization (extracted from your previous initial_data)
model_names = [
"dalle_desc_25",
"dalle_desc_50",
"dalle_desc_100",
"dalle_desc_150",
"dalle_desc_250",
"desc_25_threshold_250",
"desc_25_threshold_500",
"desc_25_threshold_1000",
"desc_250_threshold_250",
"desc_250_threshold_500",
"desc_250_threshold_1000",
"jpeg_scale_2",
"jpeg_scale_4",
"jpeg_scale_8",
"jpeg_scale_16",
"jpeg_scale_32",
"sa30_desc_50",
"sa30_desc_100",
"sa30_desc_150",
"sa30_desc_250",
"sd30_desc_25",
"sd35_desc_25",
"sd35_desc_50",
"sd35_desc_100",
"sd35_desc_150",
"sd35_desc_250"
]
# Global variables for the image template and current state:
images = [f"3d/{model}/OBJ.png" for model in model_names]
current_images = [0, 0]
current_obj = None # will store the object used in the current voting round
# Set API host and access key from environment variables
BACK_HOST = os.getenv("BACK_HOST")
ACCESS_KEY = os.getenv("ACCESS_KEY")
# List of objects to choose from (kept as-is)
objs = ['axe', 'barrel', 'bed', 'bottle', 'canon', 'car', 'chair', 'chair2', 'chair3', 'chair4']
def get_new_images():
global current_images, current_obj
random.seed()
idx1, idx2 = random.sample(range(len(images)), 2)
current_images = [idx1, idx2]
obj = random.choice(objs)
current_obj = obj # store the object for the current round
new_images = [img.replace('OBJ', obj) for img in images]
original = f"3d/original/{obj}.png"
return {
"original": original,
"image1": new_images[idx1],
"image2": new_images[idx2],
"label1": "Left",
"label2": "Right",
"obj": obj # return the object in case it is needed
}
def vote_and_randomize(choice):
global current_images, current_obj
if choice == "left":
winner_index = current_images[0]
loser_index = current_images[1]
else:
winner_index = current_images[1]
loser_index = current_images[0]
winner_model = model_names[winner_index]
loser_model = model_names[loser_index]
# Use the current object generated during the image randomization
obj = current_obj
# Prepare payload for voting
payload = {
"winner": winner_model,
"loser": loser_model,
"object": obj
}
url = f"{BACK_HOST}/vote"
headers = {
"Authorization": f"Bearer {ACCESS_KEY}",
"Content-Type": "application/json"
}
try:
response = requests.post(url, headers=headers, json=payload)
resp_json = response.json()
if resp_json.get("message") == "Vote recorded successfully":
message = f"Thanks for voting for {winner_model}!"
else:
message = "Error recording vote. Please try again."
except Exception as e:
message = "Error recording vote. Please try again."
new_state = get_new_images()
updated_leaderboard = get_leaderboard_data() # refresh leaderboard from API
return (
message,
new_state["original"],
new_state["image1"],
new_state["image2"],
new_state["label1"],
new_state["label2"],
updated_leaderboard
)
def start_voting():
# Get initial random images
initial_state = get_new_images()
return (
gr.update(visible=False), # Hide start button
gr.update(visible=True), # Show voting container
initial_state["original"],
initial_state["image1"],
initial_state["image2"],
initial_state["label1"],
initial_state["label2"]
)
def get_leaderboard_data():
"""Fetch leaderboard data from the API and transform it for display."""
headers = {
"Authorization": f"Bearer {ACCESS_KEY}"
}
try:
response = requests.get(f"{BACK_HOST}/get", headers=headers)
if response.status_code == 200:
data = response.json()
# Transform the dictionary into a list of rows for the DataFrame
leaderboard_list = [[name, elo, ""] for name, elo in data.items()]
return leaderboard_list
else:
return []
except Exception as e:
return []
def refresh_leaderboard():
"""Refresh leaderboard data."""
return get_leaderboard_data()
with gr.Blocks(css="""
#main-image {
margin: auto; /* Center the image */
display: block;
}
""") as demo:
with gr.Tabs() as tabs: # Remove elem_id, we don't need it anymore
# Tab 1: Voting
with gr.Tab("Voting"):
gr.Markdown("### Vote for your favorite option!")
# Start button (centered)
with gr.Column(elem_id="start-container"):
start_btn = gr.Button("Start!", scale=0.5)
# Voting interface (initially hidden)
with gr.Column(visible=False) as voting_container:
# Image Comparison Grid
# justify in the center
with gr.Row(equal_height=True):
main_image = gr.Image(value=None, label="Original", interactive=False, show_download_button=True, elem_id="main-image", scale=0.25)
with gr.Row():
left_image = gr.Image(value=None, label="Left Option", interactive=False, show_download_button=False)
right_image = gr.Image(value=None, label="Right Option", interactive=False, show_download_button=False)
with gr.Row():
vote_1 = gr.Button(value="")
vote_2 = gr.Button(value="")
output = gr.Textbox(label="Vote Result", interactive=False)
# Tab 2: Leaderboard
with gr.Tab("Leaderboard") as leaderboard_tab:
gr.Markdown("### Leaderboard")
leaderboard_table = gr.DataFrame(
headers=["Name", "Elo", "Description"],
value=get_leaderboard_data(),
interactive=False
)
# Add a refresh button
refresh_btn = gr.Button("Refresh Leaderboard")
# Handle start button click
start_btn.click(
fn=start_voting,
outputs=[
start_btn,
voting_container,
main_image,
left_image,
right_image,
vote_1,
vote_2
]
)
# Handle voting buttons
vote_1.click(
fn=lambda: vote_and_randomize("left"),
outputs=[output, main_image, left_image, right_image, vote_1, vote_2, leaderboard_table]
)
vote_2.click(
fn=lambda: vote_and_randomize("right"),
outputs=[output, main_image, left_image, right_image, vote_1, vote_2, leaderboard_table]
)
# Replace the tabs.change with a refresh button click handler
refresh_btn.click(
fn=refresh_leaderboard,
outputs=leaderboard_table
)
# Also refresh when the leaderboard tab is selected
leaderboard_tab.select(
fn=refresh_leaderboard,
outputs=leaderboard_table
)
if __name__ == "__main__":
demo.launch()