davanstrien HF staff commited on
Commit
cb13c5d
1 Parent(s): 557fe3e

global client

Browse files
Files changed (3) hide show
  1. load_card_data.py +2 -2
  2. main.py +2 -4
  3. utils.py +10 -7
load_card_data.py CHANGED
@@ -10,7 +10,7 @@ from chromadb.utils import embedding_functions
10
  from dotenv import load_dotenv
11
  from huggingface_hub import InferenceClient
12
  from tqdm.contrib.concurrent import thread_map
13
- from utils import get_chroma_client, get_collection
14
 
15
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
16
  # Set up logging
@@ -163,8 +163,8 @@ def get_inference_client():
163
 
164
  def refresh_card_data(min_len: int = 250, min_likes: Optional[int] = None):
165
  logger.info(f"Starting data refresh with min_len={min_len}, min_likes={min_likes}")
166
- chroma_client = get_chroma_client()
167
  embedding_function = card_embedding_function()
 
168
  collection = get_collection(chroma_client, embedding_function, COLLECTION_NAME)
169
  most_recent = get_last_modified_in_collection(collection)
170
 
 
10
  from dotenv import load_dotenv
11
  from huggingface_hub import InferenceClient
12
  from tqdm.contrib.concurrent import thread_map
13
+ from utils import get_collection, get_chroma_client
14
 
15
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
16
  # Set up logging
 
163
 
164
  def refresh_card_data(min_len: int = 250, min_likes: Optional[int] = None):
165
  logger.info(f"Starting data refresh with min_len={min_len}, min_likes={min_likes}")
 
166
  embedding_function = card_embedding_function()
167
+ chroma_client = get_chroma_client()
168
  collection = get_collection(chroma_client, embedding_function, COLLECTION_NAME)
169
  most_recent = get_last_modified_in_collection(collection)
170
 
main.py CHANGED
@@ -18,7 +18,7 @@ from starlette.status import (
18
 
19
  from load_card_data import card_embedding_function, refresh_card_data
20
  from load_viewer_data import refresh_viewer_data
21
- from utils import get_save_path, get_collection
22
 
23
  # Set up logging
24
  logging.basicConfig(
@@ -30,9 +30,7 @@ logger = logging.getLogger(__name__)
30
  cache.setup("mem://?check_interval=10&size=1000")
31
 
32
  # Initialize Chroma client
33
- SAVE_PATH = get_save_path()
34
- client = chromadb.PersistentClient(path=SAVE_PATH)
35
-
36
 
37
  async_client = AsyncClient(
38
  follow_redirects=True,
 
18
 
19
  from load_card_data import card_embedding_function, refresh_card_data
20
  from load_viewer_data import refresh_viewer_data
21
+ from utils import get_save_path, get_collection, get_chroma_client
22
 
23
  # Set up logging
24
  logging.basicConfig(
 
30
  cache.setup("mem://?check_interval=10&size=1000")
31
 
32
  # Initialize Chroma client
33
+ client = get_chroma_client()
 
 
34
 
35
  async_client = AsyncClient(
36
  follow_redirects=True,
utils.py CHANGED
@@ -26,6 +26,8 @@ logger = logging.getLogger(__name__)
26
 
27
  load_dotenv()
28
 
 
 
29
 
30
  def get_save_path() -> Literal["chroma/"] | Literal["/data/chroma/"]:
31
  path = "chroma/" if platform.system() == "Darwin" else "/data/chroma/"
@@ -34,13 +36,14 @@ def get_save_path() -> Literal["chroma/"] | Literal["/data/chroma/"]:
34
 
35
 
36
  def get_chroma_client():
37
- logger.info("Initializing Chroma client")
38
- SAVE_PATH = get_save_path()
39
-
40
- return chromadb.PersistentClient(
41
- path=SAVE_PATH,
42
- settings=Settings(anonymized_telemetry=False, is_persistent=True),
43
- )
 
44
 
45
 
46
  def get_collection(chroma_client, embedding_function, collection_name):
 
26
 
27
  load_dotenv()
28
 
29
+ CHROMA_CLIENT = None
30
+
31
 
32
  def get_save_path() -> Literal["chroma/"] | Literal["/data/chroma/"]:
33
  path = "chroma/" if platform.system() == "Darwin" else "/data/chroma/"
 
36
 
37
 
38
  def get_chroma_client():
39
+ global CHROMA_CLIENT
40
+ if CHROMA_CLIENT is None:
41
+ SAVE_PATH = get_save_path()
42
+ CHROMA_CLIENT = chromadb.PersistentClient(
43
+ path=SAVE_PATH,
44
+ settings=Settings(anonymized_telemetry=False, is_persistent=True),
45
+ )
46
+ return CHROMA_CLIENT
47
 
48
 
49
  def get_collection(chroma_client, embedding_function, collection_name):