|
import os |
|
|
|
import numpy as np |
|
import pandas as pd |
|
from datasets import Dataset, DownloadMode, load_dataset |
|
from gradio_client import Client |
|
|
|
from src.my_logger import setup_logger |
|
|
|
SUBREDDIT = os.environ["SUBREDDIT"] |
|
USERNAME = os.environ["USERNAME"] |
|
OG_DATASET = f"{USERNAME}/dataset-creator-reddit-{SUBREDDIT}" |
|
PROCESSED_DATASET = os.environ['PROCESSED_DATASET'] |
|
embeddings_space = f"{USERNAME}/nomic-embeddings" |
|
|
|
logger = setup_logger(__name__) |
|
|
|
|
|
def load_datasets(): |
|
|
|
logger.info(f"Trying to download {PROCESSED_DATASET}") |
|
dataset = load_dataset(PROCESSED_DATASET, download_mode=DownloadMode.FORCE_REDOWNLOAD) |
|
logger.info(f"Loaded {PROCESSED_DATASET}") |
|
|
|
logger.info(f"Trying to download {OG_DATASET}") |
|
original_dataset = load_dataset(OG_DATASET, download_mode=DownloadMode.FORCE_REDOWNLOAD) |
|
logger.info(f"Loaded {OG_DATASET}") |
|
return dataset, original_dataset |
|
|
|
|
|
def merge_and_update_datasets(dataset, original_dataset): |
|
|
|
client = Client(embeddings_space) |
|
|
|
|
|
odf = original_dataset['train'].to_pandas() |
|
df = dataset['train'].to_pandas() |
|
|
|
|
|
|
|
merged_df = pd.merge(odf, df[['id', 'content', 'embedding']], on='id', how='left', suffixes=('_odf', '')) |
|
updated_row_count = len(merged_df[merged_df.content != merged_df.content_odf]) |
|
|
|
|
|
merged_df['embedding'] = np.where(merged_df['content_odf'] != merged_df['content'], None, merged_df['embedding']) |
|
|
|
|
|
|
|
merged_df = merged_df.drop(columns=['content', 'new', 'updated']) |
|
merged_df.rename(columns={'content_odf': 'content'}, inplace=True) |
|
|
|
logger.info(f"Updating {updated_row_count} rows...") |
|
|
|
for index, row in merged_df[merged_df['embedding'].isnull()].iterrows(): |
|
|
|
merged_df.at[index, 'embedding'] = update_embeddings(content=row['content'], client=client) |
|
|
|
dataset['train'] = Dataset.from_pandas(merged_df) |
|
logger.info(f"Updated {updated_row_count} rows") |
|
return dataset, updated_row_count |
|
|
|
|
|
def update_embeddings(content, client): |
|
embedding = client.predict('search_document: ' + content, api_name="/embed") |
|
return np.array(embedding) |
|
|