import asyncio import io import json import os import httpx from huggingface_hub import HfApi, HfFileSystem, ModelCard, hf_hub_url from huggingface_hub.utils import build_hf_headers import src.constants as constants class Client: def __init__(self): self.client = httpx.AsyncClient(follow_redirects=True) async def _get(self, url, headers=None, params=None): r = await self.client.get(url, headers=headers, params=params) r.raise_for_status() return r async def get(self, url, headers=None, params=None): try: r = await self._get(url, headers=headers, params=params) except httpx.ReadTimeout: return await self.retry(self._get, url, headers=headers, params=params) except httpx.HTTPError: return return r async def retry(self, func, url, max_retries=4, max_wait_time=8, wait_time=1, **kwargs): for _ in range(max_retries): try: await asyncio.sleep(wait_time) return await func(url, **kwargs) except httpx.ReadTimeout: wait_time = wait_time * 2 if wait_time > max_wait_time: print("HTTP Timeout: max retries exceeded with url:", url) return api = HfApi() client = Client() fs = HfFileSystem() def glob(path): paths = fs.glob(path) return paths async def load_json_file(path): url = to_url(path) r = await client.get(url) if r is None: return return r.json() async def load_jsonlines_file(path): url = to_url(path) r = await client.get(url, headers=build_hf_headers()) if r is None: return f = io.StringIO(r.text) return [json.loads(line) for line in f] def to_url(path): *repo_type, org_name, ds_name, filename = path.split("/", 3) repo_type = repo_type[0][:-1] if repo_type else None return hf_hub_url(repo_id=f"{org_name}/{ds_name}", filename=filename, repo_type=repo_type) async def load_model_card(model_id): url = to_url(f"{model_id}/README.md") r = await client.get(url) if r is None: return return ModelCard(r.text, ignore_metadata_errors=True) async def list_models(filtering=None): params = {} if filtering: params["filter"] = filtering r = await client.get(f"{constants.HF_API_URL}/models", params=params) if r is None: return return r.json() def restart_space(): space_id = os.getenv("SPACE_ID") if space_id: api.restart_space(repo_id=space_id, token=os.getenv("HF_TOKEN"))