|
"""Callback handlers used in the app.""" |
|
from typing import Any, Dict, List |
|
|
|
from langchain.callbacks.base import AsyncCallbackHandler |
|
|
|
from schemas import ChatResponse |
|
|
|
|
|
class StreamingLLMCallbackHandler(AsyncCallbackHandler): |
|
"""Callback handler for streaming LLM responses.""" |
|
|
|
def __init__(self, websocket): |
|
self.websocket = websocket |
|
|
|
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: |
|
resp = ChatResponse(sender="bot", message=token, type="stream") |
|
await self.websocket.send_json(resp.dict()) |
|
|
|
|
|
class QuestionGenCallbackHandler(AsyncCallbackHandler): |
|
"""Callback handler for question generation.""" |
|
|
|
def __init__(self, websocket): |
|
self.websocket = websocket |
|
|
|
async def on_llm_start( |
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any |
|
) -> None: |
|
"""Run when LLM starts running.""" |
|
resp = ChatResponse( |
|
sender="bot", message="Synthesizing question...", type="info" |
|
) |
|
await self.websocket.send_json(resp.dict()) |