davanstrien HF staff commited on
Commit
b5f94b5
1 Parent(s): 1d74113
Files changed (1) hide show
  1. main.py +8 -19
main.py CHANGED
@@ -16,8 +16,9 @@ from starlette.status import (
16
  HTTP_500_INTERNAL_SERVER_ERROR,
17
  )
18
 
19
- from load_card_data import get_embedding_function, get_save_path, refresh_card_data
20
  from load_viewer_data import refresh_viewer_data
 
21
 
22
  # Set up logging
23
  logging.basicConfig(
@@ -31,7 +32,7 @@ cache.setup("mem://?check_interval=10&size=1000")
31
  # Initialize Chroma client
32
  SAVE_PATH = get_save_path()
33
  client = chromadb.PersistentClient(path=SAVE_PATH)
34
- collection = None
35
 
36
  async_client = AsyncClient(
37
  follow_redirects=True,
@@ -40,33 +41,20 @@ async_client = AsyncClient(
40
 
41
  @asynccontextmanager
42
  async def lifespan(app: FastAPI):
43
- global collection
44
  # Startup: refresh data and initialize collection
45
  logger.info("Starting up the application")
46
  try:
47
- # Create or get the collection
48
- logger.info("Initializing embedding function")
49
- embedding_function = get_embedding_function()
50
- logger.info("Creating or getting collection")
51
- collection = client.get_or_create_collection(
52
- name="dataset_cards", embedding_function=embedding_function
53
- )
54
- logger.info("Collection initialized successfully")
55
-
56
  # Refresh data
57
  logger.info("Starting refresh of card data")
58
  refresh_card_data()
59
  logger.info("Card data refresh completed")
60
-
61
  logger.info("Starting refresh of viewer data")
62
  await refresh_viewer_data()
63
  logger.info("Viewer data refresh completed")
64
-
65
  logger.info("Data refresh completed successfully")
66
  except Exception as e:
67
  logger.error(f"Error during startup: {str(e)}")
68
  logger.warning("Application starting with potential data issues")
69
-
70
  yield
71
 
72
  # Shutdown: perform any cleanup
@@ -123,6 +111,8 @@ class DatasetNotForAllAudiencesError(HTTPException):
123
  @app.get("/similar", response_model=QueryResponse)
124
  @cache(ttl="1h")
125
  async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)):
 
 
126
  try:
127
  logger.info(f"Querying dataset: {dataset_id}")
128
  # Get the embedding for the given dataset_id
@@ -130,7 +120,6 @@ async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le
130
  if not result.get("embeddings"):
131
  logger.info(f"Dataset not found: {dataset_id}")
132
  try:
133
- embedding_function = get_embedding_function()
134
  card = await try_get_card(dataset_id)
135
  if card is None:
136
  raise DatasetCardNotFoundError(dataset_id)
@@ -182,13 +171,13 @@ async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le
182
  ) from e
183
 
184
 
185
- @app.post("/similar-text", response_model=QueryResponse)
186
  @cache(ttl="1h")
187
  async def api_query_by_text(query: str, n: int = Query(default=10, ge=1, le=100)):
188
  try:
189
  logger.info(f"Querying datasets by text: {query}")
190
  collection = client.get_collection(
191
- name="dataset_cards", embedding_function=get_embedding_function()
192
  )
193
  print(query)
194
  query_result = collection.query(
@@ -220,7 +209,7 @@ async def api_query_by_text(query: str, n: int = Query(default=10, ge=1, le=100)
220
  ) from e
221
 
222
 
223
- @app.post("/search-viewer", response_model=QueryResponse)
224
  @cache(ttl="1h")
225
  async def api_search_viewer(query: str, n: int = Query(default=10, ge=1, le=100)):
226
  try:
 
16
  HTTP_500_INTERNAL_SERVER_ERROR,
17
  )
18
 
19
+ from load_card_data import card_embedding_function, refresh_card_data
20
  from load_viewer_data import refresh_viewer_data
21
+ from utils import get_save_path, get_collection
22
 
23
  # Set up logging
24
  logging.basicConfig(
 
32
  # Initialize Chroma client
33
  SAVE_PATH = get_save_path()
34
  client = chromadb.PersistentClient(path=SAVE_PATH)
35
+
36
 
37
  async_client = AsyncClient(
38
  follow_redirects=True,
 
41
 
42
  @asynccontextmanager
43
  async def lifespan(app: FastAPI):
 
44
  # Startup: refresh data and initialize collection
45
  logger.info("Starting up the application")
46
  try:
 
 
 
 
 
 
 
 
 
47
  # Refresh data
48
  logger.info("Starting refresh of card data")
49
  refresh_card_data()
50
  logger.info("Card data refresh completed")
 
51
  logger.info("Starting refresh of viewer data")
52
  await refresh_viewer_data()
53
  logger.info("Viewer data refresh completed")
 
54
  logger.info("Data refresh completed successfully")
55
  except Exception as e:
56
  logger.error(f"Error during startup: {str(e)}")
57
  logger.warning("Application starting with potential data issues")
 
58
  yield
59
 
60
  # Shutdown: perform any cleanup
 
111
  @app.get("/similar", response_model=QueryResponse)
112
  @cache(ttl="1h")
113
  async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)):
114
+ embedding_function = card_embedding_function()
115
+ collection = get_collection(client, embedding_function, "dataset_cards")
116
  try:
117
  logger.info(f"Querying dataset: {dataset_id}")
118
  # Get the embedding for the given dataset_id
 
120
  if not result.get("embeddings"):
121
  logger.info(f"Dataset not found: {dataset_id}")
122
  try:
 
123
  card = await try_get_card(dataset_id)
124
  if card is None:
125
  raise DatasetCardNotFoundError(dataset_id)
 
171
  ) from e
172
 
173
 
174
+ @app.get("/similar-text", response_model=QueryResponse)
175
  @cache(ttl="1h")
176
  async def api_query_by_text(query: str, n: int = Query(default=10, ge=1, le=100)):
177
  try:
178
  logger.info(f"Querying datasets by text: {query}")
179
  collection = client.get_collection(
180
+ name="dataset_cards", embedding_function=card_embedding_function()
181
  )
182
  print(query)
183
  query_result = collection.query(
 
209
  ) from e
210
 
211
 
212
+ @app.get("/search-viewer", response_model=QueryResponse)
213
  @cache(ttl="1h")
214
  async def api_search_viewer(query: str, n: int = Query(default=10, ge=1, le=100)):
215
  try: