Spaces:
Runtime error
Runtime error
import os | |
import json | |
import requests | |
import gradio as gr | |
import pandas as pd | |
from huggingface_hub import HfApi, hf_hub_download, snapshot_download | |
from huggingface_hub.repocard import metadata_load | |
from apscheduler.schedulers.background import BackgroundScheduler | |
from tqdm.contrib.concurrent import thread_map | |
from utils import make_clickable_model, make_clickable_user | |
DATASET_REPO_URL = ( | |
"https://huggingface.co/datasets/hivex-research/hivex-leaderboard-data" | |
) | |
DATASET_REPO_ID = "hivex-research/hivex-leaderboard-data" | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
block = gr.Blocks() | |
api = HfApi(token=HF_TOKEN) | |
# .tab-buttons button { | |
# font-size: 20px; | |
# } | |
custom_css = """ | |
/* Full width space */ | |
.gradio-container { | |
max-width: 95%!important; | |
} | |
""" | |
hivex_envs = [ | |
{ | |
"title": "Wind Farm Control", | |
"hivex_env": "hivex-wind-farm-control", | |
"task_count": 2, | |
}, | |
{ | |
"title": "Wildfire Resource Management", | |
"hivex_env": "hivex-wildfire-resource-management", | |
"task_count": 3, | |
}, | |
{ | |
"title": "Drone-Based Reforestation", | |
"hivex_env": "hivex-drone-based-reforestation", | |
"task_count": 7, | |
}, | |
{ | |
"title": "Ocean Plastic Collection", | |
"hivex_env": "hivex-ocean-plastic-collection", | |
"task_count": 4, | |
}, | |
{ | |
"title": "Aerial Wildfire Suppression", | |
"hivex_env": "hivex-aerial-wildfire-suppression", | |
"task_count": 9, | |
}, | |
] | |
def restart(): | |
print("RESTART") | |
api.restart_space(repo_id="hivex-research/hivex-leaderboard") | |
def download_leaderboard_dataset(): | |
path = snapshot_download(repo_id=DATASET_REPO_ID, repo_type="dataset") | |
return path | |
def get_total_models(): | |
total_models = 0 | |
for hivex_env in hivex_envs: | |
model_ids = get_model_ids(hivex_env["hivex_env"]) | |
total_models += len(model_ids) | |
return total_models | |
def get_model_ids(hivex_env): | |
api = HfApi() | |
models = api.list_models(filter=hivex_env) | |
model_ids = [x.modelId for x in models] | |
return model_ids | |
def get_metadata(model_id): | |
try: | |
readme_path = hf_hub_download(model_id, filename="README.md", etag_timeout=180) | |
return metadata_load(readme_path) | |
except requests.exceptions.HTTPError: | |
# 404 README.md not found | |
return None | |
def update_leaderboard_dataset_parallel(hivex_env, path): | |
# Get model ids associated with hivex_env | |
model_ids = get_model_ids(hivex_env) | |
def process_model(model_id): | |
meta = get_metadata(model_id) | |
# LOADED_MODEL_METADATA[model_id] = meta if meta is not None else '' | |
if meta is None: | |
return None | |
user_id = model_id.split("/")[0] | |
row = {} | |
row["User"] = user_id | |
row["Model"] = model_id | |
results = meta["model-index"][0]["results"][0] | |
row["Task"] = results["task"]["task-id"] | |
results_metrics = results["metrics"] | |
for result in results_metrics: | |
row[result["name"]] = float(result["value"].split("+/-")[0].strip()) | |
return row | |
data = list(thread_map(process_model, model_ids, desc="Processing models")) | |
# Filter out None results (models with no metadata) | |
data = [row for row in data if row is not None] | |
# ranked_dataframe = rank_dataframe(pd.DataFrame.from_records(data)) | |
ranked_dataframe = pd.DataFrame.from_records(data) | |
new_history = ranked_dataframe | |
file_path = path + "/" + hivex_env + ".csv" | |
new_history.to_csv(file_path, index=False) | |
return ranked_dataframe | |
def run_update_dataset(): | |
path_ = download_leaderboard_dataset() | |
for i in range(0, len(hivex_envs)): | |
hivex_env = hivex_envs[i] | |
update_leaderboard_dataset_parallel(hivex_env["hivex_env"], path_) | |
api.upload_folder( | |
folder_path=path_, | |
repo_id="hivex-research/hivex-leaderboard-data", | |
repo_type="dataset", | |
commit_message="Update dataset", | |
) | |
def get_data(rl_env, task, path) -> pd.DataFrame: | |
""" | |
Get data from rl_env, filter by the given task, and drop the Task column. | |
:return: filtered data as a pandas DataFrame without the Task column | |
""" | |
csv_path = path + "/" + rl_env + ".csv" | |
data = pd.read_csv(csv_path) | |
# Filter the data to only include rows where the "Task" column matches the given task | |
filtered_data = data[data["Task"] == task] | |
# Drop the "Task" column | |
filtered_data = filtered_data.drop(columns=["Task"]) | |
# Convert User and Model columns to clickable links | |
for index, row in filtered_data.iterrows(): | |
user_id = row["User"] | |
filtered_data.loc[index, "User"] = make_clickable_user(user_id) | |
model_id = row["Model"] | |
filtered_data.loc[index, "Model"] = make_clickable_model(model_id) | |
return filtered_data | |
run_update_dataset() | |
block = gr.Blocks(css=custom_css) | |
with block: | |
with gr.Row(elem_id="header-row"): | |
# TITLE IMAGE | |
gr.HTML( | |
""" | |
<div align="left"> | |
<div style="border-radius: 20px; width: 50%;"> | |
<img | |
src="https://huggingface.co/spaces/hivex-research/hivex-leaderboard/resolve/main/hivex_thumb_cropped.png" | |
alt="hivex header image" | |
style="width: 100%; border-radius: 20px;" | |
/> | |
</div> | |
</div> | |
""" | |
) | |
with gr.Row(elem_id="header-row"): | |
gr.HTML(f"<h1>HIVEX-Leaderboard</h1>") | |
with gr.Row(elem_id="header-row"): | |
gr.HTML(f"<p>Total models: {get_total_models()}</p>") | |
with gr.Row(elem_id="header-row"): | |
gr.HTML(f"<p>To get started, please check out <a href='https://github.com/hivex-research/hivex'>our GitHub repository</a>.</p>") | |
path_ = download_leaderboard_dataset() | |
# gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text") | |
# ENVIRONMENT TABS | |
with gr.Tabs() as tabs: # elem_classes="tab-buttons" | |
for i in range(0, len(hivex_envs)): | |
hivex_env = hivex_envs[i] | |
with gr.Tab(hivex_env["title"]) as env_tabs: | |
# TASK TABS | |
for j in range(0, hivex_env["task_count"]): | |
task = "Task " + str(j + 1) | |
with gr.TabItem(f"Task {j}"): | |
with gr.Row(): | |
gr_dataframe = gr.components.Dataframe(value=get_data(hivex_env["hivex_env"], j, path_), headers=["User", "Model"], datatype=["markdown", "markdown"], row_count=(100, 'fixed')) | |
scheduler = BackgroundScheduler() | |
scheduler.add_job(restart, "interval", seconds=86400) | |
scheduler.start() | |
block.launch() | |