import os import json import requests import numpy as np import gradio as gr import pandas as pd from huggingface_hub import HfApi, hf_hub_download, snapshot_download from huggingface_hub.repocard import metadata_load from apscheduler.schedulers.background import BackgroundScheduler from tqdm.contrib.concurrent import thread_map from utils import make_clickable_model, make_clickable_user from typing import List # Add this import statement DATASET_REPO_URL = ( "https://huggingface.co/datasets/hivex-research/hivex-leaderboard-data" ) DATASET_REPO_ID = "hivex-research/hivex-leaderboard-data" HF_TOKEN = os.environ.get("HF_TOKEN") block = gr.Blocks() api = HfApi(token=HF_TOKEN) # .tab-buttons button { # font-size: 20px; # } custom_css = """ /* Full width space */ .gradio-container { max-width: 95%!important; } .gr-dataframe table { width: auto; } .gr-dataframe td, .gr-dataframe th { white-space: nowrap; text-overflow: ellipsis; overflow: hidden; width: 1%; } """ # Pattern: 0 Default, 1 Grid, 2 Chain, 3 Circle, 4 Square, 5 Cross, 6 Two_Rows, 7 Field, 8 Random pattern_map = { 0: "0: Default", 1: "1: Grid", 2: "2: Chain", 3: "3: Circle", 4: "4: Square", 5: "5: Cross", 6: "6: Two Rows", 7: "7: Field", 8: "8: Random", } hivex_envs = [ { "title": "Wind Farm Control", "hivex_env": "hivex-wind-farm-control", "task_count": 2, }, { "title": "Wildfire Resource Management", "hivex_env": "hivex-wildfire-resource-management", "task_count": 3, }, { "title": "Drone-Based Reforestation", "hivex_env": "hivex-drone-based-reforestation", "task_count": 7, }, { "title": "Ocean Plastic Collection", "hivex_env": "hivex-ocean-plastic-collection", "task_count": 4, }, { "title": "Aerial Wildfire Suppression", "hivex_env": "hivex-aerial-wildfire-suppression", "task_count": 9, }, ] verified_users = ["hivex-research"] verified_models = [{"user": "hivex-research", "model": "my_model"}] def restart(): print("RESTART") api.restart_space(repo_id="hivex-research/hivex-leaderboard") def download_leaderboard_dataset(): path = snapshot_download(repo_id=DATASET_REPO_ID, repo_type="dataset") return path def get_total_models(): total_models = 0 for hivex_env in hivex_envs: model_ids = get_model_ids(hivex_env["hivex_env"]) total_models += len(model_ids) return total_models def get_model_ids(hivex_env): api = HfApi() models = api.list_models(filter=hivex_env) model_ids = [x.modelId for x in models] return model_ids def get_metadata(model_id): try: readme_path = hf_hub_download(model_id, filename="README.md", etag_timeout=180) return metadata_load(readme_path) except requests.exceptions.HTTPError: # 404 README.md not found return None def update_leaderboard_dataset_parallel(hivex_env, path): # Get model ids associated with hivex_env model_ids = get_model_ids(hivex_env) def process_model(model_id): meta = get_metadata(model_id) # LOADED_MODEL_METADATA[model_id] = meta if meta is not None else '' if meta is None: return None user_id = model_id.split("/")[0] row = {} row["Verified"] = "✅" if user_id in verified_users else "❌" row["User"] = user_id row["Model"] = model_id results = meta["model-index"][0]["results"][0] row["Task-ID"] = results["task"]["task-id"] row["Task"] = results["task"]["name"] if "pattern-id" in results["task"] or "difficulty-id" in results["task"]: key = "Pattern" if "pattern-id" in results["task"] else "Difficulty" row[key] = ( pattern_map[results["task"]["pattern-id"]] if "pattern-id" in results["task"] else results["task"]["difficulty-id"] ) results_metrics = results["metrics"] for result in results_metrics: row[result["name"]] = float(result["value"].split("+/-")[0].strip()) return row data = list(thread_map(process_model, model_ids, desc="Processing models")) # Filter out None results (models with no metadata) data = [row for row in data if row is not None] # ranked_dataframe = rank_dataframe(pd.DataFrame.from_records(data)) ranked_dataframe = pd.DataFrame.from_records(data) new_history = ranked_dataframe file_path = path + "/" + hivex_env + ".csv" new_history.to_csv(file_path, index=False) return ranked_dataframe def run_update_dataset(): path_ = download_leaderboard_dataset() for i in range(0, len(hivex_envs)): hivex_env = hivex_envs[i] update_leaderboard_dataset_parallel(hivex_env["hivex_env"], path_) api.upload_folder( folder_path=path_, repo_id="hivex-research/hivex-leaderboard-data", repo_type="dataset", commit_message="Update dataset", ) def get_data(rl_env, task_id, path) -> pd.DataFrame: """ Get data from rl_env, filter by the given task_id, and drop the Task-ID column. Also drops any columns that have no data (all values are NaN) or all values are 0.0. :return: filtered data as a pandas DataFrame without the Task-ID column """ csv_path = path + "/" + rl_env + ".csv" data = pd.read_csv(csv_path) # Filter the data to only include rows where the "Task-ID" column matches the given task_id filtered_data = data[data["Task-ID"] == task_id] # Drop the "Task-ID" column filtered_data = filtered_data.drop(columns=["Task-ID"]) # Drop the "Task" column filtered_data = filtered_data.drop(columns=["Task"]) # Drop columns that have no data (all values are NaN) filtered_data = filtered_data.dropna(axis=1, how="all") # Drop columns where all values are 0.0 filtered_data = filtered_data.loc[:, (filtered_data != 0.0).any(axis=0)] # Convert User and Model columns to clickable links for index, row in filtered_data.iterrows(): user_id = row["User"] filtered_data.loc[index, "User"] = make_clickable_user(user_id) model_id = row["Model"] filtered_data.loc[index, "Model"] = make_clickable_model(model_id) return filtered_data def get_task(rl_env, task_id, path) -> str: """ Get the task name from the leaderboard dataset based on the rl_env and task_id. :return: The task name as a string """ csv_path = path + "/" + rl_env + ".csv" data = pd.read_csv(csv_path) # Filter the data to find the row with the matching task_id task_row = data[data["Task-ID"] == task_id] # Check if the task exists and return the task name if not task_row.empty: task_name = task_row.iloc[0]["Task"] return task_name else: return "Task not found" def convert_to_title_case(text: str) -> str: # Replace underscores with spaces text = text.replace("_", " ") # Convert each word to title case (capitalize the first letter) title_case_text = text.title() return title_case_text def get_difficulty_pattern_ids_and_key(rl_env, path): csv_path = path + "/" + rl_env + ".csv" data = pd.read_csv(csv_path) if "Pattern" in data.columns: key = "Pattern" difficulty_pattern_ids = sorted(data[key].unique()) elif "Difficulty" in data.columns: key = "Difficulty" difficulty_pattern_ids = sorted(data[key].unique()) else: key = None difficulty_pattern_ids = [] return key, difficulty_pattern_ids def filter_data(rl_env, task_id, selected_values, path): """ Filters the data based on the selected difficulty/pattern values. """ data = get_data(rl_env, task_id, path) # If there are selected values, filter the DataFrame if selected_values: filter_column = "Pattern" if "Pattern" in data.columns else "Difficulty" if filter_column == "Difficulty": selected_values = [np.int64(sv) for sv in selected_values] data = data[data[filter_column].isin(selected_values)] return data def update_filtered_data(selected_values, rl_env, task_id, path): filtered_data = filter_data(rl_env, task_id, selected_values, path) return filtered_data run_update_dataset() block = gr.Blocks(css=custom_css) # Attach the custom CSS here with block: with gr.Row(elem_id="header-row"): # TITLE IMAGE gr.HTML( """
Total models: {get_total_models()}
" ) with gr.Row(elem_id="header-row"): gr.HTML( f"Get started 🚀 on our GitHub repository!
" ) path_ = download_leaderboard_dataset() # ENVIRONMENT TABS with gr.Tabs() as tabs: for env_index in range(0, len(hivex_envs)): hivex_env = hivex_envs[env_index] with gr.Tab(f"{hivex_env['title']}") as env_tabs: dp_key, difficulty_pattern_ids = get_difficulty_pattern_ids_and_key( hivex_env["hivex_env"], path_ ) # Check if dp_key is defined and difficulty_pattern_ids is not empty if dp_key is not None and len(difficulty_pattern_ids) > 0: selected_checkboxes = gr.CheckboxGroup( [str(dp_id) for dp_id in difficulty_pattern_ids], label=dp_key ) for task_id in range(0, hivex_env["task_count"]): task_title = convert_to_title_case( get_task(hivex_env["hivex_env"], task_id, path_) ) with gr.TabItem(f"Task {task_id}: {task_title}"): # Display initial data data = get_data(hivex_env["hivex_env"], task_id, path_) row_count = len(data) gr_dataframe = gr.DataFrame( value=data, headers=["Verified", "User", "Model"], datatype=["html", "markdown", "markdown"], row_count=(row_count, "fixed"), ) # Use gr.State to hold environment and task information rl_env_state = gr.State(value=hivex_env["hivex_env"]) task_id_state = gr.State(value=task_id) path_state = gr.State(value=path_) # Add a callback to update the DataFrame when checkboxes are changed if selected_checkboxes: selected_checkboxes.change( fn=update_filtered_data, inputs=[selected_checkboxes, rl_env_state, task_id_state, path_state], outputs=gr_dataframe, ) with gr.Tab("Submit Model ✨") as submit_tab: with gr.Row(elem_id="header-row"): with gr.Column(): gr.HTML("You can follow the steps in the hivex-results repository or stay here and follow these steps:
") gr.HTML("1. Install all dependencies as described in the HIVEX repository README.
\2. Run the Train and Test Pipeline in the HIVEX repository, either using ML-Agents or with your favorite framework.
\3. Clone the hivex-results repository.
\4. In your local hivex-results repository, add your results to the respective environment/train and environment/test folders. We have provided a train_dummy_folder
and test_dummy_folder
with results for training and testing on the Wind Farm Control environment.
5. Run find_best_models.py
. This script generates data from your results.
python tools/huggingface/find_best_models.py
\
6. Run generate_hf_yaml.py
. Uncomment the environment data parser you need for your data. For example, for our dummy data, we need generate_yaml_WFC(data['WindFarmControl'], key)
. This script takes the data generated in the previous step and turns it into folders including the checkpoint etc. of your training run and a README.md
, which serves as the model card including important meta-data that is needed for the automatic fetching of the leaderboard of your model.
python tools/huggingface/generate_hf_yaml.py
\
7. Finally, upload the content of the generated folder(s) to Huggingface 🤗 as a new model.
\8. Every 24 hours, the HIVEX Leaderboard is fetching new models. We will review your model as soon as possible and add it to the verified list of models as soon as possible. If you have any questions, please feel free to reach out to p.d.siedler@gmail.com.
\