Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Commit
•
4d185df
1
Parent(s):
a5c714c
try adding missing cards
Browse files
main.py
CHANGED
@@ -7,8 +7,9 @@ from cashews import cache
|
|
7 |
from fastapi import FastAPI, HTTPException, Query
|
8 |
from pydantic import BaseModel
|
9 |
from starlette.responses import RedirectResponse
|
10 |
-
|
11 |
from load_data import get_embedding_function, get_save_path, refresh_data
|
|
|
12 |
|
13 |
# Set up logging
|
14 |
logging.basicConfig(
|
@@ -24,6 +25,10 @@ SAVE_PATH = get_save_path()
|
|
24 |
client = chromadb.PersistentClient(path=SAVE_PATH)
|
25 |
collection = None
|
26 |
|
|
|
|
|
|
|
|
|
27 |
|
28 |
class QueryResult(BaseModel):
|
29 |
dataset_id: str
|
@@ -69,6 +74,19 @@ def root():
|
|
69 |
return RedirectResponse(url="/docs")
|
70 |
|
71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
@app.get("/query", response_model=Optional[QueryResponse])
|
73 |
@cache(ttl="1h")
|
74 |
async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)):
|
@@ -76,10 +94,20 @@ async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le
|
|
76 |
logger.info(f"Querying dataset: {dataset_id}")
|
77 |
# Get the embedding for the given dataset_id
|
78 |
result = collection.get(ids=[dataset_id], include=["embeddings"])
|
79 |
-
|
80 |
-
if not result["embeddings"]:
|
81 |
logger.info(f"Dataset not found: {dataset_id}")
|
82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
embedding = result["embeddings"][0]
|
85 |
|
|
|
7 |
from fastapi import FastAPI, HTTPException, Query
|
8 |
from pydantic import BaseModel
|
9 |
from starlette.responses import RedirectResponse
|
10 |
+
from httpx import AsyncClient
|
11 |
from load_data import get_embedding_function, get_save_path, refresh_data
|
12 |
+
from huggingface_hub import DatasetCard
|
13 |
|
14 |
# Set up logging
|
15 |
logging.basicConfig(
|
|
|
25 |
client = chromadb.PersistentClient(path=SAVE_PATH)
|
26 |
collection = None
|
27 |
|
28 |
+
async_client = AsyncClient(
|
29 |
+
follow_redirects=True,
|
30 |
+
)
|
31 |
+
|
32 |
|
33 |
class QueryResult(BaseModel):
|
34 |
dataset_id: str
|
|
|
74 |
return RedirectResponse(url="/docs")
|
75 |
|
76 |
|
77 |
+
async def try_get_card(hub_id: str) -> Optional[str]:
|
78 |
+
try:
|
79 |
+
response = await async_client.get(
|
80 |
+
f"https://huggingface.co/datasets/{hub_id}/raw/main/README.md"
|
81 |
+
)
|
82 |
+
if response.status_code == 200:
|
83 |
+
card = DatasetCard(response.text)
|
84 |
+
return card.text
|
85 |
+
except Exception as e:
|
86 |
+
logger.error(f"Error fetching card for hub_id {hub_id}: {str(e)}")
|
87 |
+
return None
|
88 |
+
|
89 |
+
|
90 |
@app.get("/query", response_model=Optional[QueryResponse])
|
91 |
@cache(ttl="1h")
|
92 |
async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)):
|
|
|
94 |
logger.info(f"Querying dataset: {dataset_id}")
|
95 |
# Get the embedding for the given dataset_id
|
96 |
result = collection.get(ids=[dataset_id], include=["embeddings"])
|
97 |
+
if not result.get("embeddings"):
|
|
|
98 |
logger.info(f"Dataset not found: {dataset_id}")
|
99 |
+
try:
|
100 |
+
embedding_function = get_embedding_function()
|
101 |
+
card = await try_get_card(dataset_id)
|
102 |
+
embeddings = embedding_function(card)
|
103 |
+
collection.upsert(ids=[dataset_id], embeddings=embeddings[0])
|
104 |
+
logger.info(f"Dataset {dataset_id} added to collection")
|
105 |
+
result = collection.get(ids=[dataset_id], include=["embeddings"])
|
106 |
+
except Exception as e:
|
107 |
+
logger.error(
|
108 |
+
f"Error adding dataset {dataset_id} to collection: {str(e)}"
|
109 |
+
)
|
110 |
+
raise HTTPException(status_code=404, detail="Dataset not found") from e
|
111 |
|
112 |
embedding = result["embeddings"][0]
|
113 |
|