list-of-demos-test / demo_list.py
hysts's picture
hysts HF staff
Migrate from yapf to black
85576c8
import dataclasses
import datetime
import operator
import pathlib
import pandas as pd
import tqdm.auto
import yaml
from huggingface_hub import HfApi
from constants import (
OWNER_CHOICES,
SLEEP_TIME_INT_TO_STR,
SLEEP_TIME_STR_TO_INT,
WHOAMI,
)
@dataclasses.dataclass(frozen=True)
class DemoInfo:
space_id: str
url: str
title: str
owner: str
sdk: str
sdk_version: str
likes: int
status: str
last_modified: str
sleep_time: int
replicas: int
private: bool
hardware: str
suggested_hardware: str
created: str = ""
arxiv: list[str] = dataclasses.field(default_factory=list)
github: list[str] = dataclasses.field(default_factory=list)
tags: list[str] = dataclasses.field(default_factory=list)
def __post_init__(self):
object.__setattr__(self, "last_modified", DemoInfo.convert_timestamp(self.last_modified))
object.__setattr__(self, "created", DemoInfo.convert_timestamp(self.created))
@staticmethod
def convert_timestamp(timestamp: str) -> str:
try:
return datetime.datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%fZ").strftime("%Y/%m/%d %H:%M:%S")
except ValueError:
return timestamp
@classmethod
def from_space_id(cls, space_id: str) -> "DemoInfo":
api = HfApi()
space_info = api.space_info(repo_id=space_id)
card = space_info.cardData
runtime = space_info.runtime
resources = runtime["resources"]
return cls(
space_id=space_id,
url=f"https://huggingface.co/spaces/{space_id}",
title=card["title"] if "title" in card else "",
owner=space_id.split("/")[0],
sdk=card["sdk"],
sdk_version=card.get("sdk_version", ""),
likes=space_info.likes,
status=runtime["stage"],
last_modified=space_info.lastModified,
sleep_time=runtime["gcTimeout"] or 0,
replicas=resources["replicas"] if resources is not None else 0,
private=space_info.private,
hardware=runtime["hardware"]["current"] or runtime["hardware"]["requested"],
suggested_hardware=card.get("suggested_hardware", ""),
)
def get_df_from_yaml(path: pathlib.Path | str) -> pd.DataFrame:
with pathlib.Path(path).open() as f:
data = yaml.safe_load(f)
demo_info = []
for space_id in tqdm.auto.tqdm(list(data)):
base_info = DemoInfo.from_space_id(space_id)
info = DemoInfo(**(dataclasses.asdict(base_info) | data[space_id]))
demo_info.append(info)
return pd.DataFrame([dataclasses.asdict(info) for info in demo_info])
class Prettifier:
@staticmethod
def get_arxiv_link(links: list[str]) -> str:
links = [Prettifier.create_link(link.split("/")[-1], link) for link in links]
return "\n".join(links)
@staticmethod
def get_github_link(links: list[str]) -> str:
links = [Prettifier.create_link("github", link) for link in links]
return "\n".join(links)
@staticmethod
def get_tag_list(tags: list[str]) -> str:
return ", ".join(tags)
@staticmethod
def create_link(text: str, url: str) -> str:
return f'<a href={url} target="_blank">{text}</a>'
@staticmethod
def to_div(text: str | None, category_name: str) -> str:
if text is None:
text = ""
class_name = f"{category_name}-{text.lower()}"
return f'<div class="{class_name}">{text}</div>'
@staticmethod
def add_div_tag_to_replicas(replicas: int) -> str:
if replicas == 0:
return ""
if replicas == 1:
return "1"
return f'<div class="multiple-replicas">{replicas}</div>'
@staticmethod
def add_div_tag_to_sleep_time(sleep_time_s: str, hardware: str) -> str:
if hardware == "cpu-basic":
return f'<div class="sleep-time-cpu-basic">{sleep_time_s}</div>'
s = sleep_time_s.replace(" ", "-")
return f'<div class="sleep-time-{s}">{sleep_time_s}</div>'
def __call__(self, df: pd.DataFrame) -> pd.DataFrame:
new_rows = []
for _, row in df.iterrows():
new_row = dict(row) | {
"status": self.to_div(row.status, "status"),
"hardware": self.to_div(row.hardware, "hardware"),
"suggested_hardware": self.to_div(row.suggested_hardware, "hardware"),
"title": self.create_link(row.title, row.url),
"owner": self.create_link(row.owner, f"https://huggingface.co/{row.owner}"),
"sdk": self.to_div(row.sdk, "sdk"),
"sleep_time": self.add_div_tag_to_sleep_time(SLEEP_TIME_INT_TO_STR[row.sleep_time], row.hardware),
"replicas": self.add_div_tag_to_replicas(row.replicas),
"arxiv": self.get_arxiv_link(row.arxiv),
"github": self.get_github_link(row.github),
"tags": self.get_tag_list(row.tags),
}
new_rows.append(new_row)
return pd.DataFrame(new_rows, columns=df.columns)
class DemoList:
COLUMN_INFO = [
["status", "markdown"],
["hardware", "markdown"],
["title", "markdown"],
["owner", "markdown"],
["arxiv", "markdown"],
["github", "markdown"],
["likes", "number"],
["tags", "str"],
["last_modified", "str"],
["created", "str"],
["sdk", "markdown"],
["sdk_version", "str"],
["suggested_hardware", "markdown"],
["sleep_time", "markdown"],
["replicas", "markdown"],
["private", "bool"],
]
def __init__(self, df: pd.DataFrame):
self.df_raw = df
self._prettifier = Prettifier()
self.df_prettified = self._prettifier(df).loc[:, self.column_names]
@property
def column_names(self):
return list(map(operator.itemgetter(0), self.COLUMN_INFO))
@property
def column_datatype(self):
return list(map(operator.itemgetter(1), self.COLUMN_INFO))
def filter(
self,
status: list[str],
hardware: list[str],
sleep_time: list[str],
multiple_replicas: bool,
sdk: list[str],
visibility: list[str],
owner: list[str],
) -> pd.DataFrame:
df = self.df_raw.copy()
if multiple_replicas:
df = df[self.df_raw.replicas > 1]
if visibility == ["public"]:
df = df[~self.df_raw.private]
elif visibility == ["private"]:
df = df[self.df_raw.private]
df = df[
(self.df_raw.status.isin(status)) & (self.df_raw.hardware.isin(hardware)) & (self.df_raw.sdk.isin(sdk))
]
sleep_time_int = [SLEEP_TIME_STR_TO_INT[s] for s in sleep_time]
df = df[self.df_raw.sleep_time.isin(sleep_time_int)]
if set(owner) == set(OWNER_CHOICES):
pass
elif WHOAMI in owner:
df = df[self.df_raw.owner == WHOAMI]
else:
df = df[self.df_raw.owner != WHOAMI]
return self._prettifier(df).loc[:, self.column_names]