davanstrien HF staff commited on
Commit
4d185df
1 Parent(s): a5c714c

try adding missing cards

Browse files
Files changed (1) hide show
  1. main.py +32 -4
main.py CHANGED
@@ -7,8 +7,9 @@ from cashews import cache
7
  from fastapi import FastAPI, HTTPException, Query
8
  from pydantic import BaseModel
9
  from starlette.responses import RedirectResponse
10
-
11
  from load_data import get_embedding_function, get_save_path, refresh_data
 
12
 
13
  # Set up logging
14
  logging.basicConfig(
@@ -24,6 +25,10 @@ SAVE_PATH = get_save_path()
24
  client = chromadb.PersistentClient(path=SAVE_PATH)
25
  collection = None
26
 
 
 
 
 
27
 
28
  class QueryResult(BaseModel):
29
  dataset_id: str
@@ -69,6 +74,19 @@ def root():
69
  return RedirectResponse(url="/docs")
70
 
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  @app.get("/query", response_model=Optional[QueryResponse])
73
  @cache(ttl="1h")
74
  async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)):
@@ -76,10 +94,20 @@ async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le
76
  logger.info(f"Querying dataset: {dataset_id}")
77
  # Get the embedding for the given dataset_id
78
  result = collection.get(ids=[dataset_id], include=["embeddings"])
79
-
80
- if not result["embeddings"]:
81
  logger.info(f"Dataset not found: {dataset_id}")
82
- raise HTTPException(status_code=404, detail="Dataset not found")
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  embedding = result["embeddings"][0]
85
 
 
7
  from fastapi import FastAPI, HTTPException, Query
8
  from pydantic import BaseModel
9
  from starlette.responses import RedirectResponse
10
+ from httpx import AsyncClient
11
  from load_data import get_embedding_function, get_save_path, refresh_data
12
+ from huggingface_hub import DatasetCard
13
 
14
  # Set up logging
15
  logging.basicConfig(
 
25
  client = chromadb.PersistentClient(path=SAVE_PATH)
26
  collection = None
27
 
28
+ async_client = AsyncClient(
29
+ follow_redirects=True,
30
+ )
31
+
32
 
33
  class QueryResult(BaseModel):
34
  dataset_id: str
 
74
  return RedirectResponse(url="/docs")
75
 
76
 
77
+ async def try_get_card(hub_id: str) -> Optional[str]:
78
+ try:
79
+ response = await async_client.get(
80
+ f"https://huggingface.co/datasets/{hub_id}/raw/main/README.md"
81
+ )
82
+ if response.status_code == 200:
83
+ card = DatasetCard(response.text)
84
+ return card.text
85
+ except Exception as e:
86
+ logger.error(f"Error fetching card for hub_id {hub_id}: {str(e)}")
87
+ return None
88
+
89
+
90
  @app.get("/query", response_model=Optional[QueryResponse])
91
  @cache(ttl="1h")
92
  async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)):
 
94
  logger.info(f"Querying dataset: {dataset_id}")
95
  # Get the embedding for the given dataset_id
96
  result = collection.get(ids=[dataset_id], include=["embeddings"])
97
+ if not result.get("embeddings"):
 
98
  logger.info(f"Dataset not found: {dataset_id}")
99
+ try:
100
+ embedding_function = get_embedding_function()
101
+ card = await try_get_card(dataset_id)
102
+ embeddings = embedding_function(card)
103
+ collection.upsert(ids=[dataset_id], embeddings=embeddings[0])
104
+ logger.info(f"Dataset {dataset_id} added to collection")
105
+ result = collection.get(ids=[dataset_id], include=["embeddings"])
106
+ except Exception as e:
107
+ logger.error(
108
+ f"Error adding dataset {dataset_id} to collection: {str(e)}"
109
+ )
110
+ raise HTTPException(status_code=404, detail="Dataset not found") from e
111
 
112
  embedding = result["embeddings"][0]
113