Spaces:
Runtime error
Runtime error
"""Serves the Lilac server.""" | |
import logging | |
import os | |
import subprocess | |
from typing import Any | |
from fastapi import APIRouter, FastAPI | |
from fastapi.responses import ORJSONResponse | |
from fastapi.routing import APIRoute | |
from fastapi.staticfiles import StaticFiles | |
from . import router_concept, router_data_loader, router_dataset, router_signal, router_tasks | |
from .config import CONFIG, data_path | |
from .router_utils import RouteErrorHandler | |
from .tasks import task_manager | |
DIST_PATH = os.path.abspath(os.path.join('web', 'blueprint', 'build')) | |
tags_metadata: list[dict[str, Any]] = [{ | |
'name': 'datasets', | |
'description': 'API for querying a dataset.', | |
}, { | |
'name': 'concepts', | |
'description': 'API for managing concepts.', | |
}, { | |
'name': 'data_loaders', | |
'description': 'API for loading data.', | |
}, { | |
'name': 'signals', | |
'description': 'API for managing signals.', | |
}] | |
def custom_generate_unique_id(route: APIRoute) -> str: | |
"""Generate the name for the API endpoint.""" | |
return route.name | |
app = FastAPI( | |
default_response_class=ORJSONResponse, | |
generate_unique_id_function=custom_generate_unique_id, | |
openapi_tags=tags_metadata) | |
v1_router = APIRouter(route_class=RouteErrorHandler) | |
v1_router.include_router(router_dataset.router, prefix='/datasets', tags=['datasets']) | |
v1_router.include_router(router_concept.router, prefix='/concepts', tags=['concepts']) | |
v1_router.include_router(router_data_loader.router, prefix='/data_loaders', tags=['data_loaders']) | |
v1_router.include_router(router_signal.router, prefix='/signals', tags=['signals']) | |
v1_router.include_router(router_tasks.router, prefix='/tasks', tags=['tasks']) | |
app.include_router(v1_router, prefix='/api/v1') | |
# Serve static files in production mode. | |
app.mount('/', StaticFiles(directory=os.path.join(DIST_PATH), html=True, check_dir=False)) | |
print('data_path=', data_path(), CONFIG) | |
print('list:', os.listdir(data_path())) | |
def run(cmd: str) -> subprocess.CompletedProcess[bytes]: | |
"""Run a command and return the result.""" | |
return subprocess.run(cmd, shell=True, check=True) | |
async def shutdown_event() -> None: | |
"""Kill the task manager when FastAPI shuts down.""" | |
await task_manager().stop() | |
class GetTasksFilter(logging.Filter): | |
"""Task filter for /tasks.""" | |
def filter(self, record: logging.LogRecord) -> bool: | |
"""Filters out /api/v1/tasks/ from the logs.""" | |
return record.getMessage().find('/api/v1/tasks/') == -1 | |
logging.getLogger('uvicorn.access').addFilter(GetTasksFilter()) | |