Spaces:
Runtime error
Runtime error
import json | |
import os | |
from difflib import SequenceMatcher | |
from typing import Any, Dict, Optional, Tuple | |
from fastapi import FastAPI, Request, Response | |
from huggingface_hub import (DatasetCard, HfApi, ModelCard, comment_discussion, | |
create_discussion, get_discussion_details, | |
get_repo_discussions, login) | |
from huggingface_hub.utils import EntryNotFoundError | |
from tabulate import tabulate | |
KEY = os.environ.get("WEBHOOK_SECRET") | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
api = HfApi(token=HF_TOKEN) | |
login(HF_TOKEN) | |
app = FastAPI() | |
def read_root(): | |
data = """ | |
<h2 style="text-align:center">Metadata Review Bot</h2> | |
<p style="text-align:center">This is a demo app showing how to use webhooks to automate metadata review for models and datasets shared on the Hugging Face Hub.</p> | |
""" | |
return Response(content=data, media_type="text/html") | |
def similar(a, b): | |
"""Check similarity of two sequences""" | |
return SequenceMatcher(None, a, b).ratio() | |
def create_metadata_key_dict(card_data, repo_type: str): | |
shared_keys = ["tags", "license"] | |
if repo_type == "model": | |
model_keys = ["library_name", "datasets", "metrics", "co2", "pipeline_tag"] | |
shared_keys.extend(model_keys) | |
keys = shared_keys | |
return {key: card_data.get(key) for key in keys} | |
if repo_type == "dataset": | |
data_keys = [ | |
"pretty_name", | |
"size_categories", | |
"task_categories", | |
"task_ids", | |
"source_datasets", | |
] | |
shared_keys.extend(data_keys) | |
keys = shared_keys | |
return {key: card_data.get(key) for key in keys} | |
def create_metadata_breakdown_table(desired_metadata_dictionary): | |
data = {k:v or "Field Missing" for k,v in desired_metadata_dictionary.items()} | |
metadata_fields_column = list(data.keys()) | |
metadata_values_column = list(data.values()) | |
table_data = list(zip(metadata_fields_column, metadata_values_column)) | |
return tabulate( | |
table_data, tablefmt="github", headers=("Metadata Field", "Provided Value") | |
) | |
def calculate_grade(desired_metadata_dictionary): | |
metadata_values = list(desired_metadata_dictionary.values()) | |
score = sum(1 if field else 0 for field in metadata_values) / len(metadata_values) | |
return round(score, 2) | |
def create_markdown_report( | |
desired_metadata_dictionary, repo_name, repo_type, score, update: bool = False | |
): | |
report = f"""# {repo_type.title()} metadata report card {"(updated)" if update else ""} | |
\n | |
This is an automatically produced metadata quality report card for {repo_name}. This report is meant as a POC! | |
\n | |
## Breakdown of metadata fields for your{repo_type} | |
\n | |
{create_metadata_breakdown_table(desired_metadata_dictionary)} | |
\n | |
You scored a metadata coverage grade of: **{score}**% \n {f"We're not angry we're just disappointed! {repo_type.title()} metadata is super important. Please try harder..." | |
if score <= 0.5 else f"Not too shabby! Make sure you also fill in a {repo_type} card too!"} | |
""" | |
return report | |
def parse_webhook_post(data: Dict[str, Any]) -> Optional[Tuple[str, str]]: | |
event = data["event"] | |
if event["scope"] != "repo": | |
return None | |
repo = data["repo"] | |
repo_name = repo["name"] | |
repo_type = repo["type"] | |
if repo_type not in {"model", "dataset"}: | |
raise ValueError("Unknown hub type") | |
return repo_type, repo_name | |
def load_repo_card_metadata(repo_type, repo_name): | |
if repo_type == "dataset": | |
try: | |
return DatasetCard.load(repo_name).data.to_dict() | |
except EntryNotFoundError: | |
return {} | |
if repo_type == "model": | |
try: | |
return ModelCard.load(repo_name).data.to_dict() | |
except EntryNotFoundError: | |
return {} | |
def create_or_update_report(data): | |
if parsed_post := parse_webhook_post(data): | |
repo_type, repo_name = parsed_post | |
else: | |
return Response("Unable to parse webhook data", status_code=400) | |
card_data = load_repo_card_metadata(repo_type, repo_name) | |
desired_metadata_dictionary = create_metadata_key_dict(card_data, repo_type) | |
score = calculate_grade(desired_metadata_dictionary) | |
report = create_markdown_report( | |
desired_metadata_dictionary, repo_name, repo_type, score, update=False | |
) | |
repo_discussions = get_repo_discussions( | |
repo_name, | |
repo_type=repo_type, | |
) | |
for discussion in repo_discussions: | |
if ( | |
discussion.title == "Metadata Report Card" and discussion.status == "open" | |
): # An existing open report card thread | |
discussion_details = get_discussion_details( | |
repo_name, discussion.num, repo_type=repo_type | |
) | |
last_comment = discussion_details.events[-1].content | |
if similar(report, last_comment) <= 0.999: | |
report = create_markdown_report( | |
desired_metadata_dictionary, | |
repo_name, | |
repo_type, | |
score, | |
update=True, | |
) | |
comment_discussion( | |
repo_name, | |
discussion.num, | |
comment=report, | |
repo_type=repo_type, | |
) | |
return True | |
create_discussion( | |
repo_name, | |
"Metadata Report Card", | |
description=report, | |
repo_type=repo_type, | |
) | |
return True | |
async def webhook(request: Request): | |
if request.method == "POST": | |
if request.headers.get("X-Webhook-Secret") != KEY: | |
return Response("Invalid secret", status_code=401) | |
data = await request.json() | |
result = create_or_update_report(data) | |
return "Webhook received!" if result else result | |