Spaces:
Sleeping
Sleeping
import gradio as gr | |
import random | |
import os | |
import requests # added import for API calls | |
from dotenv import load_dotenv | |
load_dotenv() | |
# Define model names for randomization (extracted from your previous initial_data) | |
model_names = [ | |
"dalle_desc_25", | |
"dalle_desc_50", | |
"dalle_desc_100", | |
"dalle_desc_150", | |
"dalle_desc_250", | |
"desc_25_threshold_250", | |
"desc_25_threshold_500", | |
"desc_25_threshold_1000", | |
"desc_250_threshold_250", | |
"desc_250_threshold_500", | |
"desc_250_threshold_1000", | |
"jpeg_scale_2", | |
"jpeg_scale_4", | |
"jpeg_scale_8", | |
"jpeg_scale_16", | |
"jpeg_scale_32", | |
"sa30_desc_50", | |
"sa30_desc_100", | |
"sa30_desc_150", | |
"sa30_desc_250", | |
"sd30_desc_25", | |
"sd35_desc_25", | |
"sd35_desc_50", | |
"sd35_desc_100", | |
"sd35_desc_150", | |
"sd35_desc_250" | |
] | |
# Global variables for the image template and current state: | |
images = [f"3d/{model}/OBJ.png" for model in model_names] | |
current_images = [0, 0] | |
current_obj = None # will store the object used in the current voting round | |
# Set API host and access key from environment variables | |
BACK_HOST = os.getenv("BACK_HOST") | |
ACCESS_KEY = os.getenv("ACCESS_KEY") | |
# List of objects to choose from (kept as-is) | |
objs = ['axe', 'barrel', 'bed', 'bottle', 'canon', 'car', 'chair', 'chair2', 'chair3', 'chair4'] | |
def get_new_images(): | |
global current_images, current_obj | |
random.seed() | |
idx1, idx2 = random.sample(range(len(images)), 2) | |
current_images = [idx1, idx2] | |
obj = random.choice(objs) | |
current_obj = obj # store the object for the current round | |
new_images = [img.replace('OBJ', obj) for img in images] | |
original = f"3d/original/{obj}.png" | |
return { | |
"original": original, | |
"image1": new_images[idx1], | |
"image2": new_images[idx2], | |
"label1": "Left", | |
"label2": "Right", | |
"obj": obj # return the object in case it is needed | |
} | |
def vote_and_randomize(choice): | |
global current_images, current_obj | |
if choice == "left": | |
winner_index = current_images[0] | |
loser_index = current_images[1] | |
else: | |
winner_index = current_images[1] | |
loser_index = current_images[0] | |
winner_model = model_names[winner_index] | |
loser_model = model_names[loser_index] | |
# Use the current object generated during the image randomization | |
obj = current_obj | |
# Prepare payload for voting | |
payload = { | |
"winner": winner_model, | |
"loser": loser_model, | |
"object": obj | |
} | |
url = f"{BACK_HOST}/vote" | |
headers = { | |
"Authorization": f"Bearer {ACCESS_KEY}", | |
"Content-Type": "application/json" | |
} | |
try: | |
response = requests.post(url, headers=headers, json=payload) | |
resp_json = response.json() | |
if resp_json.get("message") == "Vote recorded successfully": | |
message = f"Thanks for voting for {winner_model}!" | |
else: | |
message = "Error recording vote. Please try again." | |
except Exception as e: | |
message = "Error recording vote. Please try again." | |
new_state = get_new_images() | |
updated_leaderboard = get_leaderboard_data() # refresh leaderboard from API | |
return ( | |
message, | |
new_state["original"], | |
new_state["image1"], | |
new_state["image2"], | |
new_state["label1"], | |
new_state["label2"], | |
updated_leaderboard | |
) | |
def start_voting(): | |
# Get initial random images | |
initial_state = get_new_images() | |
return ( | |
gr.update(visible=False), # Hide start button | |
gr.update(visible=True), # Show voting container | |
initial_state["original"], | |
initial_state["image1"], | |
initial_state["image2"], | |
initial_state["label1"], | |
initial_state["label2"] | |
) | |
def get_leaderboard_data(): | |
"""Fetch leaderboard data from the API and transform it for display.""" | |
headers = { | |
"Authorization": f"Bearer {ACCESS_KEY}" | |
} | |
try: | |
response = requests.get(f"{BACK_HOST}/get", headers=headers) | |
if response.status_code == 200: | |
data = response.json() | |
# Transform the dictionary into a list of rows for the DataFrame | |
leaderboard_list = [[name, elo, ""] for name, elo in data.items()] | |
return leaderboard_list | |
else: | |
return [] | |
except Exception as e: | |
return [] | |
def refresh_leaderboard(): | |
"""Refresh leaderboard data.""" | |
return get_leaderboard_data() | |
with gr.Blocks(css=""" | |
#main-image { | |
margin: auto; /* Center the image */ | |
display: block; | |
} | |
""") as demo: | |
with gr.Tabs() as tabs: # Remove elem_id, we don't need it anymore | |
# Tab 1: Voting | |
with gr.Tab("Voting"): | |
gr.Markdown("### Vote for your favorite option!") | |
# Start button (centered) | |
with gr.Column(elem_id="start-container"): | |
start_btn = gr.Button("Start!", scale=0.5) | |
# Voting interface (initially hidden) | |
with gr.Column(visible=False) as voting_container: | |
# Image Comparison Grid | |
# justify in the center | |
with gr.Row(equal_height=True): | |
main_image = gr.Image(value=None, label="Original", interactive=False, show_download_button=True, elem_id="main-image", scale=0.25) | |
with gr.Row(): | |
left_image = gr.Image(value=None, label="Left Option", interactive=False, show_download_button=False) | |
right_image = gr.Image(value=None, label="Right Option", interactive=False, show_download_button=False) | |
with gr.Row(): | |
vote_1 = gr.Button(value="") | |
vote_2 = gr.Button(value="") | |
output = gr.Textbox(label="Vote Result", interactive=False) | |
# Tab 2: Leaderboard | |
with gr.Tab("Leaderboard") as leaderboard_tab: | |
gr.Markdown("### Leaderboard") | |
leaderboard_table = gr.DataFrame( | |
headers=["Name", "Elo", "Description"], | |
value=get_leaderboard_data(), | |
interactive=False | |
) | |
# Add a refresh button | |
refresh_btn = gr.Button("Refresh Leaderboard") | |
# Handle start button click | |
start_btn.click( | |
fn=start_voting, | |
outputs=[ | |
start_btn, | |
voting_container, | |
main_image, | |
left_image, | |
right_image, | |
vote_1, | |
vote_2 | |
] | |
) | |
# Handle voting buttons | |
vote_1.click( | |
fn=lambda: vote_and_randomize("left"), | |
outputs=[output, main_image, left_image, right_image, vote_1, vote_2, leaderboard_table] | |
) | |
vote_2.click( | |
fn=lambda: vote_and_randomize("right"), | |
outputs=[output, main_image, left_image, right_image, vote_1, vote_2, leaderboard_table] | |
) | |
# Replace the tabs.change with a refresh button click handler | |
refresh_btn.click( | |
fn=refresh_leaderboard, | |
outputs=leaderboard_table | |
) | |
# Also refresh when the leaderboard tab is selected | |
leaderboard_tab.select( | |
fn=refresh_leaderboard, | |
outputs=leaderboard_table | |
) | |
if __name__ == "__main__": | |
demo.launch() |