philippds's picture
Update app.py
db08613 verified
raw
history blame
15.5 kB
import os
import json
import requests
import numpy as np
import gradio as gr
import pandas as pd
from huggingface_hub import HfApi, hf_hub_download, snapshot_download
from huggingface_hub.repocard import metadata_load
from apscheduler.schedulers.background import BackgroundScheduler
from tqdm.contrib.concurrent import thread_map
from utils import make_clickable_model, make_clickable_user
from typing import List # Add this import statement
DATASET_REPO_URL = (
"https://huggingface.co/datasets/hivex-research/hivex-leaderboard-data"
)
DATASET_REPO_ID = "hivex-research/hivex-leaderboard-data"
HF_TOKEN = os.environ.get("HF_TOKEN")
block = gr.Blocks()
api = HfApi(token=HF_TOKEN)
# .tab-buttons button {
# font-size: 20px;
# }
custom_css = """
/* Full width space */
.gradio-container {
max-width: 95%!important;
}
.gr-dataframe table {
width: auto;
}
.gr-dataframe td, .gr-dataframe th {
white-space: nowrap;
text-overflow: ellipsis;
overflow: hidden;
width: 1%;
}
"""
# Pattern: 0 Default, 1 Grid, 2 Chain, 3 Circle, 4 Square, 5 Cross, 6 Two_Rows, 7 Field, 8 Random
pattern_map = {
0: "0: Default",
1: "1: Grid",
2: "2: Chain",
3: "3: Circle",
4: "4: Square",
5: "5: Cross",
6: "6: Two Rows",
7: "7: Field",
8: "8: Random",
}
hivex_envs = [
{
"title": "Wind Farm Control",
"hivex_env": "hivex-wind-farm-control",
"task_count": 2,
},
{
"title": "Wildfire Resource Management",
"hivex_env": "hivex-wildfire-resource-management",
"task_count": 3,
},
{
"title": "Drone-Based Reforestation",
"hivex_env": "hivex-drone-based-reforestation",
"task_count": 7,
},
{
"title": "Ocean Plastic Collection",
"hivex_env": "hivex-ocean-plastic-collection",
"task_count": 4,
},
{
"title": "Aerial Wildfire Suppression",
"hivex_env": "hivex-aerial-wildfire-suppression",
"task_count": 9,
},
]
verified_users = ["hivex-research"]
verified_models = [{"user": "hivex-research", "model": "my_model"}]
def restart():
print("RESTART")
api.restart_space(repo_id="hivex-research/hivex-leaderboard")
def download_leaderboard_dataset():
path = snapshot_download(repo_id=DATASET_REPO_ID, repo_type="dataset")
return path
def get_total_models():
total_models = 0
for hivex_env in hivex_envs:
model_ids = get_model_ids(hivex_env["hivex_env"])
total_models += len(model_ids)
return total_models
def get_model_ids(hivex_env):
api = HfApi()
models = api.list_models(filter=hivex_env)
model_ids = [x.modelId for x in models]
return model_ids
def get_metadata(model_id):
try:
readme_path = hf_hub_download(model_id, filename="README.md", etag_timeout=180)
return metadata_load(readme_path)
except requests.exceptions.HTTPError:
# 404 README.md not found
return None
def update_leaderboard_dataset_parallel(hivex_env, path):
# Get model ids associated with hivex_env
model_ids = get_model_ids(hivex_env)
def process_model(model_id):
meta = get_metadata(model_id)
# LOADED_MODEL_METADATA[model_id] = meta if meta is not None else ''
if meta is None:
return None
user_id = model_id.split("/")[0]
row = {}
row["Verified"] = "βœ…" if user_id in verified_users else "❌"
row["User"] = user_id
row["Model"] = model_id
results = meta["model-index"][0]["results"][0]
row["Task-ID"] = results["task"]["task-id"]
row["Task"] = results["task"]["name"]
if "pattern-id" in results["task"] or "difficulty-id" in results["task"]:
key = "Pattern" if "pattern-id" in results["task"] else "Difficulty"
row[key] = (
pattern_map[results["task"]["pattern-id"]]
if "pattern-id" in results["task"]
else results["task"]["difficulty-id"]
)
results_metrics = results["metrics"]
for result in results_metrics:
row[result["name"]] = float(result["value"].split("+/-")[0].strip())
return row
data = list(thread_map(process_model, model_ids, desc="Processing models"))
# Filter out None results (models with no metadata)
data = [row for row in data if row is not None]
# ranked_dataframe = rank_dataframe(pd.DataFrame.from_records(data))
ranked_dataframe = pd.DataFrame.from_records(data)
new_history = ranked_dataframe
file_path = path + "/" + hivex_env + ".csv"
new_history.to_csv(file_path, index=False)
return ranked_dataframe
def run_update_dataset():
path_ = download_leaderboard_dataset()
for i in range(0, len(hivex_envs)):
hivex_env = hivex_envs[i]
update_leaderboard_dataset_parallel(hivex_env["hivex_env"], path_)
api.upload_folder(
folder_path=path_,
repo_id="hivex-research/hivex-leaderboard-data",
repo_type="dataset",
commit_message="Update dataset",
)
def get_data(rl_env, task_id, path) -> pd.DataFrame:
"""
Get data from rl_env, filter by the given task_id, and drop the Task-ID column.
Also drops any columns that have no data (all values are NaN) or all values are 0.0.
:return: filtered data as a pandas DataFrame without the Task-ID column
"""
csv_path = path + "/" + rl_env + ".csv"
data = pd.read_csv(csv_path)
# Filter the data to only include rows where the "Task-ID" column matches the given task_id
filtered_data = data[data["Task-ID"] == task_id]
# Drop the "Task-ID" column
filtered_data = filtered_data.drop(columns=["Task-ID"])
# Drop the "Task" column
filtered_data = filtered_data.drop(columns=["Task"])
# Drop columns that have no data (all values are NaN)
filtered_data = filtered_data.dropna(axis=1, how="all")
# Drop columns where all values are 0.0
filtered_data = filtered_data.loc[:, (filtered_data != 0.0).any(axis=0)]
# Convert User and Model columns to clickable links
for index, row in filtered_data.iterrows():
user_id = row["User"]
filtered_data.loc[index, "User"] = make_clickable_user(user_id)
model_id = row["Model"]
filtered_data.loc[index, "Model"] = make_clickable_model(model_id)
return filtered_data
def get_task(rl_env, task_id, path) -> str:
"""
Get the task name from the leaderboard dataset based on the rl_env and task_id.
:return: The task name as a string
"""
csv_path = path + "/" + rl_env + ".csv"
data = pd.read_csv(csv_path)
# Filter the data to find the row with the matching task_id
task_row = data[data["Task-ID"] == task_id]
# Check if the task exists and return the task name
if not task_row.empty:
task_name = task_row.iloc[0]["Task"]
return task_name
else:
return "Task not found"
def convert_to_title_case(text: str) -> str:
# Replace underscores with spaces
text = text.replace("_", " ")
# Convert each word to title case (capitalize the first letter)
title_case_text = text.title()
return title_case_text
def get_difficulty_pattern_ids_and_key(rl_env, path):
csv_path = path + "/" + rl_env + ".csv"
data = pd.read_csv(csv_path)
if "Pattern" in data.columns:
key = "Pattern"
difficulty_pattern_ids = sorted(data[key].unique())
elif "Difficulty" in data.columns:
key = "Difficulty"
difficulty_pattern_ids = sorted(data[key].unique())
else:
key = None
difficulty_pattern_ids = []
return key, difficulty_pattern_ids
def filter_data(rl_env, task_id, selected_values, path):
"""
Filters the data based on the selected difficulty/pattern values.
"""
data = get_data(rl_env, task_id, path)
# If there are selected values, filter the DataFrame
if selected_values:
filter_column = "Pattern" if "Pattern" in data.columns else "Difficulty"
if filter_column == "Difficulty":
selected_values = [np.int64(sv) for sv in selected_values]
data = data[data[filter_column].isin(selected_values)]
return data
def update_filtered_data(selected_values, rl_env, task_id, path):
filtered_data = filter_data(rl_env, task_id, selected_values, path)
return filtered_data
run_update_dataset()
block = gr.Blocks(css=custom_css) # Attach the custom CSS here
with block:
with gr.Row(elem_id="header-row"):
# TITLE IMAGE
gr.HTML(
"""
<div style="width: 50%; margin: 0 auto; text-align: center;">
<img
src="https://huggingface.co/spaces/hivex-research/hivex-leaderboard/resolve/main/hivex_logo.png"
alt="hivex logo"
style="width: 100px; display: inline-block; border-radius:20px;"
/>
<h1 style="font-weight: bold;">HIVEX Leaderboard</h1>
</div>
"""
)
with gr.Row(elem_id="header-row"):
gr.HTML(
f"<p style='text-align: center;'>Total models: {get_total_models()}</p>"
)
with gr.Row(elem_id="header-row"):
gr.HTML(
f"<p style='text-align: center;'>Get started πŸš€ on our <a href='https://github.com/hivex-research/hivex'>GitHub repository</a>!</p>"
)
path_ = download_leaderboard_dataset()
# ENVIRONMENT TABS
with gr.Tabs() as tabs:
for env_index in range(0, len(hivex_envs)):
hivex_env = hivex_envs[env_index]
with gr.Tab(f"{hivex_env['title']}") as env_tabs:
dp_key, difficulty_pattern_ids = get_difficulty_pattern_ids_and_key(
hivex_env["hivex_env"], path_
)
# Check if dp_key is defined and difficulty_pattern_ids is not empty
if dp_key is not None and len(difficulty_pattern_ids) > 0:
selected_checkboxes = gr.CheckboxGroup(
[str(dp_id) for dp_id in difficulty_pattern_ids], label=dp_key
)
for task_id in range(0, hivex_env["task_count"]):
task_title = convert_to_title_case(
get_task(hivex_env["hivex_env"], task_id, path_)
)
with gr.TabItem(f"Task {task_id}: {task_title}"):
# Display initial data
data = get_data(hivex_env["hivex_env"], task_id, path_)
row_count = len(data)
gr_dataframe = gr.DataFrame(
value=data,
headers=["Verified", "User", "Model"],
datatype=["html", "markdown", "markdown"],
row_count=(row_count, "fixed"),
)
# Use gr.State to hold environment and task information
rl_env_state = gr.State(value=hivex_env["hivex_env"])
task_id_state = gr.State(value=task_id)
path_state = gr.State(value=path_)
# Add a callback to update the DataFrame when checkboxes are changed
if selected_checkboxes:
selected_checkboxes.change(
fn=update_filtered_data,
inputs=[selected_checkboxes, rl_env_state, task_id_state, path_state],
outputs=gr_dataframe,
)
with gr.Tab("Submit Model ✨") as submit_tab:
with gr.Row(elem_id="header-row"):
with gr.Column():
gr.HTML("<h1>Submit your own Results to the <a href='https://huggingface.co/spaces/hivex-research/hivex-leaderboard' target='_blank'>HIVEX Leaderboard</a> on Huggingface πŸ€—</h1>")
gr.HTML("<p>You can follow the steps in the <a href='https://github.com/hivex-research/hivex-results?tab=readme-ov-file#submit-your-own-results-to-the-hivex-leaderboard-on-huggingface-' target='_blank'>hivex-results repository</a> or stay here and follow these steps:</p>")
gr.HTML("<div style='padding-left: 20px; line-height: 1.6;'>\
<p><strong>1.</strong> Install all dependencies as described in the <a href='https://github.com/hivex-research/hivex/tree/main' target='_blank'>HIVEX repository README</a>.</p>\
<p><strong>2.</strong> Run the Train and Test Pipeline in the <a href='https://github.com/hivex-research/hivex/tree/main' target='_blank'>HIVEX repository</a>, either using <a href='https://github.com/hivex-research/hivex/tree/main?tab=readme-ov-file#-reproducing-paper-results' target='_blank'>ML-Agents</a> or with your <a href='https://github.com/hivex-research/hivex/tree/main?tab=readme-ov-file#-additional-environments-and-training-frameworks' target='_blank'>favorite framework</a>.</p>\
<p><strong>3.</strong> Clone the <a href='https://github.com/hivex-research/hivex-results/tree/master' target='_blank'>hivex-results repository</a>.</p>\
<p><strong>4.</strong> In your local hivex-results repository, add your results to the respective environment/train and environment/test folders. We have provided a <code>train_dummy_folder</code> and <code>test_dummy_folder</code> with results for training and testing on the Wind Farm Control environment.</p>\
<p><strong>5.</strong> Run <code>find_best_models.py</code>. This script generates data from your results.</p>\
<code>python tools/huggingface/find_best_models.py</code>\
<p><strong>6.</strong> Run <code>generate_hf_yaml.py</code>. Uncomment the environment data parser you need for your data. For example, for our dummy data, we need <code>generate_yaml_WFC(data['WindFarmControl'], key)</code>. This script takes the data generated in the previous step and turns it into folders including the checkpoint etc. of your training run and a <code>README.md</code>, which serves as the model card including important meta-data that is needed for the automatic fetching of the leaderboard of your model.</p>\
<code>python tools/huggingface/generate_hf_yaml.py</code>\
<p><strong>7.</strong> Finally, upload the content of the generated folder(s) to Huggingface πŸ€— as a new model.</p>\
<p><strong>8.</strong> Every 24 hours, the <a href='https://huggingface.co/spaces/hivex-research/hivex-leaderboard' target='_blank'>HIVEX Leaderboard</a> is fetching new models. We will review your model as soon as possible and add it to the verified list of models as soon as possible. If you have any questions, please feel free to reach out to p.d.siedler@gmail.com.</p>\
</div>")
gr.HTML("<h2>Congratulations, you did it πŸš€!</h2>")
scheduler = BackgroundScheduler()
scheduler.add_job(restart, "interval", seconds=86400)
scheduler.start()
block.launch()