nikhil_no_persistent / lilac /router_signal.py
nsthorat-lilac's picture
Duplicate from lilacai/nikhil_staging
bfc0ec6
raw
history blame
3.2 kB
"""Router for the signal registry."""
import math
from typing import Annotated, Any, Optional
from fastapi import APIRouter, Depends
from pydantic import BaseModel, validator
from .auth import UserInfo, get_session_user
from .router_utils import RouteErrorHandler, server_compute_concept
from .schema import Field, SignalInputType
from .signal import SIGNAL_REGISTRY, Signal, TextEmbeddingSignal, resolve_signal
from .signals.concept_scorer import ConceptSignal
router = APIRouter(route_class=RouteErrorHandler)
EMBEDDING_SORT_PRIORITIES = ['gte-small', 'gte-base', 'openai', 'sbert']
class SignalInfo(BaseModel):
"""Information about a signal."""
name: str
input_type: SignalInputType
json_schema: dict[str, Any]
@router.get('/', response_model_exclude_none=True)
def get_signals() -> list[SignalInfo]:
"""List the signals."""
return [
SignalInfo(name=s.name, input_type=s.input_type, json_schema=s.schema())
for s in SIGNAL_REGISTRY.values()
if not issubclass(s, TextEmbeddingSignal)
]
@router.get('/embeddings', response_model_exclude_none=True)
def get_embeddings() -> list[SignalInfo]:
"""List the embeddings."""
embedding_infos = [
SignalInfo(name=s.name, input_type=s.input_type, json_schema=s.schema())
for s in SIGNAL_REGISTRY.values()
if issubclass(s, TextEmbeddingSignal)
]
# Sort the embedding infos by priority.
embedding_infos = sorted(
embedding_infos,
key=lambda s: EMBEDDING_SORT_PRIORITIES.index(s.name)
if s.name in EMBEDDING_SORT_PRIORITIES else math.inf)
return embedding_infos
class SignalComputeOptions(BaseModel):
"""The request for the standalone compute signal endpoint."""
signal: Signal
# The inputs to compute.
inputs: list[str]
@validator('signal', pre=True)
def parse_signal(cls, signal: dict) -> Signal:
"""Parse a signal to its specific subclass instance."""
return resolve_signal(signal)
class SignalComputeResponse(BaseModel):
"""The response for the standalone compute signal endpoint."""
items: list[Optional[Any]]
@router.post('/compute', response_model_exclude_none=True)
def compute(
options: SignalComputeOptions,
user: Annotated[Optional[UserInfo], Depends(get_session_user)]) -> SignalComputeResponse:
"""Compute a signal over a set of inputs."""
signal = options.signal
if isinstance(signal, ConceptSignal):
result = server_compute_concept(signal, options.inputs, user)
else:
signal.setup()
result = list(signal.compute(options.inputs))
return SignalComputeResponse(items=result)
class SignalSchemaOptions(BaseModel):
"""The request for the signal schema endpoint."""
signal: Signal
@validator('signal', pre=True)
def parse_signal(cls, signal: dict) -> Signal:
"""Parse a signal to its specific subclass instance."""
return resolve_signal(signal)
class SignalSchemaResponse(BaseModel):
"""The response for the signal schema endpoint."""
fields: Field
@router.post('/schema', response_model_exclude_none=True)
def schema(options: SignalSchemaOptions) -> SignalSchemaResponse:
"""Get the schema for a signal."""
signal = options.signal
return SignalSchemaResponse(fields=signal.fields())