|
import asyncio |
|
import io |
|
import json |
|
import os |
|
|
|
import httpx |
|
from huggingface_hub import HfApi, HfFileSystem, ModelCard, hf_hub_url |
|
from huggingface_hub.utils import build_hf_headers |
|
|
|
import src.constants as constants |
|
|
|
|
|
class Client: |
|
def __init__(self): |
|
self.client = httpx.AsyncClient(follow_redirects=True) |
|
|
|
async def _get(self, url, headers=None, params=None): |
|
r = await self.client.get(url, headers=headers, params=params) |
|
r.raise_for_status() |
|
return r |
|
|
|
async def get(self, url, headers=None, params=None): |
|
try: |
|
r = await self._get(url, headers=headers, params=params) |
|
except httpx.ReadTimeout: |
|
return await self.retry(self._get, url, headers=headers, params=params) |
|
except httpx.HTTPError: |
|
return |
|
return r |
|
|
|
async def retry(self, func, url, max_retries=4, max_wait_time=8, wait_time=1, **kwargs): |
|
for _ in range(max_retries): |
|
try: |
|
await asyncio.sleep(wait_time) |
|
return await func(url, **kwargs) |
|
except httpx.ReadTimeout: |
|
wait_time = wait_time * 2 |
|
if wait_time > max_wait_time: |
|
print("HTTP Timeout: max retries exceeded with url:", url) |
|
return |
|
|
|
|
|
api = HfApi() |
|
client = Client() |
|
fs = HfFileSystem() |
|
|
|
|
|
def glob(path): |
|
paths = fs.glob(path) |
|
return paths |
|
|
|
|
|
async def load_json_file(path): |
|
url = to_url(path) |
|
r = await client.get(url) |
|
if r is None: |
|
return |
|
return r.json() |
|
|
|
|
|
async def load_jsonlines_file(path): |
|
url = to_url(path) |
|
r = await client.get(url, headers=build_hf_headers()) |
|
if r is None: |
|
return |
|
f = io.StringIO(r.text) |
|
return [json.loads(line) for line in f] |
|
|
|
|
|
def to_url(path): |
|
*repo_type, org_name, ds_name, filename = path.split("/", 3) |
|
repo_type = repo_type[0][:-1] if repo_type else None |
|
return hf_hub_url(repo_id=f"{org_name}/{ds_name}", filename=filename, repo_type=repo_type) |
|
|
|
|
|
async def load_model_card(model_id): |
|
url = to_url(f"{model_id}/README.md") |
|
r = await client.get(url) |
|
if r is None: |
|
return |
|
return ModelCard(r.text, ignore_metadata_errors=True) |
|
|
|
|
|
async def list_models(filtering=None): |
|
params = {} |
|
if filtering: |
|
params["filter"] = filtering |
|
r = await client.get(f"{constants.HF_API_URL}/models", params=params) |
|
if r is None: |
|
return |
|
return r.json() |
|
|
|
|
|
def restart_space(): |
|
space_id = os.getenv("SPACE_ID") |
|
if space_id: |
|
api.restart_space(repo_id=space_id, token=os.getenv("HF_TOKEN")) |
|
|