import logging from contextlib import asynccontextmanager from typing import List, Optional import chromadb from cashews import cache from fastapi import FastAPI, HTTPException, Query from pydantic import BaseModel from starlette.responses import RedirectResponse from load_data import get_embedding_function, get_save_path, refresh_data # Set up logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) # Set up caching cache.setup("mem://?check_interval=10&size=10000") # Initialize Chroma client SAVE_PATH = get_save_path() client = chromadb.PersistentClient(path=SAVE_PATH) collection = None class QueryResult(BaseModel): dataset_id: str similarity: float class QueryResponse(BaseModel): results: List[QueryResult] @asynccontextmanager async def lifespan(app: FastAPI): global collection # Startup: refresh data and initialize collection logger.info("Starting up the application") try: # Create or get the collection embedding_function = get_embedding_function() collection = client.get_or_create_collection( name="dataset_cards", embedding_function=embedding_function ) logger.info("Collection initialized successfully") # Refresh data refresh_data() logger.info("Data refresh completed successfully") except Exception as e: logger.error(f"Error during startup: {str(e)}") raise yield # Here the app is running and handling requests # Shutdown: perform any cleanup logger.info("Shutting down the application") # Add any cleanup code here if needed app = FastAPI(lifespan=lifespan) @app.get("/", include_in_schema=False) def root(): return RedirectResponse(url="/docs") @app.get("/query", response_model=Optional[QueryResponse]) @cache(ttl="1h") async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)): try: logger.info(f"Querying dataset: {dataset_id}") # Get the embedding for the given dataset_id result = collection.get(ids=[dataset_id], include=["embeddings"]) if not result["embeddings"]: logger.info(f"Dataset not found: {dataset_id}") raise HTTPException(status_code=404, detail="Dataset not found") embedding = result["embeddings"][0] # Query the collection for similar datasets query_result = collection.query( query_embeddings=[embedding], n_results=n, include=["distances"] ) if not query_result["ids"]: logger.info(f"No similar datasets found for: {dataset_id}") return None # Prepare the response results = [ QueryResult(dataset_id=id, similarity=1 - distance) for id, distance in zip( query_result["ids"][0], query_result["distances"][0] ) ] logger.info(f"Found {len(results)} similar datasets for: {dataset_id}") return QueryResponse(results=results) except Exception as e: logger.error(f"Error querying dataset {dataset_id}: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) from e if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)