Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import json | |
import random | |
import httpx | |
import polars as pl | |
from huggingface_hub import list_datasets | |
from tqdm import tqdm | |
from tqdm.asyncio import tqdm_asyncio | |
# Initialize the HTTP client | |
client = httpx.AsyncClient(timeout=60, http2=True) | |
async def generate_dataset_prompt(dataset_name, num_rows=2): | |
try: | |
base_url = "https://datasets-server.huggingface.co" | |
# Get splits and configs | |
splits_url = f"{base_url}/splits?dataset={dataset_name}" | |
splits_response = await client.get(splits_url) | |
splits_data = splits_response.json() | |
if not splits_data.get("splits"): | |
return None | |
# Get the first config and split | |
first_split = splits_data["splits"][0] | |
config_name = first_split["config"] | |
split_name = first_split["split"] | |
# Get dataset info for the specific config | |
info_url = f"{base_url}/info?dataset={dataset_name}&config={config_name}" | |
info_response = await client.get(info_url) | |
info_data = info_response.json() | |
# Get first rows for the specific config and split | |
first_rows_url = f"{base_url}/first-rows?dataset={dataset_name}&config={config_name}&split={split_name}" | |
first_rows_response = await client.get(first_rows_url) | |
first_rows_data = first_rows_response.json() | |
# Get size information | |
size_url = f"{base_url}/size?dataset={dataset_name}" | |
size_response = await client.get(size_url) | |
size_data = size_response.json() | |
# Extract relevant information | |
dataset_info = info_data.get("dataset_info", {}) | |
features = dataset_info.get("features", {}) | |
splits = dataset_info.get("splits", {}) | |
# Calculate total examples and size | |
total_examples = sum(split.get("num_examples", 0) for split in splits.values()) | |
total_size = ( | |
size_data.get("size", {}) | |
.get("dataset", {}) | |
.get("num_bytes_original_files", 0) | |
) | |
# Format features | |
def format_feature(name, details): | |
if isinstance(details, dict): | |
feature_type = details.get( | |
"dtype", details.get("_type", "unknown type") | |
) | |
elif isinstance(details, list): | |
feature_type = "list" | |
else: | |
feature_type = str(type(details).__name__) | |
return f"- {name} ({feature_type})" | |
formatted_features = "\n".join( | |
format_feature(name, details) for name, details in features.items() | |
) | |
# Format sample data (specified number of rows) | |
sample_data = json.dumps(first_rows_data.get("rows", [])[:num_rows], indent=2) | |
# Create the formatted prompt | |
prompt = f""" | |
Dataset: "{dataset_name}" | |
Features: | |
{formatted_features} | |
Splits and Configs: | |
{', '.join(f"{split['config']}/{split['split']}" for split in splits_data['splits'])} | |
Size Statistics: | |
Total Examples: {total_examples} | |
Split Sizes: {', '.join(f"{split}: {info['num_examples']}" for split, info in splits.items())} | |
Data Sample ({num_rows} rows out of {total_examples} total): | |
{sample_data} | |
""" | |
return prompt.strip() | |
except Exception as e: | |
print(f"Error for {dataset_name}: {e}") | |
return None | |
async def process_batch(batch): | |
results = await tqdm_asyncio.gather( | |
*[generate_dataset_prompt(dataset) for dataset in batch], leave=False | |
) | |
return [ | |
(dataset_id, prompt) | |
for dataset_id, prompt in zip(batch, results) | |
if prompt is not None | |
] | |
async def prep_data(sample_size=200_000, min_likes=1): | |
# Load the dataset containing dataset IDs | |
df = pl.read_parquet( | |
"hf://datasets/davanstrien/dataset-viewer-descriptions-processed/data/train-00000-of-00001.parquet" | |
) | |
# remove datasets that are already in the train or test set we can remove this later once the model works okay | |
in_train_or_test = set(df["dataset_id"].unique().to_list()) | |
# Get all datasets | |
datasets = [ | |
dataset for dataset in list_datasets() if dataset.id not in in_train_or_test | |
] | |
# filter to datasets with 1 or more likes | |
if min_likes: | |
datasets = [dataset for dataset in datasets if dataset.likes >= min_likes] | |
datasets = [dataset.id for dataset in datasets] | |
# Sample datasets (adjust the number as needed) | |
datasets = random.sample(datasets, min(sample_size, len(datasets))) | |
# Process datasets in batches of 100 | |
batch_size = 500 | |
all_results = [] | |
for i in tqdm(range(0, len(datasets), batch_size), desc="Processing batches"): | |
batch = datasets[i : i + batch_size] | |
batch_results = await process_batch(batch) | |
all_results.extend(batch_results) | |
# Optional: Save intermediate results | |
if len(all_results) % 1000 == 0: | |
intermediate_df = pl.DataFrame( | |
{ | |
"dataset_id": [row[0] for row in all_results], | |
"formatted_prompt": [row[1] for row in all_results], | |
} | |
) | |
intermediate_df.write_parquet( | |
f"dataset_prompts_intermediate_{len(all_results)}.parquet" | |
) | |
return pl.DataFrame( | |
{ | |
"dataset_id": [row[0] for row in all_results], | |
"formatted_prompt": [row[1] for row in all_results], | |
} | |
) | |