import logging from contextlib import asynccontextmanager from typing import List, Optional import chromadb from cashews import cache from fastapi import FastAPI, HTTPException, Query from pydantic import BaseModel from starlette.responses import RedirectResponse from httpx import AsyncClient from load_data import get_embedding_function, get_save_path, refresh_data from huggingface_hub import DatasetCard # Set up logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) # Set up caching cache.setup("mem://?check_interval=10&size=1000") # Initialize Chroma client SAVE_PATH = get_save_path() client = chromadb.PersistentClient(path=SAVE_PATH) collection = None async_client = AsyncClient( follow_redirects=True, ) class QueryResult(BaseModel): dataset_id: str similarity: float class QueryResponse(BaseModel): results: List[QueryResult] @asynccontextmanager async def lifespan(app: FastAPI): global collection # Startup: refresh data and initialize collection logger.info("Starting up the application") try: # Create or get the collection embedding_function = get_embedding_function() collection = client.get_or_create_collection( name="dataset_cards", embedding_function=embedding_function ) logger.info("Collection initialized successfully") # Refresh data refresh_data() logger.info("Data refresh completed successfully") except Exception as e: logger.error(f"Error during startup: {str(e)}") raise yield # Here the app is running and handling requests # Shutdown: perform any cleanup logger.info("Shutting down the application") # Add any cleanup code here if needed app = FastAPI(lifespan=lifespan) @app.get("/", include_in_schema=False) def root(): return RedirectResponse(url="/docs") async def try_get_card(hub_id: str) -> Optional[str]: try: response = await async_client.get( f"https://huggingface.co/datasets/{hub_id}/raw/main/README.md" ) if response.status_code == 200: card = DatasetCard(response.text) return card.text except Exception as e: logger.error(f"Error fetching card for hub_id {hub_id}: {str(e)}") return None @app.get("/similar", response_model=Optional[QueryResponse]) @cache(ttl="1h") async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)): try: logger.info(f"Querying dataset: {dataset_id}") # Get the embedding for the given dataset_id result = collection.get(ids=[dataset_id], include=["embeddings"]) if not result.get("embeddings"): logger.info(f"Dataset not found: {dataset_id}") try: embedding_function = get_embedding_function() card = await try_get_card(dataset_id) embeddings = embedding_function(card) collection.upsert(ids=[dataset_id], embeddings=embeddings[0]) logger.info(f"Dataset {dataset_id} added to collection") result = collection.get(ids=[dataset_id], include=["embeddings"]) except Exception as e: logger.error( f"Error adding dataset {dataset_id} to collection: {str(e)}" ) raise HTTPException(status_code=404, detail="Dataset not found") from e embedding = result["embeddings"][0] # Query the collection for similar datasets query_result = collection.query( query_embeddings=[embedding], n_results=n, include=["distances"] ) if not query_result["ids"]: logger.info(f"No similar datasets found for: {dataset_id}") return None # Prepare the response results = [ QueryResult(dataset_id=id, similarity=1 - distance) for id, distance in zip( query_result["ids"][0], query_result["distances"][0] ) ] logger.info(f"Found {len(results)} similar datasets for: {dataset_id}") return QueryResponse(results=results) except Exception as e: logger.error(f"Error querying dataset {dataset_id}: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) from e if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)