derek-thomas's picture
derek-thomas HF staff
Adding filter for ids
6621d73
raw
history blame
3.78 kB
import os
import numpy as np
import pandas as pd
import requests
from datasets import Dataset, DownloadMode, load_dataset
from gradio_client import Client
from src.my_logger import setup_logger
SUBREDDIT = os.environ["SUBREDDIT"]
USERNAME = os.environ["USERNAME"]
OG_DATASET = f"{USERNAME}/dataset-creator-reddit-{SUBREDDIT}"
PROCESSED_DATASET = os.environ['PROCESSED_DATASET']
embeddings_space = f"{USERNAME}/nomic-embeddings"
FILTER_IDS_URL = "https://huggingface.co/spaces/reddit-tools-HF/dataset-creator-reddit-bestofredditorupdates/raw/main/filter_ids.json"
logger = setup_logger(__name__)
def load_datasets():
# Get latest datasets locally
logger.info(f"Trying to download {PROCESSED_DATASET}")
dataset = load_dataset(PROCESSED_DATASET, download_mode=DownloadMode.FORCE_REDOWNLOAD)
logger.info(f"Loaded {PROCESSED_DATASET}")
logger.info(f"Trying to download {OG_DATASET}")
original_dataset = load_dataset(OG_DATASET, download_mode=DownloadMode.FORCE_REDOWNLOAD)
logger.info(f"Loaded {OG_DATASET}")
return dataset, original_dataset
def merge_and_update_datasets(dataset, original_dataset):
# Get client
client = Client(embeddings_space)
# Merge and figure out which rows need to be updated with embeddings
odf = original_dataset['train'].to_pandas()
df = dataset['train'].to_pandas()
# Filter ODF in-case we missed any
odf = remove_filtered_rows(odf, FILTER_IDS_URL)
# Step 1: Merge df onto odf
# We'll bring in 'content' and 'embedding' from df to compare and possibly update 'embedding'
merged_df = pd.merge(odf, df[['id', 'content', 'embedding']], on='id', how='left', suffixes=('_odf', ''))
updated_row_count = len(merged_df[merged_df.content != merged_df.content_odf])
# Step 2: Compare 'content' from odf and df, update 'embedding' if they differ
merged_df['embedding'] = np.where(merged_df['content_odf'] != merged_df['content'], None, merged_df['embedding'])
# Step 3: Cleanup - keep only the necessary columns.
# Assuming you want to keep 'content' from 'odf' and the updated 'embedding', and drop the rest
merged_df = merged_df.drop(columns=['content', 'new', 'updated']) # Update columns to match df
merged_df.rename(columns={'content_odf': 'content'}, inplace=True) # Rename 'content_odf' back to 'content'
logger.info(f"Updating {updated_row_count} rows...")
# Iterate over the DataFrame rows where 'embedding' is None
for index, row in merged_df[merged_df['embedding'].isnull()].iterrows():
# Update 'embedding' for the current row using our function
merged_df.at[index, 'embedding'] = update_embeddings(content=row['content'], client=client)
dataset['train'] = Dataset.from_pandas(merged_df)
logger.info(f"Updated {updated_row_count} rows")
return dataset, updated_row_count
def remove_filtered_rows(df: pd.DataFrame, url: str) -> pd.DataFrame:
"""
Removes rows from the DataFrame where the 'id' is present in the JSON file at the given URL.
:param df: Input DataFrame to be filtered.
:param url: URL to the JSON file containing the filter IDs.
:return: DataFrame with rows containing IDs present in the JSON file removed.
"""
# Load filter IDs from JSON file at the URL
response = requests.get(url)
filter_ids = response.json()
logger.info(f"Loaded {len(filter_ids)} IDs from {url}")
# Remove the rows with IDs present in filter_ids
filtered_df = df[~df['id'].astype(str).isin(filter_ids)]
logger.info(f"Filtered {len(df) - len(filtered_df)} rows from the DataFrame")
return filtered_df
def update_embeddings(content, client):
embedding = client.predict('search_document: ' + content, api_name="/embed")
return np.array(embedding)