Spaces:
Runtime error
Runtime error
"""Serves the Lilac server.""" | |
import logging | |
import os | |
import shutil | |
from typing import Any | |
from fastapi import APIRouter, FastAPI | |
from fastapi.responses import ORJSONResponse | |
from fastapi.routing import APIRoute | |
from fastapi.staticfiles import StaticFiles | |
from huggingface_hub import snapshot_download | |
from . import router_concept, router_data_loader, router_dataset, router_signal, router_tasks | |
from .config import CONFIG, data_path | |
from .router_utils import RouteErrorHandler | |
from .tasks import task_manager | |
from .utils import get_datasets_dir | |
DIST_PATH = os.path.abspath(os.path.join('web', 'blueprint', 'build')) | |
tags_metadata: list[dict[str, Any]] = [{ | |
'name': 'datasets', | |
'description': 'API for querying a dataset.', | |
}, { | |
'name': 'concepts', | |
'description': 'API for managing concepts.', | |
}, { | |
'name': 'data_loaders', | |
'description': 'API for loading data.', | |
}, { | |
'name': 'signals', | |
'description': 'API for managing signals.', | |
}] | |
def custom_generate_unique_id(route: APIRoute) -> str: | |
"""Generate the name for the API endpoint.""" | |
return route.name | |
app = FastAPI( | |
default_response_class=ORJSONResponse, | |
generate_unique_id_function=custom_generate_unique_id, | |
openapi_tags=tags_metadata) | |
v1_router = APIRouter(route_class=RouteErrorHandler) | |
v1_router.include_router(router_dataset.router, prefix='/datasets', tags=['datasets']) | |
v1_router.include_router(router_concept.router, prefix='/concepts', tags=['concepts']) | |
v1_router.include_router(router_data_loader.router, prefix='/data_loaders', tags=['data_loaders']) | |
v1_router.include_router(router_signal.router, prefix='/signals', tags=['signals']) | |
v1_router.include_router(router_tasks.router, prefix='/tasks', tags=['tasks']) | |
app.include_router(v1_router, prefix='/api/v1') | |
# Serve static files in production mode. | |
app.mount('/', StaticFiles(directory=os.path.join(DIST_PATH), html=True, check_dir=False)) | |
def startup() -> None: | |
"""Download dataset files from the HF space that is uploaded before building the image.""" | |
# Setup. | |
repo_id = 'lilacai/nikhil_staging' | |
# Download the huggingface space data. This includes code and datasets, so we move the datasets | |
# alone to the data directory. | |
tmpdir = os.path.join(data_path(), '.hf-spaces', repo_id) | |
snapshot_download( | |
repo_id=repo_id, | |
repo_type='space', | |
local_dir=tmpdir, | |
local_dir_use_symlinks=False, | |
token=CONFIG['HF_ACCESS_TOKEN']) | |
shutil.move(get_datasets_dir(tmpdir), os.path.join(get_datasets_dir(data_path()), '..')) | |
async def shutdown_event() -> None: | |
"""Kill the task manager when FastAPI shuts down.""" | |
await task_manager().stop() | |
class GetTasksFilter(logging.Filter): | |
"""Task filter for /tasks.""" | |
def filter(self, record: logging.LogRecord) -> bool: | |
"""Filters out /api/v1/tasks/ from the logs.""" | |
return record.getMessage().find('/api/v1/tasks/') == -1 | |
logging.getLogger('uvicorn.access').addFilter(GetTasksFilter()) | |