Spaces:
Runtime error
Runtime error
import json | |
import logging | |
import sqlite3 | |
from contextlib import asynccontextmanager | |
from typing import List | |
import numpy as np | |
from fastapi import FastAPI, HTTPException, Query | |
from pandas import Timestamp | |
from pydantic import BaseModel | |
from starlette.responses import RedirectResponse | |
from data_loader import refresh_data | |
logger = logging.getLogger(__name__) | |
def get_db_connection(): | |
conn = sqlite3.connect("datasets.db") | |
conn.row_factory = sqlite3.Row | |
return conn | |
def setup_database(): | |
conn = get_db_connection() | |
c = conn.cursor() | |
c.execute( | |
"""CREATE TABLE IF NOT EXISTS datasets | |
(hub_id TEXT PRIMARY KEY, | |
likes INTEGER, | |
downloads INTEGER, | |
tags TEXT, | |
created_at INTEGER, | |
last_modified INTEGER, | |
license TEXT, | |
language TEXT, | |
config_name TEXT, | |
column_names TEXT, | |
features TEXT)""" | |
) | |
c.execute("CREATE INDEX IF NOT EXISTS idx_column_names ON datasets (column_names)") | |
conn.commit() | |
conn.close() | |
def serialize_numpy(obj): | |
if isinstance(obj, np.ndarray): | |
return obj.tolist() | |
if isinstance(obj, np.integer): | |
return int(obj) | |
if isinstance(obj, np.floating): | |
return float(obj) | |
if isinstance(obj, Timestamp): | |
return int(obj.timestamp()) | |
logger.error(f"Object of type {type(obj)} is not JSON serializable") | |
raise TypeError(f"Object of type {type(obj)} is not JSON serializable") | |
def insert_data(conn, data): | |
c = conn.cursor() | |
created_at = data.get("created_at", 0) | |
if isinstance(created_at, Timestamp): | |
created_at = int(created_at.timestamp()) | |
last_modified = data.get("last_modified", 0) | |
if isinstance(last_modified, Timestamp): | |
last_modified = int(last_modified.timestamp()) | |
c.execute( | |
""" | |
INSERT OR REPLACE INTO datasets | |
(hub_id, likes, downloads, tags, created_at, last_modified, license, language, config_name, column_names, features) | |
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) | |
""", | |
( | |
data["hub_id"], | |
data.get("likes", 0), | |
data.get("downloads", 0), | |
json.dumps(data.get("tags", []), default=serialize_numpy), | |
created_at, | |
last_modified, | |
json.dumps(data.get("license", []), default=serialize_numpy), | |
json.dumps(data.get("language", []), default=serialize_numpy), | |
data.get("config_name", ""), | |
json.dumps(data.get("column_names", []), default=serialize_numpy), | |
json.dumps(data.get("features", []), default=serialize_numpy), | |
), | |
) | |
conn.commit() | |
async def lifespan(app: FastAPI): | |
# Startup: Load data into the database | |
setup_database() | |
logger.info("Creating database connection") | |
conn = get_db_connection() | |
logger.info("Refreshing data") | |
datasets = refresh_data() | |
for data in datasets: | |
insert_data(conn, data) | |
conn.close() | |
logger.info("Data refreshed") | |
yield | |
# Shutdown: You can add any cleanup operations here if needed | |
# For example, closing database connections, clearing caches, etc. | |
app = FastAPI(lifespan=lifespan) | |
def root(): | |
return RedirectResponse(url="/docs") | |
class SearchResponse(BaseModel): | |
total: int | |
page: int | |
page_size: int | |
results: List[dict] | |
async def search_datasets( | |
columns: List[str] = Query(...), | |
match_all: bool = Query(False), | |
page: int = Query(1, ge=1), | |
page_size: int = Query(10, ge=1, le=1000), | |
): | |
offset = (page - 1) * page_size | |
conn = get_db_connection() | |
c = conn.cursor() | |
try: | |
if match_all: | |
query = """ | |
SELECT COUNT(*) as total FROM datasets | |
WHERE (SELECT COUNT(*) FROM json_each(column_names) | |
WHERE value IN ({})) = ? | |
""".format(",".join("?" * len(columns))) | |
c.execute(query, (*columns, len(columns))) | |
else: | |
query = """ | |
SELECT COUNT(*) as total FROM datasets | |
WHERE EXISTS ( | |
SELECT 1 FROM json_each(column_names) | |
WHERE value IN ({}) | |
) | |
""".format(",".join("?" * len(columns))) | |
c.execute(query, columns) | |
total = c.fetchone()["total"] | |
if match_all: | |
query = """ | |
SELECT * FROM datasets | |
WHERE (SELECT COUNT(*) FROM json_each(column_names) | |
WHERE value IN ({})) = ? | |
LIMIT ? OFFSET ? | |
""".format(",".join("?" * len(columns))) | |
c.execute(query, (*columns, len(columns), page_size, offset)) | |
else: | |
query = """ | |
SELECT * FROM datasets | |
WHERE EXISTS ( | |
SELECT 1 FROM json_each(column_names) | |
WHERE value IN ({}) | |
) | |
LIMIT ? OFFSET ? | |
""".format(",".join("?" * len(columns))) | |
c.execute(query, (*columns, page_size, offset)) | |
results = [dict(row) for row in c.fetchall()] | |
for result in results: | |
result["tags"] = json.loads(result["tags"]) | |
result["license"] = json.loads(result["license"]) | |
result["language"] = json.loads(result["language"]) | |
result["column_names"] = json.loads(result["column_names"]) | |
result["features"] = json.loads(result["features"]) | |
return SearchResponse( | |
total=total, page=page, page_size=page_size, results=results | |
) | |
except sqlite3.Error as e: | |
logger.error(f"Database error: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Database error: {str(e)}") from e | |
finally: | |
conn.close() | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |