Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Commit
•
b5f94b5
1
Parent(s):
1d74113
refactor
Browse files
main.py
CHANGED
@@ -16,8 +16,9 @@ from starlette.status import (
|
|
16 |
HTTP_500_INTERNAL_SERVER_ERROR,
|
17 |
)
|
18 |
|
19 |
-
from load_card_data import
|
20 |
from load_viewer_data import refresh_viewer_data
|
|
|
21 |
|
22 |
# Set up logging
|
23 |
logging.basicConfig(
|
@@ -31,7 +32,7 @@ cache.setup("mem://?check_interval=10&size=1000")
|
|
31 |
# Initialize Chroma client
|
32 |
SAVE_PATH = get_save_path()
|
33 |
client = chromadb.PersistentClient(path=SAVE_PATH)
|
34 |
-
|
35 |
|
36 |
async_client = AsyncClient(
|
37 |
follow_redirects=True,
|
@@ -40,33 +41,20 @@ async_client = AsyncClient(
|
|
40 |
|
41 |
@asynccontextmanager
|
42 |
async def lifespan(app: FastAPI):
|
43 |
-
global collection
|
44 |
# Startup: refresh data and initialize collection
|
45 |
logger.info("Starting up the application")
|
46 |
try:
|
47 |
-
# Create or get the collection
|
48 |
-
logger.info("Initializing embedding function")
|
49 |
-
embedding_function = get_embedding_function()
|
50 |
-
logger.info("Creating or getting collection")
|
51 |
-
collection = client.get_or_create_collection(
|
52 |
-
name="dataset_cards", embedding_function=embedding_function
|
53 |
-
)
|
54 |
-
logger.info("Collection initialized successfully")
|
55 |
-
|
56 |
# Refresh data
|
57 |
logger.info("Starting refresh of card data")
|
58 |
refresh_card_data()
|
59 |
logger.info("Card data refresh completed")
|
60 |
-
|
61 |
logger.info("Starting refresh of viewer data")
|
62 |
await refresh_viewer_data()
|
63 |
logger.info("Viewer data refresh completed")
|
64 |
-
|
65 |
logger.info("Data refresh completed successfully")
|
66 |
except Exception as e:
|
67 |
logger.error(f"Error during startup: {str(e)}")
|
68 |
logger.warning("Application starting with potential data issues")
|
69 |
-
|
70 |
yield
|
71 |
|
72 |
# Shutdown: perform any cleanup
|
@@ -123,6 +111,8 @@ class DatasetNotForAllAudiencesError(HTTPException):
|
|
123 |
@app.get("/similar", response_model=QueryResponse)
|
124 |
@cache(ttl="1h")
|
125 |
async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)):
|
|
|
|
|
126 |
try:
|
127 |
logger.info(f"Querying dataset: {dataset_id}")
|
128 |
# Get the embedding for the given dataset_id
|
@@ -130,7 +120,6 @@ async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le
|
|
130 |
if not result.get("embeddings"):
|
131 |
logger.info(f"Dataset not found: {dataset_id}")
|
132 |
try:
|
133 |
-
embedding_function = get_embedding_function()
|
134 |
card = await try_get_card(dataset_id)
|
135 |
if card is None:
|
136 |
raise DatasetCardNotFoundError(dataset_id)
|
@@ -182,13 +171,13 @@ async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le
|
|
182 |
) from e
|
183 |
|
184 |
|
185 |
-
@app.
|
186 |
@cache(ttl="1h")
|
187 |
async def api_query_by_text(query: str, n: int = Query(default=10, ge=1, le=100)):
|
188 |
try:
|
189 |
logger.info(f"Querying datasets by text: {query}")
|
190 |
collection = client.get_collection(
|
191 |
-
name="dataset_cards", embedding_function=
|
192 |
)
|
193 |
print(query)
|
194 |
query_result = collection.query(
|
@@ -220,7 +209,7 @@ async def api_query_by_text(query: str, n: int = Query(default=10, ge=1, le=100)
|
|
220 |
) from e
|
221 |
|
222 |
|
223 |
-
@app.
|
224 |
@cache(ttl="1h")
|
225 |
async def api_search_viewer(query: str, n: int = Query(default=10, ge=1, le=100)):
|
226 |
try:
|
|
|
16 |
HTTP_500_INTERNAL_SERVER_ERROR,
|
17 |
)
|
18 |
|
19 |
+
from load_card_data import card_embedding_function, refresh_card_data
|
20 |
from load_viewer_data import refresh_viewer_data
|
21 |
+
from utils import get_save_path, get_collection
|
22 |
|
23 |
# Set up logging
|
24 |
logging.basicConfig(
|
|
|
32 |
# Initialize Chroma client
|
33 |
SAVE_PATH = get_save_path()
|
34 |
client = chromadb.PersistentClient(path=SAVE_PATH)
|
35 |
+
|
36 |
|
37 |
async_client = AsyncClient(
|
38 |
follow_redirects=True,
|
|
|
41 |
|
42 |
@asynccontextmanager
|
43 |
async def lifespan(app: FastAPI):
|
|
|
44 |
# Startup: refresh data and initialize collection
|
45 |
logger.info("Starting up the application")
|
46 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
# Refresh data
|
48 |
logger.info("Starting refresh of card data")
|
49 |
refresh_card_data()
|
50 |
logger.info("Card data refresh completed")
|
|
|
51 |
logger.info("Starting refresh of viewer data")
|
52 |
await refresh_viewer_data()
|
53 |
logger.info("Viewer data refresh completed")
|
|
|
54 |
logger.info("Data refresh completed successfully")
|
55 |
except Exception as e:
|
56 |
logger.error(f"Error during startup: {str(e)}")
|
57 |
logger.warning("Application starting with potential data issues")
|
|
|
58 |
yield
|
59 |
|
60 |
# Shutdown: perform any cleanup
|
|
|
111 |
@app.get("/similar", response_model=QueryResponse)
|
112 |
@cache(ttl="1h")
|
113 |
async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)):
|
114 |
+
embedding_function = card_embedding_function()
|
115 |
+
collection = get_collection(client, embedding_function, "dataset_cards")
|
116 |
try:
|
117 |
logger.info(f"Querying dataset: {dataset_id}")
|
118 |
# Get the embedding for the given dataset_id
|
|
|
120 |
if not result.get("embeddings"):
|
121 |
logger.info(f"Dataset not found: {dataset_id}")
|
122 |
try:
|
|
|
123 |
card = await try_get_card(dataset_id)
|
124 |
if card is None:
|
125 |
raise DatasetCardNotFoundError(dataset_id)
|
|
|
171 |
) from e
|
172 |
|
173 |
|
174 |
+
@app.get("/similar-text", response_model=QueryResponse)
|
175 |
@cache(ttl="1h")
|
176 |
async def api_query_by_text(query: str, n: int = Query(default=10, ge=1, le=100)):
|
177 |
try:
|
178 |
logger.info(f"Querying datasets by text: {query}")
|
179 |
collection = client.get_collection(
|
180 |
+
name="dataset_cards", embedding_function=card_embedding_function()
|
181 |
)
|
182 |
print(query)
|
183 |
query_result = collection.query(
|
|
|
209 |
) from e
|
210 |
|
211 |
|
212 |
+
@app.get("/search-viewer", response_model=QueryResponse)
|
213 |
@cache(ttl="1h")
|
214 |
async def api_search_viewer(query: str, n: int = Query(default=10, ge=1, le=100)):
|
215 |
try:
|