import logging import os import platform from datetime import datetime from typing import List, Literal, Optional, Tuple import chromadb import polars as pl import requests import stamina from chromadb.utils import embedding_functions from dotenv import load_dotenv from huggingface_hub import InferenceClient from tqdm.contrib.concurrent import thread_map os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" # Set up logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) load_dotenv() HF_TOKEN = os.getenv("HF_TOKEN") EMBEDDING_MODEL_NAME = "Snowflake/snowflake-arctic-embed-m-long" EMBEDDING_MODEL_REVISION = "ac9d0cb43661ee1f7d67b3aa63614d65a6c86463" INFERENCE_MODEL_URL = ( "https://pqzap00ebpl1ydt4.us-east-1.aws.endpoints.huggingface.cloud" ) DATASET_PARQUET_URL = ( "hf://datasets/librarian-bots/dataset_cards_with_metadata/data/train-*.parquet" ) COLLECTION_NAME = "dataset_cards" MAX_EMBEDDING_LENGTH = 8192 def get_save_path() -> Literal["chroma/"] | Literal["/data/chroma/"]: path = "chroma/" if platform.system() == "Darwin" else "/data/chroma/" logger.info(f"Using save path: {path}") return path SAVE_PATH = get_save_path() def get_chroma_client(): logger.info("Initializing Chroma client") return chromadb.PersistentClient(path=SAVE_PATH) def get_embedding_function(): logger.info(f"Initializing embedding function with model: {EMBEDDING_MODEL_NAME}") return embedding_functions.SentenceTransformerEmbeddingFunction( model_name=EMBEDDING_MODEL_NAME, trust_remote_code=True, revision=EMBEDDING_MODEL_REVISION, ) def get_collection(chroma_client, embedding_function): logger.info(f"Getting or creating collection: {COLLECTION_NAME}") return chroma_client.create_collection( name=COLLECTION_NAME, get_or_create=True, embedding_function=embedding_function ) def get_last_modified_in_collection(collection) -> datetime | None: logger.info("Fetching last modified date from collection") try: all_items = collection.get(include=["metadatas"]) if last_modified := [ datetime.fromisoformat(item["last_modified"]) for item in all_items["metadatas"] ]: last_mod = max(last_modified) logger.info(f"Last modified date: {last_mod}") return last_mod else: logger.info("No last modified date found") return None except Exception as e: logger.error(f"Error fetching last modified date: {str(e)}") return None def parse_markdown_column( df: pl.DataFrame, markdown_column: str, dataset_id_column: str ) -> pl.DataFrame: logger.info("Parsing markdown column") return df.with_columns( parsed_markdown=( pl.col(markdown_column) .str.extract(r"(?s)^---.*?---\s*(.*)", group_index=1) .fill_null(pl.col(markdown_column)) .str.strip_chars() ), prepended_markdown=( pl.concat_str( [ pl.lit("Dataset ID "), pl.col(dataset_id_column).cast(pl.Utf8), pl.lit("\n\n"), pl.col(markdown_column) .str.extract(r"(?s)^---.*?---\s*(.*)", group_index=1) .fill_null(pl.col(markdown_column)) .str.strip_chars(), ] ) ), ) def is_unmodified_template(card: str) -> bool: # Check for a combination of template-specific phrases template_indicators = [ "# Dataset Card for Dataset Name", "", "This dataset card aims to be a base template for new datasets", "[More Information Needed]", ] # Count how many indicators are present indicator_count = sum(indicator in card for indicator in template_indicators) # Check if the card contains a high number of "[More Information Needed]" occurrences more_info_needed_count = card.count("[More Information Needed]") # Consider it an unmodified template if it has most of the indicators # and a high number of "[More Information Needed]" occurrences return indicator_count >= 3 or more_info_needed_count >= 7 def load_cards( min_len: int = 50, min_likes: int | None = None, last_modified: Optional[datetime] = None, ) -> Optional[Tuple[List[str], List[str], List[datetime]]]: logger.info( f"Loading cards with min_len={min_len}, min_likes={min_likes}, last_modified={last_modified}" ) df = pl.read_parquet(DATASET_PARQUET_URL) df = df.filter(~pl.col("tags").list.contains("not-for-all-audiences")) df = parse_markdown_column(df, "card", "datasetId") df = df.with_columns(pl.col("parsed_markdown").str.len_chars().alias("card_len")) df = df.filter(pl.col("card_len") > min_len) if min_likes: df = df.filter(pl.col("likes") > min_likes) if last_modified: df = df.filter(pl.col("last_modified") > last_modified) # Filter out unmodified template cards df = df.filter( ~pl.col("prepended_markdown").map_elements( is_unmodified_template, return_dtype=pl.Boolean ) ) if len(df) == 0: logger.info("No cards found matching criteria") return None cards = df.get_column("prepended_markdown").to_list() model_ids = df.get_column("datasetId").to_list() last_modifieds = df.get_column("last_modified").to_list() logger.info(f"Loaded {len(cards)} cards") return cards, model_ids, last_modifieds @stamina.retry(on=requests.HTTPError, attempts=3, wait_initial=10) def embed_card(text, client): text = text[:MAX_EMBEDDING_LENGTH] return client.feature_extraction(text) def get_inference_client(): logger.info(f"Initializing inference client with model: {INFERENCE_MODEL_URL}") return InferenceClient( model=INFERENCE_MODEL_URL, token=HF_TOKEN, ) def refresh_data(min_len: int = 200, min_likes: Optional[int] = None): logger.info(f"Starting data refresh with min_len={min_len}, min_likes={min_likes}") chroma_client = get_chroma_client() embedding_function = get_embedding_function() collection = get_collection(chroma_client, embedding_function) most_recent = get_last_modified_in_collection(collection) if data := load_cards( min_len=min_len, min_likes=min_likes, last_modified=most_recent ): _create_and_upsert_embeddings(data, collection) else: logger.info("No new data to refresh") def _create_and_upsert_embeddings(data, collection): cards, model_ids, last_modifieds = data logger.info("Embedding cards...") inference_client = get_inference_client() results = thread_map(lambda card: embed_card(card, inference_client), cards) logger.info(f"Upserting {len(model_ids)} items to collection") collection.upsert( ids=model_ids, embeddings=[embedding.tolist()[0] for embedding in results], metadatas=[{"last_modified": str(lm)} for lm in last_modifieds], ) logger.info("Data refresh completed successfully") if __name__ == "__main__": refresh_data()