Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Commit
•
1d74113
1
Parent(s):
79269c4
refactor
Browse files- load_card_data.py +7 -32
load_card_data.py
CHANGED
@@ -1,10 +1,8 @@
|
|
1 |
import logging
|
2 |
import os
|
3 |
-
import platform
|
4 |
from datetime import datetime
|
5 |
-
from typing import List,
|
6 |
|
7 |
-
import chromadb
|
8 |
import polars as pl
|
9 |
import requests
|
10 |
import stamina
|
@@ -12,6 +10,7 @@ from chromadb.utils import embedding_functions
|
|
12 |
from dotenv import load_dotenv
|
13 |
from huggingface_hub import InferenceClient
|
14 |
from tqdm.contrib.concurrent import thread_map
|
|
|
15 |
|
16 |
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
17 |
# Set up logging
|
@@ -27,30 +26,17 @@ HF_TOKEN = os.getenv("HF_TOKEN")
|
|
27 |
EMBEDDING_MODEL_NAME = "Alibaba-NLP/gte-large-en-v1.5"
|
28 |
EMBEDDING_MODEL_REVISION = "104333d6af6f97649377c2afbde10a7704870c7b"
|
29 |
INFERENCE_MODEL_URL = (
|
30 |
-
"https://
|
31 |
)
|
32 |
DATASET_PARQUET_URL = (
|
33 |
"hf://datasets/librarian-bots/dataset_cards_with_metadata/data/train-*.parquet"
|
34 |
)
|
35 |
COLLECTION_NAME = "dataset_cards"
|
36 |
-
MAX_EMBEDDING_LENGTH = 8192
|
37 |
-
|
38 |
-
|
39 |
-
def get_save_path() -> Literal["chroma/"] | Literal["/data/chroma/"]:
|
40 |
-
path = "chroma/" if platform.system() == "Darwin" else "/data/chroma/"
|
41 |
-
logger.info(f"Using save path: {path}")
|
42 |
-
return path
|
43 |
-
|
44 |
-
|
45 |
-
SAVE_PATH = get_save_path()
|
46 |
-
|
47 |
|
48 |
-
|
49 |
-
logger.info("Initializing Chroma client")
|
50 |
-
return chromadb.PersistentClient(path=SAVE_PATH)
|
51 |
|
52 |
|
53 |
-
def
|
54 |
logger.info(f"Initializing embedding function with model: {EMBEDDING_MODEL_NAME}")
|
55 |
return embedding_functions.SentenceTransformerEmbeddingFunction(
|
56 |
model_name=EMBEDDING_MODEL_NAME,
|
@@ -59,16 +45,6 @@ def get_embedding_function():
|
|
59 |
)
|
60 |
|
61 |
|
62 |
-
def get_collection(chroma_client, embedding_function):
|
63 |
-
logger.info(f"Getting or creating collection: {COLLECTION_NAME}")
|
64 |
-
return chroma_client.create_collection(
|
65 |
-
name=COLLECTION_NAME,
|
66 |
-
get_or_create=True,
|
67 |
-
embedding_function=embedding_function,
|
68 |
-
metadata={"hnsw:space": "cosine"},
|
69 |
-
)
|
70 |
-
|
71 |
-
|
72 |
def get_last_modified_in_collection(collection) -> datetime | None:
|
73 |
logger.info("Fetching last modified date from collection")
|
74 |
try:
|
@@ -188,9 +164,8 @@ def get_inference_client():
|
|
188 |
def refresh_card_data(min_len: int = 250, min_likes: Optional[int] = None):
|
189 |
logger.info(f"Starting data refresh with min_len={min_len}, min_likes={min_likes}")
|
190 |
chroma_client = get_chroma_client()
|
191 |
-
embedding_function =
|
192 |
-
collection = get_collection(chroma_client, embedding_function)
|
193 |
-
|
194 |
most_recent = get_last_modified_in_collection(collection)
|
195 |
|
196 |
if data := load_cards(
|
|
|
1 |
import logging
|
2 |
import os
|
|
|
3 |
from datetime import datetime
|
4 |
+
from typing import List, Optional, Tuple
|
5 |
|
|
|
6 |
import polars as pl
|
7 |
import requests
|
8 |
import stamina
|
|
|
10 |
from dotenv import load_dotenv
|
11 |
from huggingface_hub import InferenceClient
|
12 |
from tqdm.contrib.concurrent import thread_map
|
13 |
+
from utils import get_chroma_client, get_collection
|
14 |
|
15 |
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
16 |
# Set up logging
|
|
|
26 |
EMBEDDING_MODEL_NAME = "Alibaba-NLP/gte-large-en-v1.5"
|
27 |
EMBEDDING_MODEL_REVISION = "104333d6af6f97649377c2afbde10a7704870c7b"
|
28 |
INFERENCE_MODEL_URL = (
|
29 |
+
"https://spwy1g6626yhjhjhpr.us-east-1.aws.endpoints.huggingface.cloud"
|
30 |
)
|
31 |
DATASET_PARQUET_URL = (
|
32 |
"hf://datasets/librarian-bots/dataset_cards_with_metadata/data/train-*.parquet"
|
33 |
)
|
34 |
COLLECTION_NAME = "dataset_cards"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
+
MAX_EMBEDDING_LENGTH = 8192
|
|
|
|
|
37 |
|
38 |
|
39 |
+
def card_embedding_function():
|
40 |
logger.info(f"Initializing embedding function with model: {EMBEDDING_MODEL_NAME}")
|
41 |
return embedding_functions.SentenceTransformerEmbeddingFunction(
|
42 |
model_name=EMBEDDING_MODEL_NAME,
|
|
|
45 |
)
|
46 |
|
47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
def get_last_modified_in_collection(collection) -> datetime | None:
|
49 |
logger.info("Fetching last modified date from collection")
|
50 |
try:
|
|
|
164 |
def refresh_card_data(min_len: int = 250, min_likes: Optional[int] = None):
|
165 |
logger.info(f"Starting data refresh with min_len={min_len}, min_likes={min_likes}")
|
166 |
chroma_client = get_chroma_client()
|
167 |
+
embedding_function = card_embedding_function()
|
168 |
+
collection = get_collection(chroma_client, embedding_function, COLLECTION_NAME)
|
|
|
169 |
most_recent = get_last_modified_in_collection(collection)
|
170 |
|
171 |
if data := load_cards(
|