cd1 / main.py
Ezi Ozoani
1st
4c5342d
import json
import os
from difflib import SequenceMatcher
from typing import Any, Dict, Optional, Tuple
from fastapi import FastAPI, Request, Response
from huggingface_hub import (DatasetCard, HfApi, ModelCard, comment_discussion,
create_discussion, get_discussion_details,
get_repo_discussions, login)
from huggingface_hub.utils import EntryNotFoundError
from tabulate import tabulate
KEY = os.environ.get("WEBHOOK_SECRET")
HF_TOKEN = os.environ.get("HF_TOKEN")
api = HfApi(token=HF_TOKEN)
login(HF_TOKEN)
app = FastAPI()
@app.get("/")
def read_root():
data = """
<h2 style="text-align:center">Metadata Review Bot</h2>
<p style="text-align:center">This is a demo app showing how to use webhooks to automate metadata review for models and datasets shared on the Hugging Face Hub.</p>
"""
return Response(content=data, media_type="text/html")
def similar(a, b):
"""Check similarity of two sequences"""
return SequenceMatcher(None, a, b).ratio()
def create_metadata_key_dict(card_data, repo_type: str):
shared_keys = ["tags", "license"]
if repo_type == "model":
model_keys = ["library_name", "datasets", "metrics", "co2", "pipeline_tag"]
shared_keys.extend(model_keys)
keys = shared_keys
return {key: card_data.get(key) for key in keys}
if repo_type == "dataset":
data_keys = [
"pretty_name",
"size_categories",
"task_categories",
"task_ids",
"source_datasets",
]
shared_keys.extend(data_keys)
keys = shared_keys
return {key: card_data.get(key) for key in keys}
def create_metadata_breakdown_table(desired_metadata_dictionary):
data = {k:v or "Field Missing" for k,v in desired_metadata_dictionary.items()}
metadata_fields_column = list(data.keys())
metadata_values_column = list(data.values())
table_data = list(zip(metadata_fields_column, metadata_values_column))
return tabulate(
table_data, tablefmt="github", headers=("Metadata Field", "Provided Value")
)
def calculate_grade(desired_metadata_dictionary):
metadata_values = list(desired_metadata_dictionary.values())
score = sum(1 if field else 0 for field in metadata_values) / len(metadata_values)
return round(score, 2)
def create_markdown_report(
desired_metadata_dictionary, repo_name, repo_type, score, update: bool = False
):
report = f"""# {repo_type.title()} metadata report card {"(updated)" if update else ""}
\n
This is an automatically produced metadata quality report card for {repo_name}. This report is meant as a POC!
\n
## Breakdown of metadata fields for your{repo_type}
\n
{create_metadata_breakdown_table(desired_metadata_dictionary)}
\n
You scored a metadata coverage grade of: **{score}**% \n {f"We're not angry we're just disappointed! {repo_type.title()} metadata is super important. Please try harder..."
if score <= 0.5 else f"Not too shabby! Make sure you also fill in a {repo_type} card too!"}
"""
return report
def parse_webhook_post(data: Dict[str, Any]) -> Optional[Tuple[str, str]]:
event = data["event"]
if event["scope"] != "repo":
return None
repo = data["repo"]
repo_name = repo["name"]
repo_type = repo["type"]
if repo_type not in {"model", "dataset"}:
raise ValueError("Unknown hub type")
return repo_type, repo_name
def load_repo_card_metadata(repo_type, repo_name):
if repo_type == "dataset":
try:
return DatasetCard.load(repo_name).data.to_dict()
except EntryNotFoundError:
return {}
if repo_type == "model":
try:
return ModelCard.load(repo_name).data.to_dict()
except EntryNotFoundError:
return {}
def create_or_update_report(data):
if parsed_post := parse_webhook_post(data):
repo_type, repo_name = parsed_post
else:
return Response("Unable to parse webhook data", status_code=400)
card_data = load_repo_card_metadata(repo_type, repo_name)
desired_metadata_dictionary = create_metadata_key_dict(card_data, repo_type)
score = calculate_grade(desired_metadata_dictionary)
report = create_markdown_report(
desired_metadata_dictionary, repo_name, repo_type, score, update=False
)
repo_discussions = get_repo_discussions(
repo_name,
repo_type=repo_type,
)
for discussion in repo_discussions:
if (
discussion.title == "Metadata Report Card" and discussion.status == "open"
): # An existing open report card thread
discussion_details = get_discussion_details(
repo_name, discussion.num, repo_type=repo_type
)
last_comment = discussion_details.events[-1].content
if similar(report, last_comment) <= 0.999:
report = create_markdown_report(
desired_metadata_dictionary,
repo_name,
repo_type,
score,
update=True,
)
comment_discussion(
repo_name,
discussion.num,
comment=report,
repo_type=repo_type,
)
return True
create_discussion(
repo_name,
"Metadata Report Card",
description=report,
repo_type=repo_type,
)
return True
@app.post("/webhook")
async def webhook(request: Request):
if request.method == "POST":
if request.headers.get("X-Webhook-Secret") != KEY:
return Response("Invalid secret", status_code=401)
data = await request.json()
result = create_or_update_report(data)
return "Webhook received!" if result else result