import logging from contextlib import asynccontextmanager from typing import List, Optional import chromadb from cashews import cache from fastapi import FastAPI, HTTPException, Query from httpx import AsyncClient from huggingface_hub import DatasetCard from pydantic import BaseModel from starlette.responses import RedirectResponse from load_data import get_embedding_function, get_save_path, refresh_data # Set up logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) # Set up caching cache.setup("mem://?check_interval=10&size=1000") # Initialize Chroma client SAVE_PATH = get_save_path() client = chromadb.PersistentClient(path=SAVE_PATH) collection = None async_client = AsyncClient( follow_redirects=True, ) class QueryResult(BaseModel): dataset_id: str similarity: float class QueryResponse(BaseModel): results: List[QueryResult] @asynccontextmanager async def lifespan(app: FastAPI): global collection # Startup: refresh data and initialize collection logger.info("Starting up the application") try: # Create or get the collection embedding_function = get_embedding_function() collection = client.get_or_create_collection( name="dataset_cards", embedding_function=embedding_function ) logger.info("Collection initialized successfully") # Refresh data refresh_data() logger.info("Data refresh completed successfully") except Exception as e: logger.error(f"Error during startup: {str(e)}") raise yield # Here the app is running and handling requests # Shutdown: perform any cleanup logger.info("Shutting down the application") # Add any cleanup code here if needed app = FastAPI(lifespan=lifespan) @app.get("/", include_in_schema=False) def root(): return RedirectResponse(url="/docs") async def try_get_card(hub_id: str) -> Optional[str]: try: response = await async_client.get( f"https://huggingface.co/datasets/{hub_id}/raw/main/README.md" ) if response.status_code == 200: card = DatasetCard(response.text) return card.text except Exception as e: logger.error(f"Error fetching card for hub_id {hub_id}: {str(e)}") return None @app.get("/similar", response_model=QueryResponse) @cache(ttl="1h") async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)): try: logger.info(f"Querying dataset: {dataset_id}") # Get the embedding for the given dataset_id result = collection.get(ids=[dataset_id], include=["embeddings"]) if not result.get("embeddings"): logger.info(f"Dataset not found: {dataset_id}") try: embedding_function = get_embedding_function() card = await try_get_card(dataset_id) if card is None: return QueryResponse(message="No dataset card available for recommendations.") embeddings = embedding_function(card) collection.upsert(ids=[dataset_id], embeddings=embeddings[0]) logger.info(f"Dataset {dataset_id} added to collection") result = collection.get(ids=[dataset_id], include=["embeddings"]) except Exception as e: logger.error( f"Error adding dataset {dataset_id} to collection: {str(e)}" ) return QueryResponse(message="No dataset card available for recommendations.") embedding = result["embeddings"][0] # Query the collection for similar datasets query_result = collection.query( query_embeddings=[embedding], n_results=n, include=["distances"] ) if not query_result["ids"]: logger.info(f"No similar datasets found for: {dataset_id}") return QueryResponse(message="No similar datasets found.") # Prepare the response results = [ QueryResult(dataset_id=id, similarity=1 - distance) for id, distance in zip( query_result["ids"][0], query_result["distances"][0] ) ] logger.info(f"Found {len(results)} similar datasets for: {dataset_id}") return QueryResponse(results=results) except Exception as e: logger.error(f"Error querying dataset {dataset_id}: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)