Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import logging | |
from contextlib import asynccontextmanager | |
from typing import List, Optional | |
import chromadb | |
from cashews import cache | |
from fastapi import FastAPI, HTTPException, Query | |
from pydantic import BaseModel | |
from starlette.responses import RedirectResponse | |
from httpx import AsyncClient | |
from load_data import get_embedding_function, get_save_path, refresh_data | |
from huggingface_hub import DatasetCard | |
# Set up logging | |
logging.basicConfig( | |
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" | |
) | |
logger = logging.getLogger(__name__) | |
# Set up caching | |
cache.setup("mem://?check_interval=10&size=1000") | |
# Initialize Chroma client | |
SAVE_PATH = get_save_path() | |
client = chromadb.PersistentClient(path=SAVE_PATH) | |
collection = None | |
async_client = AsyncClient( | |
follow_redirects=True, | |
) | |
class QueryResult(BaseModel): | |
dataset_id: str | |
similarity: float | |
class QueryResponse(BaseModel): | |
results: List[QueryResult] | |
async def lifespan(app: FastAPI): | |
global collection | |
# Startup: refresh data and initialize collection | |
logger.info("Starting up the application") | |
try: | |
# Create or get the collection | |
embedding_function = get_embedding_function() | |
collection = client.get_or_create_collection( | |
name="dataset_cards", embedding_function=embedding_function | |
) | |
logger.info("Collection initialized successfully") | |
# Refresh data | |
refresh_data() | |
logger.info("Data refresh completed successfully") | |
except Exception as e: | |
logger.error(f"Error during startup: {str(e)}") | |
raise | |
yield # Here the app is running and handling requests | |
# Shutdown: perform any cleanup | |
logger.info("Shutting down the application") | |
# Add any cleanup code here if needed | |
app = FastAPI(lifespan=lifespan) | |
def root(): | |
return RedirectResponse(url="/docs") | |
async def try_get_card(hub_id: str) -> Optional[str]: | |
try: | |
response = await async_client.get( | |
f"https://huggingface.co/datasets/{hub_id}/raw/main/README.md" | |
) | |
if response.status_code == 200: | |
card = DatasetCard(response.text) | |
return card.text | |
except Exception as e: | |
logger.error(f"Error fetching card for hub_id {hub_id}: {str(e)}") | |
return None | |
async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)): | |
try: | |
logger.info(f"Querying dataset: {dataset_id}") | |
# Get the embedding for the given dataset_id | |
result = collection.get(ids=[dataset_id], include=["embeddings"]) | |
if not result.get("embeddings"): | |
logger.info(f"Dataset not found: {dataset_id}") | |
try: | |
embedding_function = get_embedding_function() | |
card = await try_get_card(dataset_id) | |
embeddings = embedding_function(card) | |
collection.upsert(ids=[dataset_id], embeddings=embeddings[0]) | |
logger.info(f"Dataset {dataset_id} added to collection") | |
result = collection.get(ids=[dataset_id], include=["embeddings"]) | |
except Exception as e: | |
logger.error( | |
f"Error adding dataset {dataset_id} to collection: {str(e)}" | |
) | |
raise HTTPException(status_code=404, detail="Dataset not found") from e | |
embedding = result["embeddings"][0] | |
# Query the collection for similar datasets | |
query_result = collection.query( | |
query_embeddings=[embedding], n_results=n, include=["distances"] | |
) | |
if not query_result["ids"]: | |
logger.info(f"No similar datasets found for: {dataset_id}") | |
return None | |
# Prepare the response | |
results = [ | |
QueryResult(dataset_id=id, similarity=1 - distance) | |
for id, distance in zip( | |
query_result["ids"][0], query_result["distances"][0] | |
) | |
] | |
logger.info(f"Found {len(results)} similar datasets for: {dataset_id}") | |
return QueryResponse(results=results) | |
except Exception as e: | |
logger.error(f"Error querying dataset {dataset_id}: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) from e | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |