derek-thomas's picture
derek-thomas HF staff
Upgrading to nomic 1.5
a6deb48
raw
history blame
No virus
2.76 kB
import os
import numpy as np
import pandas as pd
from datasets import Dataset, DownloadMode, load_dataset
from gradio_client import Client
from src.my_logger import setup_logger
SUBREDDIT = os.environ["SUBREDDIT"]
USERNAME = os.environ["USERNAME"]
OG_DATASET = f"{USERNAME}/dataset-creator-reddit-{SUBREDDIT}"
PROCESSED_DATASET = os.environ['PROCESSED_DATASET']
embeddings_space = f"{USERNAME}/nomic-embeddings"
logger = setup_logger(__name__)
def load_datasets():
# Get latest datasets locally
logger.info(f"Trying to download {PROCESSED_DATASET}")
dataset = load_dataset(PROCESSED_DATASET, download_mode=DownloadMode.FORCE_REDOWNLOAD)
logger.info(f"Loaded {PROCESSED_DATASET}")
logger.info(f"Trying to download {OG_DATASET}")
original_dataset = load_dataset(OG_DATASET, download_mode=DownloadMode.FORCE_REDOWNLOAD)
logger.info(f"Loaded {OG_DATASET}")
return dataset, original_dataset
def merge_and_update_datasets(dataset, original_dataset):
# Get client
client = Client(embeddings_space)
# Merge and figure out which rows need to be updated with embeddings
odf = original_dataset['train'].to_pandas()
df = dataset['train'].to_pandas()
# Step 1: Merge df onto odf
# We'll bring in 'content' and 'embedding' from df to compare and possibly update 'embedding'
merged_df = pd.merge(odf, df[['id', 'content', 'embedding']], on='id', how='left', suffixes=('_odf', ''))
updated_row_count = len(merged_df[merged_df.content != merged_df.content_odf])
# Step 2: Compare 'content' from odf and df, update 'embedding' if they differ
merged_df['embedding'] = np.where(merged_df['content_odf'] != merged_df['content'], None, merged_df['embedding'])
# Step 3: Cleanup - keep only the necessary columns.
# Assuming you want to keep 'content' from 'odf' and the updated 'embedding', and drop the rest
merged_df = merged_df.drop(columns=['content', 'new', 'updated']) # Update columns to match df
merged_df.rename(columns={'content_odf': 'content'}, inplace=True) # Rename 'content_odf' back to 'content'
logger.info(f"Updating {updated_row_count} rows...")
# Iterate over the DataFrame rows where 'embedding' is None
for index, row in merged_df[merged_df['embedding'].isnull()].iterrows():
# Update 'embedding' for the current row using our function
merged_df.at[index, 'embedding'] = update_embeddings(content=row['content'], client=client)
dataset['train'] = Dataset.from_pandas(merged_df)
logger.info(f"Updated {updated_row_count} rows")
return dataset, updated_row_count
def update_embeddings(content, client):
embedding = client.predict('search_document: ' + content, api_name="/embed")
return np.array(embedding)