Spaces:
Runtime error
Runtime error
from typing import Any, Dict, Iterator, List | |
import requests | |
from huggingface_hub import add_collection_item, create_collection | |
from tqdm.auto import tqdm | |
class DatasetSearchClient: | |
def __init__( | |
self, | |
base_url: str = "https://librarian-bots-dataset-column-search-api.hf.space", | |
): | |
self.base_url = base_url | |
def search( | |
self, columns: List[str], match_all: bool = False, page_size: int = 100 | |
) -> Iterator[Dict[str, Any]]: | |
""" | |
Search datasets using the provided API, automatically handling pagination. | |
Args: | |
columns (List[str]): List of column names to search for. | |
match_all (bool, optional): If True, match all columns. If False, match any column. Defaults to False. | |
page_size (int, optional): Number of results per page. Defaults to 100. | |
Yields: | |
Dict[str, Any]: Each dataset result from all pages. | |
Raises: | |
requests.RequestException: If there's an error with the HTTP request. | |
ValueError: If the API returns an unexpected response format. | |
""" | |
page = 1 | |
total_results = None | |
while total_results is None or (page - 1) * page_size < total_results: | |
params = { | |
"columns": columns, | |
"match_all": str(match_all).lower(), | |
"page": page, | |
"page_size": page_size, | |
} | |
try: | |
response = requests.get(f"{self.base_url}/search", params=params) | |
response.raise_for_status() | |
data = response.json() | |
if not {"total", "page", "page_size", "results"}.issubset(data.keys()): | |
raise ValueError("Unexpected response format from the API") | |
if total_results is None: | |
total_results = data["total"] | |
yield from data["results"] | |
page += 1 | |
except requests.RequestException as e: | |
raise requests.RequestException( | |
f"Error connecting to the API: {str(e)}" | |
) from e | |
except ValueError as e: | |
raise ValueError(f"Error processing API response: {str(e)}") from e | |
# Create an instance of the client | |
client = DatasetSearchClient() | |
def update_collection_for_dataset( | |
collection_name: str = None, | |
dataset_columns: List[str] = None, | |
collection_description: str = None, | |
collection_namespace: str = None, | |
): | |
if not collection_name: | |
collection = create_collection( | |
collection_name, exists_ok=True, description=collection_description | |
) | |
else: | |
collection = create_collection( | |
collection_name, | |
exists_ok=True, | |
description=collection_description, | |
namespace=collection_namespace, | |
) | |
results = list( | |
tqdm( | |
client.search(dataset_columns, match_all=True), | |
desc="Searching datasets...", | |
leave=False, | |
) | |
) | |
for result in tqdm(results, desc="Adding datasets to collection...", leave=False): | |
try: | |
add_collection_item( | |
collection.slug, result["hub_id"], item_type="dataset", exists_ok=True | |
) | |
except Exception as e: | |
print( | |
f"Error adding dataset {result['hub_id']} to collection {collection_name}: {str(e)}" | |
) | |
return f"https://huggingface.co/collections/{collection.slug}" | |
collections = [ | |
{ | |
"dataset_columns": ["chosen", "rejected", "prompt"], | |
"collection_description": "Datasets suitable for DPO based on having 'chosen', 'rejected', and 'prompt' columns. Created using librarian-bots/dataset-column-search-api", | |
"collection_name": "Direct Preference Optimization Datasets", | |
}, | |
{ | |
"dataset_columns": ["image", "chosen", "rejected"], | |
"collection_description": "Datasets suitable for Image Preference Optimization based on having 'image','chosen', and 'rejected' columns", | |
"collection_name": "Image Preference Optimization Datasets", | |
}, | |
{ | |
"collection_name": "Alpaca Style Datasets", | |
"dataset_columns": ["instruction", "input", "output"], | |
"collection_description": "Datasets which follow the Alpaca Style format based on having 'instruction', 'input', and 'output' columns", | |
}, | |
] | |
# results = [ | |
# update_collection_for_dataset(**collection, collection_namespace="librarian-bots") | |
# for collection in collections | |
# ] | |
# print(results) | |