Spaces:
Runtime error
Runtime error
import os | |
import numpy as np | |
import pandas as pd | |
import requests | |
from datasets import Dataset, DownloadMode, load_dataset | |
from gradio_client import Client | |
from src.my_logger import setup_logger | |
SUBREDDIT = os.environ["SUBREDDIT"] | |
USERNAME = os.environ["USERNAME"] | |
OG_DATASET = f"{USERNAME}/dataset-creator-reddit-{SUBREDDIT}" | |
PROCESSED_DATASET = os.environ['PROCESSED_DATASET'] | |
embeddings_space = f"{USERNAME}/nomic-embeddings" | |
FILTER_IDS_URL = "https://huggingface.co/spaces/reddit-tools-HF/dataset-creator-reddit-bestofredditorupdates/raw/main/filter_ids.json" | |
logger = setup_logger(__name__) | |
def load_datasets(): | |
# Get latest datasets locally | |
logger.info(f"Trying to download {PROCESSED_DATASET}") | |
dataset = load_dataset(PROCESSED_DATASET, download_mode=DownloadMode.FORCE_REDOWNLOAD) | |
logger.info(f"Loaded {PROCESSED_DATASET}") | |
logger.info(f"Trying to download {OG_DATASET}") | |
original_dataset = load_dataset(OG_DATASET, download_mode=DownloadMode.FORCE_REDOWNLOAD) | |
logger.info(f"Loaded {OG_DATASET}") | |
return dataset, original_dataset | |
def merge_and_update_datasets(dataset, original_dataset): | |
# Get client | |
client = Client(embeddings_space) | |
# Merge and figure out which rows need to be updated with embeddings | |
odf = original_dataset['train'].to_pandas() | |
df = dataset['train'].to_pandas() | |
# Filter ODF in-case we missed any | |
odf = remove_filtered_rows(odf, FILTER_IDS_URL) | |
# Step 1: Merge df onto odf | |
# We'll bring in 'content' and 'embedding' from df to compare and possibly update 'embedding' | |
merged_df = pd.merge(odf, df[['id', 'content', 'embedding']], on='id', how='left', suffixes=('_odf', '')) | |
updated_row_count = len(merged_df[merged_df.content != merged_df.content_odf]) | |
# Step 2: Compare 'content' from odf and df, update 'embedding' if they differ | |
merged_df['embedding'] = np.where(merged_df['content_odf'] != merged_df['content'], None, merged_df['embedding']) | |
# Step 3: Cleanup - keep only the necessary columns. | |
# Assuming you want to keep 'content' from 'odf' and the updated 'embedding', and drop the rest | |
merged_df = merged_df.drop(columns=['content', 'new', 'updated']) # Update columns to match df | |
merged_df.rename(columns={'content_odf': 'content'}, inplace=True) # Rename 'content_odf' back to 'content' | |
logger.info(f"Updating {updated_row_count} rows...") | |
# Iterate over the DataFrame rows where 'embedding' is None | |
for index, row in merged_df[merged_df['embedding'].isnull()].iterrows(): | |
# Update 'embedding' for the current row using our function | |
merged_df.at[index, 'embedding'] = update_embeddings(content=row['content'], client=client) | |
dataset['train'] = Dataset.from_pandas(merged_df) | |
logger.info(f"Updated {updated_row_count} rows") | |
return dataset, updated_row_count | |
def remove_filtered_rows(df: pd.DataFrame, url: str) -> pd.DataFrame: | |
""" | |
Removes rows from the DataFrame where the 'id' is present in the JSON file at the given URL. | |
:param df: Input DataFrame to be filtered. | |
:param url: URL to the JSON file containing the filter IDs. | |
:return: DataFrame with rows containing IDs present in the JSON file removed. | |
""" | |
# Load filter IDs from JSON file at the URL | |
response = requests.get(url) | |
filter_ids = response.json() | |
logger.info(f"Loaded {len(filter_ids)} IDs from {url}") | |
# Remove the rows with IDs present in filter_ids | |
filtered_df = df[~df['id'].astype(str).isin(filter_ids)] | |
logger.info(f"Filtered {len(df) - len(filtered_df)} rows from the DataFrame") | |
return filtered_df | |
def update_embeddings(content, client): | |
embedding = client.predict('search_document: ' + content, api_name="/embed") | |
return np.array(embedding) | |