comparator / src /hub.py
albertvillanova's picture
Schedule Space restart to update list of models
8a91492 verified
import asyncio
import io
import json
import os
import httpx
from huggingface_hub import HfApi, HfFileSystem, ModelCard, hf_hub_url
from huggingface_hub.utils import build_hf_headers
import src.constants as constants
class Client:
def __init__(self):
self.client = httpx.AsyncClient(follow_redirects=True)
async def _get(self, url, headers=None, params=None):
r = await self.client.get(url, headers=headers, params=params)
r.raise_for_status()
return r
async def get(self, url, headers=None, params=None):
try:
r = await self._get(url, headers=headers, params=params)
except httpx.ReadTimeout:
return await self.retry(self._get, url, headers=headers, params=params)
except httpx.HTTPError:
return
return r
async def retry(self, func, url, max_retries=4, max_wait_time=8, wait_time=1, **kwargs):
for _ in range(max_retries):
try:
await asyncio.sleep(wait_time)
return await func(url, **kwargs)
except httpx.ReadTimeout:
wait_time = wait_time * 2
if wait_time > max_wait_time:
print("HTTP Timeout: max retries exceeded with url:", url)
return
api = HfApi()
client = Client()
fs = HfFileSystem()
def glob(path):
paths = fs.glob(path)
return paths
async def load_json_file(path):
url = to_url(path)
r = await client.get(url)
if r is None:
return
return r.json()
async def load_jsonlines_file(path):
url = to_url(path)
r = await client.get(url, headers=build_hf_headers())
if r is None:
return
f = io.StringIO(r.text)
return [json.loads(line) for line in f]
def to_url(path):
*repo_type, org_name, ds_name, filename = path.split("/", 3)
repo_type = repo_type[0][:-1] if repo_type else None
return hf_hub_url(repo_id=f"{org_name}/{ds_name}", filename=filename, repo_type=repo_type)
async def load_model_card(model_id):
url = to_url(f"{model_id}/README.md")
r = await client.get(url)
if r is None:
return
return ModelCard(r.text, ignore_metadata_errors=True)
async def list_models(filtering=None):
params = {}
if filtering:
params["filter"] = filtering
r = await client.get(f"{constants.HF_API_URL}/models", params=params)
if r is None:
return
return r.json()
def restart_space():
space_id = os.getenv("SPACE_ID")
if space_id:
api.restart_space(repo_id=space_id, token=os.getenv("HF_TOKEN"))