import json import logging import sqlite3 from contextlib import asynccontextmanager from typing import List import numpy as np from cashews import NOT_NONE, cache from fastapi import FastAPI, HTTPException, Query from pandas import Timestamp from pydantic import BaseModel from starlette.responses import RedirectResponse from data_loader import refresh_data cache.setup("mem://?check_interval=10&size=10000") logger = logging.getLogger(__name__) def get_db_connection(): conn = sqlite3.connect("datasets.db") conn.row_factory = sqlite3.Row conn.execute("PRAGMA journal_mode = WAL") conn.execute("PRAGMA synchronous = NORMAL") return conn def setup_database(): conn = get_db_connection() c = conn.cursor() c.execute( """CREATE TABLE IF NOT EXISTS datasets (hub_id TEXT PRIMARY KEY, likes INTEGER, downloads INTEGER, tags JSON, created_at INTEGER, last_modified INTEGER, license JSON, language JSON, config_name TEXT, column_names JSON, features JSON)""" ) c.execute( """ CREATE INDEX IF NOT EXISTS idx_column_names ON datasets(column_names) """ ) c.execute( """ CREATE INDEX IF NOT EXISTS idx_downloads_likes ON datasets(downloads DESC, likes DESC) """ ) conn.commit() c.execute("ANALYZE") conn.close() def serialize_numpy(obj): if isinstance(obj, np.ndarray): return obj.tolist() if isinstance(obj, np.integer): return int(obj) if isinstance(obj, np.floating): return float(obj) if isinstance(obj, Timestamp): return int(obj.timestamp()) logger.error(f"Object of type {type(obj)} is not JSON serializable") raise TypeError(f"Object of type {type(obj)} is not JSON serializable") @asynccontextmanager async def lifespan(app: FastAPI): setup_database() logger.info("Creating database connection") conn = get_db_connection() logger.info("Refreshing data") datasets = refresh_data() c = conn.cursor() c.executemany( """ INSERT OR REPLACE INTO datasets (hub_id, likes, downloads, tags, created_at, last_modified, license, language, config_name, column_names, features) VALUES (?, ?, ?, json(?), ?, ?, json(?), json(?), ?, json(?), json(?)) """, [ ( data["hub_id"], data.get("likes", 0), data.get("downloads", 0), json.dumps(data.get("tags", []), default=serialize_numpy), int(data["created_at"].timestamp()) if isinstance(data["created_at"], Timestamp) else data.get("created_at", 0), int(data["last_modified"].timestamp()) if isinstance(data["last_modified"], Timestamp) else data.get("last_modified", 0), json.dumps(data.get("license", []), default=serialize_numpy), json.dumps(data.get("language", []), default=serialize_numpy), data.get("config_name", ""), json.dumps(data.get("column_names", []), default=serialize_numpy), json.dumps(data.get("features", []), default=serialize_numpy), ) for data in datasets ], ) conn.commit() conn.close() logger.info("Data refreshed") yield app = FastAPI(lifespan=lifespan) @app.get("/", include_in_schema=False) def root(): return RedirectResponse(url="/docs") class SearchResponse(BaseModel): total: int page: int page_size: int results: List[dict] @cache(ttl="1h", condition=NOT_NONE) @app.get("/search", response_model=SearchResponse) async def search_datasets( columns: List[str] = Query(...), match_all: bool = Query(False), page: int = Query(1, ge=1), page_size: int = Query(10, ge=1, le=1000), ): offset = (page - 1) * page_size conn = get_db_connection() c = conn.cursor() try: if match_all: query = """ SELECT *, ( SELECT COUNT(*) FROM json_each(column_names) WHERE json_each.value IN ({}) ) as match_count FROM datasets WHERE match_count = ? ORDER BY downloads DESC, likes DESC LIMIT ? OFFSET ? """.format(",".join("?" * len(columns))) c.execute(query, (*columns, len(columns), page_size, offset)) else: query = """ SELECT * FROM datasets WHERE EXISTS ( SELECT 1 FROM json_each(column_names) WHERE json_each.value IN ({}) ) ORDER BY downloads DESC, likes DESC LIMIT ? OFFSET ? """.format(",".join("?" * len(columns))) c.execute(query, (*columns, page_size, offset)) results = [dict(row) for row in c.fetchall()] # Get total count if match_all: count_query = """ SELECT COUNT(*) as total FROM datasets WHERE ( SELECT COUNT(*) FROM json_each(column_names) WHERE json_each.value IN ({}) ) = ? """.format(",".join("?" * len(columns))) c.execute(count_query, (*columns, len(columns))) else: count_query = """ SELECT COUNT(*) as total FROM datasets WHERE EXISTS ( SELECT 1 FROM json_each(column_names) WHERE json_each.value IN ({}) ) """.format(",".join("?" * len(columns))) c.execute(count_query, columns) total = c.fetchone()["total"] for result in results: result["tags"] = json.loads(result["tags"]) result["license"] = json.loads(result["license"]) result["language"] = json.loads(result["language"]) result["column_names"] = json.loads(result["column_names"]) result["features"] = json.loads(result["features"]) return SearchResponse( total=total, page=page, page_size=page_size, results=results ) except sqlite3.Error as e: logger.error(f"Database error: {str(e)}") raise HTTPException(status_code=500, detail=f"Database error: {str(e)}") from e finally: conn.close() if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)