import chromadb import platform import polars as pl import polars as pl from chromadb.utils import embedding_functions from typing import List, Tuple, Optional from huggingface_hub import InferenceClient from tqdm.contrib.concurrent import thread_map from huggingface_hub import login from dotenv import load_dotenv import os from datetime import datetime, timedelta import stamina import requests import polars as pl from typing import Literal load_dotenv() HF_TOKEN = os.getenv("HF_TOKEN") def get_save_path() -> Literal["chroma/"] | Literal["/data/chroma/"]: return "chroma/" if platform.system() == "Darwin" else "/data/chroma/" save_path = get_save_path() chroma_client = chromadb.PersistentClient( path=save_path, ) sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction( model_name="Snowflake/snowflake-arctic-embed-m-long", trust_remote_code=True ) collection = chroma_client.create_collection( name="dataset_cards", get_or_create=True, embedding_function=sentence_transformer_ef ) def get_last_modified_in_collection() -> datetime | None: all_items = collection.get( include=[ "metadatas", ] ) if last_modified := [ datetime.fromisoformat(item["last_modified"]) for item in all_items["metadatas"] ]: return max(last_modified) else: return None def parse_markdown_column( df: pl.DataFrame, markdown_column: str, dataset_id_column: str ) -> pl.DataFrame: return df.with_columns( parsed_markdown=( pl.col(markdown_column) .str.extract(r"(?s)^---.*?---\s*(.*)", group_index=1) .fill_null(pl.col(markdown_column)) .str.strip_chars() ), prepended_markdown=( pl.concat_str( [ pl.lit("Dataset ID "), pl.col(dataset_id_column).cast(pl.Utf8), pl.lit("\n\n"), pl.col(markdown_column) .str.extract(r"(?s)^---.*?---\s*(.*)", group_index=1) .fill_null(pl.col(markdown_column)) .str.strip_chars(), ] ) ), ) def load_cards( min_len: int = 50, min_likes: int | None = None, last_modified: Optional[datetime] = None, ) -> ( None | Tuple[ List[str], List[str], List[datetime], ] ): df = pl.read_parquet( "hf://datasets/librarian-bots/dataset_cards_with_metadata_with_embeddings/data/train-00000-of-00001.parquet" ) df = parse_markdown_column(df, "card", "datasetId") df = df.with_columns(pl.col("parsed_markdown").str.len_chars().alias("card_len")) print(df) df = df.filter(pl.col("card_len") > min_len) print(df) if min_likes: df = df.filter(pl.col("likes") > min_likes) if last_modified: df = df.filter(pl.col("last_modified") > last_modified) if len(df) == 0: return None cards = df.get_column("prepended_markdown").to_list() model_ids = df.get_column("datasetId").to_list() last_modifieds = df.get_column("last_modified").to_list() return cards, model_ids, last_modifieds client = InferenceClient( model="https://pqzap00ebpl1ydt4.us-east-1.aws.endpoints.huggingface.cloud", token=HF_TOKEN, ) @stamina.retry(on=requests.HTTPError, attempts=3, wait_initial=10) def embed_card(text): text = text[:8192] return client.feature_extraction(text) most_recent = get_last_modified_in_collection() if data := load_cards(min_len=200, min_likes=None, last_modified=most_recent): cards, model_ids, last_modifieds = data print("mapping...") results = thread_map(embed_card, cards) collection.upsert( ids=model_ids, embeddings=[embedding.tolist()[0] for embedding in results], metadatas=[{"last_modified": str(lm)} for lm in last_modifieds], ) print("done") else: print("no new data")