nickmuchi's picture
Create callback.py
7fa9a42
raw
history blame
1.06 kB
"""Callback handlers used in the app."""
from typing import Any, Dict, List
from langchain.callbacks.base import AsyncCallbackHandler
from schemas import ChatResponse
class StreamingLLMCallbackHandler(AsyncCallbackHandler):
"""Callback handler for streaming LLM responses."""
def __init__(self, websocket):
self.websocket = websocket
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
resp = ChatResponse(sender="bot", message=token, type="stream")
await self.websocket.send_json(resp.dict())
class QuestionGenCallbackHandler(AsyncCallbackHandler):
"""Callback handler for question generation."""
def __init__(self, websocket):
self.websocket = websocket
async def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts running."""
resp = ChatResponse(
sender="bot", message="Synthesizing question...", type="info"
)
await self.websocket.send_json(resp.dict())