from typing import Optional, List from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException, Query from pydantic import BaseModel import chromadb import logging from load_data import get_save_path, refresh_data from cashews import cache # Set up logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) # Set up caching cache.setup("mem://?check_interval=10&size=10000") # Initialize Chroma client SAVE_PATH = get_save_path() client = chromadb.PersistentClient(path=SAVE_PATH) collection = client.get_collection("dataset_cards") class QueryResult(BaseModel): dataset_id: str similarity: float class QueryResponse(BaseModel): results: List[QueryResult] @asynccontextmanager async def lifespan(app: FastAPI): # Startup: refresh data logger.info("Starting up the application") try: refresh_data() logger.info("Data refresh completed successfully") except Exception as e: logger.error(f"Error during data refresh: {str(e)}") yield # Here the app is running and handling requests # Shutdown: perform any cleanup logger.info("Shutting down the application") # Add any cleanup code here if needed app = FastAPI(lifespan=lifespan) @app.get("/query", response_model=Optional[QueryResponse]) @cache(ttl="1h") async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)): try: logger.info(f"Querying dataset: {dataset_id}") # Get the embedding for the given dataset_id result = collection.get(ids=[dataset_id], include=["embeddings"]) if not result["embeddings"]: logger.info(f"Dataset not found: {dataset_id}") raise HTTPException(status_code=404, detail="Dataset not found") embedding = result["embeddings"][0] # Query the collection for similar datasets query_result = collection.query( query_embeddings=[embedding], n_results=n, include=["distances"] ) if not query_result["ids"]: logger.info(f"No similar datasets found for: {dataset_id}") return None # Prepare the response results = [ QueryResult(dataset_id=id, similarity=1 - distance) for id, distance in zip( query_result["ids"][0], query_result["distances"][0] ) ] logger.info(f"Found {len(results)} similar datasets for: {dataset_id}") return QueryResponse(results=results) except Exception as e: logger.error(f"Error querying dataset {dataset_id}: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)