import json import logging import sqlite3 from contextlib import asynccontextmanager from typing import List import numpy as np from fastapi import FastAPI, HTTPException, Query from pandas import Timestamp from pydantic import BaseModel from starlette.responses import RedirectResponse from data_loader import refresh_data logger = logging.getLogger(__name__) def get_db_connection(): conn = sqlite3.connect("datasets.db") conn.row_factory = sqlite3.Row return conn def setup_database(): conn = get_db_connection() c = conn.cursor() c.execute( """CREATE TABLE IF NOT EXISTS datasets (hub_id TEXT PRIMARY KEY, likes INTEGER, downloads INTEGER, tags TEXT, created_at INTEGER, last_modified INTEGER, license TEXT, language TEXT, config_name TEXT, column_names TEXT, features TEXT)""" ) c.execute("CREATE INDEX IF NOT EXISTS idx_column_names ON datasets (column_names)") conn.commit() conn.close() def serialize_numpy(obj): if isinstance(obj, np.ndarray): return obj.tolist() if isinstance(obj, np.integer): return int(obj) if isinstance(obj, np.floating): return float(obj) if isinstance(obj, Timestamp): return int(obj.timestamp()) logger.error(f"Object of type {type(obj)} is not JSON serializable") raise TypeError(f"Object of type {type(obj)} is not JSON serializable") def insert_data(conn, data): c = conn.cursor() created_at = data.get("created_at", 0) if isinstance(created_at, Timestamp): created_at = int(created_at.timestamp()) last_modified = data.get("last_modified", 0) if isinstance(last_modified, Timestamp): last_modified = int(last_modified.timestamp()) c.execute( """ INSERT OR REPLACE INTO datasets (hub_id, likes, downloads, tags, created_at, last_modified, license, language, config_name, column_names, features) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( data["hub_id"], data.get("likes", 0), data.get("downloads", 0), json.dumps(data.get("tags", []), default=serialize_numpy), created_at, last_modified, json.dumps(data.get("license", []), default=serialize_numpy), json.dumps(data.get("language", []), default=serialize_numpy), data.get("config_name", ""), json.dumps(data.get("column_names", []), default=serialize_numpy), json.dumps(data.get("features", []), default=serialize_numpy), ), ) conn.commit() @asynccontextmanager async def lifespan(app: FastAPI): # Startup: Load data into the database setup_database() logger.info("Creating database connection") conn = get_db_connection() logger.info("Refreshing data") datasets = refresh_data() for data in datasets: insert_data(conn, data) conn.close() logger.info("Data refreshed") yield # Shutdown: You can add any cleanup operations here if needed # For example, closing database connections, clearing caches, etc. app = FastAPI(lifespan=lifespan) @app.get("/", include_in_schema=False) def root(): return RedirectResponse(url="/docs") class SearchResponse(BaseModel): total: int page: int page_size: int results: List[dict] @app.get("/search", response_model=SearchResponse) async def search_datasets( columns: List[str] = Query(...), match_all: bool = Query(False), page: int = Query(1, ge=1), page_size: int = Query(10, ge=1, le=1000), ): offset = (page - 1) * page_size conn = get_db_connection() c = conn.cursor() try: if match_all: query = """ SELECT COUNT(*) as total FROM datasets WHERE (SELECT COUNT(*) FROM json_each(column_names) WHERE value IN ({})) = ? """.format(",".join("?" * len(columns))) c.execute(query, (*columns, len(columns))) else: query = """ SELECT COUNT(*) as total FROM datasets WHERE EXISTS ( SELECT 1 FROM json_each(column_names) WHERE value IN ({}) ) """.format(",".join("?" * len(columns))) c.execute(query, columns) total = c.fetchone()["total"] if match_all: query = """ SELECT * FROM datasets WHERE (SELECT COUNT(*) FROM json_each(column_names) WHERE value IN ({})) = ? LIMIT ? OFFSET ? """.format(",".join("?" * len(columns))) c.execute(query, (*columns, len(columns), page_size, offset)) else: query = """ SELECT * FROM datasets WHERE EXISTS ( SELECT 1 FROM json_each(column_names) WHERE value IN ({}) ) LIMIT ? OFFSET ? """.format(",".join("?" * len(columns))) c.execute(query, (*columns, page_size, offset)) results = [dict(row) for row in c.fetchall()] for result in results: result["tags"] = json.loads(result["tags"]) result["license"] = json.loads(result["license"]) result["language"] = json.loads(result["language"]) result["column_names"] = json.loads(result["column_names"]) result["features"] = json.loads(result["features"]) return SearchResponse( total=total, page=page, page_size=page_size, results=results ) except sqlite3.Error as e: logger.error(f"Database error: {str(e)}") raise HTTPException(status_code=500, detail=f"Database error: {str(e)}") from e finally: conn.close() if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)