davanstrien HF staff commited on
Commit
4d0e134
1 Parent(s): 2794b15
Files changed (1) hide show
  1. load_viewer_data.py +30 -7
load_viewer_data.py CHANGED
@@ -9,18 +9,26 @@ from huggingface_hub import InferenceClient
9
  from tqdm.auto import tqdm
10
  from tqdm.contrib.concurrent import thread_map
11
 
 
12
  from prep_viewer_data import prep_data
 
13
 
14
  # Set up logging
15
  logger = logging.getLogger(__name__)
16
  logger.setLevel(logging.INFO)
17
 
 
 
 
 
 
 
18
 
19
  def initialize_clients():
20
  logger.info("Initializing clients")
21
- chroma_client = chromadb.PersistentClient()
22
  inference_client = InferenceClient(
23
- "https://bm143rfir2on1bkw.us-east-1.aws.endpoints.huggingface.cloud"
24
  )
25
  return chroma_client, inference_client
26
 
@@ -28,9 +36,13 @@ def initialize_clients():
28
  def create_collection(chroma_client):
29
  logger.info("Creating or getting collection")
30
  embedding_function = SentenceTransformerEmbeddingFunction(
31
- model_name="davanstrien/dataset-viewer-descriptions-processed-st",
32
  trust_remote_code=True,
 
33
  )
 
 
 
34
  return chroma_client.create_collection(
35
  name="dataset-viewer-descriptions",
36
  get_or_create=True,
@@ -46,9 +58,14 @@ def embed_card(text, client):
46
 
47
 
48
  def embed_and_upsert_datasets(
49
- dataset_rows_and_ids, collection, inference_client, batch_size=10
 
 
 
50
  ):
51
- logger.info(f"Embedding and upserting {len(dataset_rows_and_ids)} datasets")
 
 
52
  for i in tqdm(range(0, len(dataset_rows_and_ids), batch_size)):
53
  batch = dataset_rows_and_ids[i : i + batch_size]
54
  ids = []
@@ -59,6 +76,7 @@ def embed_and_upsert_datasets(
59
  results = thread_map(
60
  lambda doc: embed_card(doc, inference_client), documents, leave=False
61
  )
 
62
  collection.upsert(
63
  ids=ids,
64
  embeddings=[embedding.tolist()[0] for embedding in results],
@@ -66,15 +84,20 @@ def embed_and_upsert_datasets(
66
  logger.debug(f"Processed batch {i//batch_size + 1}")
67
 
68
 
69
- async def refresh_viewer_data(sample_size=100_000, min_likes=2):
70
  logger.info(
71
  f"Refreshing viewer data with sample_size={sample_size} and min_likes={min_likes}"
72
  )
73
  chroma_client, inference_client = initialize_clients()
74
  collection = create_collection(chroma_client)
75
-
76
  logger.info("Preparing data")
77
  df = await prep_data(sample_size=sample_size, min_likes=min_likes)
 
 
 
 
 
78
  dataset_rows_and_ids = df.to_dicts()
79
 
80
  logger.info(f"Embedding and upserting {len(dataset_rows_and_ids)} datasets")
 
9
  from tqdm.auto import tqdm
10
  from tqdm.contrib.concurrent import thread_map
11
 
12
+
13
  from prep_viewer_data import prep_data
14
+ from utils import get_chroma_client
15
 
16
  # Set up logging
17
  logger = logging.getLogger(__name__)
18
  logger.setLevel(logging.INFO)
19
 
20
+ EMBEDDING_MODEL_NAME = "davanstrien/dataset-viewer-descriptions-processed-st"
21
+ EMBEDDING_MODEL_REVISION = "d09abf1227ac41c6955eb9dd53c21771b0984ade"
22
+ INFERENCE_MODEL_URL = (
23
+ "https://bm143rfir2on1bkw.us-east-1.aws.endpoints.huggingface.cloud"
24
+ )
25
+
26
 
27
  def initialize_clients():
28
  logger.info("Initializing clients")
29
+ chroma_client = get_chroma_client()
30
  inference_client = InferenceClient(
31
+ INFERENCE_MODEL_URL,
32
  )
33
  return chroma_client, inference_client
34
 
 
36
  def create_collection(chroma_client):
37
  logger.info("Creating or getting collection")
38
  embedding_function = SentenceTransformerEmbeddingFunction(
39
+ model_name=EMBEDDING_MODEL_NAME,
40
  trust_remote_code=True,
41
+ revision=EMBEDDING_MODEL_REVISION,
42
  )
43
+ logger.info(f"Embedding function: {embedding_function}")
44
+ logger.info(f"Embedding model name: {EMBEDDING_MODEL_NAME}")
45
+ logger.info(f"Embedding model revision: {EMBEDDING_MODEL_REVISION}")
46
  return chroma_client.create_collection(
47
  name="dataset-viewer-descriptions",
48
  get_or_create=True,
 
58
 
59
 
60
  def embed_and_upsert_datasets(
61
+ dataset_rows_and_ids: list[dict[str, str]],
62
+ collection: chromadb.Collection,
63
+ inference_client: InferenceClient,
64
+ batch_size: int = 10,
65
  ):
66
+ logger.info(
67
+ f"Embedding and upserting {len(dataset_rows_and_ids)} datasets for viewer data"
68
+ )
69
  for i in tqdm(range(0, len(dataset_rows_and_ids), batch_size)):
70
  batch = dataset_rows_and_ids[i : i + batch_size]
71
  ids = []
 
76
  results = thread_map(
77
  lambda doc: embed_card(doc, inference_client), documents, leave=False
78
  )
79
+ logger.info(f"Results: {len(results)}")
80
  collection.upsert(
81
  ids=ids,
82
  embeddings=[embedding.tolist()[0] for embedding in results],
 
84
  logger.debug(f"Processed batch {i//batch_size + 1}")
85
 
86
 
87
+ async def refresh_viewer_data(sample_size=200_000, min_likes=2):
88
  logger.info(
89
  f"Refreshing viewer data with sample_size={sample_size} and min_likes={min_likes}"
90
  )
91
  chroma_client, inference_client = initialize_clients()
92
  collection = create_collection(chroma_client)
93
+ logger.info("Collection created successfully")
94
  logger.info("Preparing data")
95
  df = await prep_data(sample_size=sample_size, min_likes=min_likes)
96
+ df.write_parquet("viewer_data.parquet")
97
+ if df is not None:
98
+ logger.info("Data prepared successfully")
99
+ logger.info(f"Data: {df}")
100
+
101
  dataset_rows_and_ids = df.to_dicts()
102
 
103
  logger.info(f"Embedding and upserting {len(dataset_rows_and_ids)} datasets")