Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 3,705 Bytes
3e2784f 4d0e134 3e2784f 4d0e134 3e2784f 6a4b44c 4d0e134 6a4b44c 4d0e134 3e2784f 4d0e134 3e2784f 4d0e134 3e2784f 4d0e134 3e2784f 4d0e134 3e2784f 4d0e134 3e2784f 4d0e134 6a4b44c 3e2784f 4d0e134 3e2784f 4d0e134 3e2784f 4d0e134 3e2784f 4d0e134 3e2784f 4d0e134 3e2784f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import asyncio
import logging
import chromadb
import requests
import stamina
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
from huggingface_hub import InferenceClient
from tqdm.auto import tqdm
from tqdm.contrib.concurrent import thread_map
from prep_viewer_data import prep_data
from utils import get_chroma_client
# Set up logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
EMBEDDING_MODEL_NAME = "davanstrien/query-to-dataset-viewer-descriptions"
EMBEDDING_MODEL_REVISION = "07c71d97861a73695f0c53cd6b4b32980007d908"
INFERENCE_MODEL_URL = (
"https://ecg0by60w2vo9j8h.us-east-1.aws.endpoints.huggingface.cloud"
)
def initialize_clients():
logger.info("Initializing clients")
chroma_client = get_chroma_client()
inference_client = InferenceClient(
INFERENCE_MODEL_URL,
)
return chroma_client, inference_client
def create_collection(chroma_client):
logger.info("Creating or getting collection")
embedding_function = SentenceTransformerEmbeddingFunction(
model_name=EMBEDDING_MODEL_NAME,
trust_remote_code=True,
revision=EMBEDDING_MODEL_REVISION,
)
logger.info(f"Embedding function: {embedding_function}")
logger.info(f"Embedding model name: {EMBEDDING_MODEL_NAME}")
logger.info(f"Embedding model revision: {EMBEDDING_MODEL_REVISION}")
return chroma_client.create_collection(
name="dataset-viewer-descriptions",
get_or_create=True,
embedding_function=embedding_function,
metadata={"hnsw:space": "cosine"},
)
@stamina.retry(on=requests.HTTPError, attempts=3, wait_initial=10)
def embed_card(text, client):
text = text[:8192]
return client.feature_extraction(text)
def embed_and_upsert_datasets(
dataset_rows_and_ids: list[dict[str, str]],
collection: chromadb.Collection,
inference_client: InferenceClient,
batch_size: int = 100,
):
logger.info(
f"Embedding and upserting {len(dataset_rows_and_ids)} datasets for viewer data"
)
for i in tqdm(range(0, len(dataset_rows_and_ids), batch_size)):
batch = dataset_rows_and_ids[i : i + batch_size]
ids = []
documents = []
for item in batch:
ids.append(item["dataset_id"])
documents.append(f"HUB_DATASET_PREVIEW: {item['formatted_prompt']}")
results = thread_map(
lambda doc: embed_card(doc, inference_client), documents, leave=False
)
logger.info(f"Results: {len(results)}")
collection.upsert(
ids=ids,
embeddings=[embedding.tolist()[0] for embedding in results],
)
logger.debug(f"Processed batch {i//batch_size + 1}")
async def refresh_viewer_data(sample_size=200_000, min_likes=2):
logger.info(
f"Refreshing viewer data with sample_size={sample_size} and min_likes={min_likes}"
)
chroma_client, inference_client = initialize_clients()
collection = create_collection(chroma_client)
logger.info("Collection created successfully")
logger.info("Preparing data")
df = await prep_data(sample_size=sample_size, min_likes=min_likes)
df.write_parquet("viewer_data.parquet")
if df is not None:
logger.info("Data prepared successfully")
logger.info(f"Data: {df}")
dataset_rows_and_ids = df.to_dicts()
logger.info(f"Embedding and upserting {len(dataset_rows_and_ids)} datasets")
embed_and_upsert_datasets(dataset_rows_and_ids, collection, inference_client)
logger.info("Refresh completed successfully")
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
asyncio.run(refresh_viewer_data())
|