dataset-column-search-api / data_loader.py
davanstrien's picture
davanstrien HF staff
improve db
33c1203
import os
from datetime import datetime
from typing import Any, Dict, List
import pandas as pd
from dotenv import load_dotenv
from huggingface_hub import HfApi
from huggingface_hub.utils import logging
from tqdm.auto import tqdm
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
assert HF_TOKEN is not None, "You need to set HF_TOKEN in your environment variables"
USER_AGENT = os.getenv("USER_AGENT")
assert (
USER_AGENT is not None
), "You need to set USER_AGENT in your environment variables"
logger = logging.get_logger(__name__)
api = HfApi(token=HF_TOKEN)
MAX_DATASETS = None
def has_card_data(dataset):
return hasattr(dataset, "card_data")
def check_dataset_has_dataset_info(dataset):
return bool(
has_card_data(dataset)
and hasattr(dataset.card_data, "dataset_info")
and dataset.card_data.dataset_info is not None
)
def parse_single_config_dataset(data):
config_name = data.get("config_name", "default")
features = data.get("features", [])
column_names = [feature.get("name") for feature in features]
return {
"config_name": config_name,
"column_names": column_names,
"features": features,
}
def parse_multiple_config_dataset(data: List[Dict[str, Any]]):
return [parse_single_config_dataset(d) for d in data]
def parse_dataset(dataset):
hub_id = dataset.id
likes = dataset.likes
downloads = dataset.downloads
tags = dataset.tags
created_at = dataset.created_at
last_modified = dataset.last_modified
license = dataset.card_data.license
language = dataset.card_data.language
return {
"hub_id": hub_id,
"likes": likes,
"downloads": downloads,
"tags": tags,
"created_at": created_at,
"last_modified": last_modified,
"license": license,
"language": language,
}
def parsed_column_info(dataset_info):
if isinstance(dataset_info, dict):
return [parse_single_config_dataset(dataset_info)]
elif isinstance(dataset_info, list):
return parse_multiple_config_dataset(dataset_info)
return None
def ensure_list_of_strings(value):
if value is None:
return []
if isinstance(value, list):
return [str(item) for item in value]
return [str(value)]
def refresh_data() -> List[Dict[str, Any]]:
# current date as string
now = datetime.now()
# check if a file for the current date exists
if os.path.exists(f"datasets_{now.strftime('%Y-%m-%d')}.parquet"):
df = pd.read_parquet(f"datasets_{now.strftime('%Y-%m-%d')}.parquet")
return df.to_dict(orient="records")
# List all datasets
datasets = list(api.list_datasets(limit=MAX_DATASETS, full=True))
# Filter datasets with dataset info
datasets = [
dataset for dataset in tqdm(datasets) if check_dataset_has_dataset_info(dataset)
]
parsed_datasets = []
for dataset in tqdm(datasets):
try:
datasetinfo = parse_dataset(dataset)
column_info = parsed_column_info(dataset.card_data.dataset_info)
parsed_datasets.extend({**datasetinfo, **info} for info in column_info)
except Exception as e:
print(f"Error processing dataset {dataset.id}: {e}")
continue
# Convert to DataFrame
df = pd.DataFrame(parsed_datasets)
# Ensure 'license', 'tags', and 'language' are lists of strings
df["license"] = df["license"].apply(ensure_list_of_strings)
df["tags"] = df["tags"].apply(ensure_list_of_strings)
df["language"] = df["language"].apply(ensure_list_of_strings)
# Ensure 'column_names' is a list
df["column_names"] = df["column_names"].apply(
lambda x: x if isinstance(x, list) else []
)
df = df.astype({"hub_id": "string", "config_name": "string"})
# save to parquet file with current date
# df.to_parquet(f"datasets_{now.strftime('%Y-%m-%d')}.parquet")
# # save to JSON file with current date
# df.to_json(
# f"datasets_{now.strftime('%Y-%m-%d')}.json", orient="records", lines=True
# )
# return a list of dictionaries
return df.to_dict(orient="records")