Spaces:
Runtime error
Runtime error
File size: 4,686 Bytes
7a3d7a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import os
from typing import Dict, Any, Optional, List
import re
from abc import ABC, abstractmethod
from huggingface_hub import (ModelCard, comment_discussion,
create_discussion, get_discussion_details,
get_repo_discussions)
import markdown
from bs4 import BeautifulSoup
from tabulate import tabulate
from difflib import SequenceMatcher
KEY = os.environ.get("KEY")
def similar(a, b):
"""Check similarity of two sequences"""
return SequenceMatcher(None, a, b).ratio()
class ComplianceCheck(ABC):
def __init__(self, name):
self.name = name
@abstractmethod
def check(self, card: BeautifulSoup) -> bool:
raise NotImplementedError
class ModelProviderIdentityCheck(ComplianceCheck):
def __init__(self):
super().__init__("Identity and Contact Details")
def check(self, card: BeautifulSoup):
developed_by_li = card.findAll(text=re.compile("Developed by"))[0].parent.parent
developed_by = list(developed_by_li.children)[1].text.strip()
if developed_by == "[More Information Needed]":
return False
else:
return True
class IntendedPurposeCheck(ComplianceCheck):
def __init__(self):
super().__init__("Intended Purpose")
def check(self, card: BeautifulSoup):
# direct_use = card.find_all("h2", text="Direct Use")[0]
#
# if developed_by == "[More Information Needed]":
# return False
# else:
return False
compliance_checks = [
ModelProviderIdentityCheck(),
IntendedPurposeCheck()
# "General Limitations",
# "Computational and Hardware Requirements",
# "Carbon Emissions"
]
def parse_webhook_post(data: Dict[str, Any]) -> Optional[str]:
event = data["event"]
if event["scope"] != "repo":
return None
repo = data["repo"]
repo_name = repo["name"]
repo_type = repo["type"]
if repo_type != "model":
raise ValueError("Incorrect repo type.")
return repo_name
def check_compliance(comp_checks: List[ComplianceCheck], card: BeautifulSoup) -> Dict[str, bool]:
return {c.name: c.check(card) for c in comp_checks}
def run_compliance_check(repo_name):
card_data: ModelCard = ModelCard.load(repo_id_or_path=repo_name)
card_html = markdown.markdown(card_data.content)
card_soup = BeautifulSoup(card_html, features="html.parser")
compliance_results = check_compliance(compliance_checks, card_soup)
return compliance_results
def create_metadata_breakdown_table(compliance_check_dictionary):
data = {k: v for k, v in compliance_check_dictionary.items()}
metadata_fields_column = list(data.keys())
metadata_values_column = list(data.values())
table_data = list(zip(metadata_fields_column, metadata_values_column))
return tabulate(
table_data, tablefmt="github", headers=("Compliance Check", "Present")
)
def create_markdown_report(
desired_metadata_dictionary, repo_name, update: bool = False
):
report = f"""# Model Card Regulatory Compliance report card {"(updated)" if update else ""}
\n
This is an automatically produced model card regulatory compliance report card for {repo_name}.
This report is meant as a POC!
\n
## Breakdown of metadata fields for your model
\n
{create_metadata_breakdown_table(desired_metadata_dictionary)}
\n
"""
return report
def create_or_update_report(compliance_check, repo_name):
report = create_markdown_report(
compliance_check, repo_name, update=False
)
repo_discussions = get_repo_discussions(
repo_name,
repo_type="model",
)
for discussion in repo_discussions:
if (
discussion.title == "Metadata Report Card" and discussion.status == "open"
): # An existing open report card thread
discussion_details = get_discussion_details(
repo_name, discussion.num, repo_type="model"
)
last_comment = discussion_details.events[-1].content
if similar(report, last_comment) <= 0.999:
report = create_markdown_report(
compliance_check,
repo_name,
update=True,
)
comment_discussion(
repo_name,
discussion.num,
comment=report,
repo_type="model",
)
return True
create_discussion(
repo_name,
"Model Card Regulatory Compliance Report Card",
description=report,
repo_type="model",
)
return True
|