Asaad Almutareb
cleaned branch, added final streaming callback handler
fa99d8f
raw
history blame
10.1 kB
from app.schemas.message_schema import IChatResponse
from app.utils.adaptive_cards.cards import create_adaptive_card, create_image_card
from app.utils.chains import get_suggestions_questions
from langchain.callbacks.base import AsyncCallbackHandler
from app.utils.utils import generate_uuid
from fastapi import WebSocket
from uuid import UUID
from typing import Any
from langchain.schema.agent import AgentFinish
from langchain.schema.output import LLMResult
DEFAULT_ANSWER_PREFIX_TOKENS = ["Final", " Answer", ":"]
class CustomAsyncCallbackHandler(AsyncCallbackHandler):
def append_to_last_tokens(self, token: str) -> None:
self.last_tokens.append(token)
self.last_tokens_stripped.append(token.strip())
if len(self.last_tokens) > len(self.answer_prefix_tokens):
self.last_tokens.pop(0)
self.last_tokens_stripped.pop(0)
def check_if_answer_reached(self) -> bool:
if self.strip_tokens:
return self.last_tokens_stripped == self.answer_prefix_tokens_stripped
else:
return self.last_tokens == self.answer_prefix_tokens
def update_message_id(self, message_id: str = generate_uuid()):
self.message_id = message_id
def __init__(
self,
websocket: WebSocket,
*,
message_id: str = generate_uuid(),
answer_prefix_tokens: list[str] | None = None,
strip_tokens: bool = True,
stream_prefix: bool = False,
) -> None:
"""Instantiate FinalStreamingStdOutCallbackHandler.
Args:
answer_prefix_tokens: Token sequence that prefixes the answer.
Default is ["Final", "Answer", ":"]
strip_tokens: Ignore white spaces and new lines when comparing
answer_prefix_tokens to last tokens? (to determine if answer has been
reached)
stream_prefix: Should answer prefix itself also be streamed?
"""
self.websocket: WebSocket = websocket
self.message_id: str = message_id
self.text: str = ""
self.started: bool = False
self.loading_card = create_image_card(
"https://res.cloudinary.com/dnv0qwkrk/image/upload/v1691005682/Alita/Ellipsis-2.4s-81px_1_nja8hq.gif"
)
self.adaptive_card = self.loading_card
if answer_prefix_tokens is None:
self.answer_prefix_tokens = DEFAULT_ANSWER_PREFIX_TOKENS
else:
self.answer_prefix_tokens = answer_prefix_tokens
if strip_tokens:
self.answer_prefix_tokens_stripped = [
token.strip() for token in self.answer_prefix_tokens
]
else:
self.answer_prefix_tokens_stripped = self.answer_prefix_tokens
self.last_tokens = [""] * len(self.answer_prefix_tokens)
self.last_tokens_stripped = [""] * len(self.answer_prefix_tokens)
self.strip_tokens = strip_tokens
self.stream_prefix = stream_prefix
self.answer_reached = False
async def on_llm_start(
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
) -> None:
"""Run when LLM starts running."""
resp = IChatResponse(
id="",
message_id=self.message_id,
sender="bot",
message=self.loading_card.to_dict(),
type="start",
)
await self.websocket.send_json(resp.dict())
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
# Remember the last n tokens, where n = len(answer_prefix_tokens)
self.append_to_last_tokens(token)
self.text += f"{token}"
self.adaptive_card = create_adaptive_card(self.text)
resp = IChatResponse(
# id=generate_uuid(),
id="",
message_id=self.message_id,
sender="bot",
message=self.adaptive_card.to_dict(),
type="stream",
)
await self.websocket.send_json(resp.model_dump())
async def on_llm_end(
self,
response: LLMResult,
*,
run_id: UUID,
parent_run_id: UUID | None = None,
tags: list[str] | None = None,
**kwargs: Any,
) -> None:
"""Run when LLM ends running."""
resp = IChatResponse(
id="",
message_id=self.message_id,
sender="bot",
message=self.adaptive_card.to_dict(),
type="end",
)
await self.websocket.send_json(resp.model_dump())
class CustomFinalStreamingStdOutCallbackHandler(AsyncCallbackHandler):
"""Callback handler for streaming in agents.
Only works with agents using LLMs that support streaming.
Only the final output of the agent will be streamed.
"""
def append_to_last_tokens(self, token: str) -> None:
self.last_tokens.append(token)
self.last_tokens_stripped.append(token.strip())
if len(self.last_tokens) > len(self.answer_prefix_tokens):
self.last_tokens.pop(0)
self.last_tokens_stripped.pop(0)
def check_if_answer_reached(self) -> bool:
if self.strip_tokens:
return self.last_tokens_stripped == self.answer_prefix_tokens_stripped
else:
return self.last_tokens == self.answer_prefix_tokens
def update_message_id(self, message_id: str = generate_uuid()):
self.message_id = message_id
def __init__(
self,
websocket: WebSocket,
*,
message_id: str = generate_uuid(),
answer_prefix_tokens: list[str] | None = None,
strip_tokens: bool = True,
stream_prefix: bool = False,
) -> None:
"""Instantiate FinalStreamingStdOutCallbackHandler.
Args:
answer_prefix_tokens: Token sequence that prefixes the answer.
Default is ["Final", "Answer", ":"]
strip_tokens: Ignore white spaces and new lines when comparing
answer_prefix_tokens to last tokens? (to determine if answer has been
reached)
stream_prefix: Should answer prefix itself also be streamed?
"""
self.websocket: WebSocket = websocket
self.message_id: str = message_id
self.text: str = ""
self.started: bool = False
self.loading_card = create_image_card(
"https://res.cloudinary.com/dnv0qwkrk/image/upload/v1691005682/Alita/Ellipsis-2.4s-81px_1_nja8hq.gif"
)
self.adaptive_card = self.loading_card
if answer_prefix_tokens is None:
self.answer_prefix_tokens = DEFAULT_ANSWER_PREFIX_TOKENS
else:
self.answer_prefix_tokens = answer_prefix_tokens
if strip_tokens:
self.answer_prefix_tokens_stripped = [
token.strip() for token in self.answer_prefix_tokens
]
else:
self.answer_prefix_tokens_stripped = self.answer_prefix_tokens
self.last_tokens = [""] * len(self.answer_prefix_tokens)
self.last_tokens_stripped = [""] * len(self.answer_prefix_tokens)
self.strip_tokens = strip_tokens
self.stream_prefix = stream_prefix
self.answer_reached = False
async def on_llm_start(
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
) -> None:
"""Run when LLM starts running."""
if self.started == False:
self.started = True
resp = IChatResponse(
id="",
message_id=self.message_id,
sender="bot",
message=self.loading_card.to_dict(),
type="start",
)
await self.websocket.send_json(resp.model_dump())
async def on_agent_finish(
self,
finish: AgentFinish,
*,
run_id: UUID,
parent_run_id: UUID | None = None,
**kwargs: Any,
) -> Any:
"""Run on agent end."""
message: str = (
self.text
if self.text != ""
# else "😕 Lo siento no he podido hallar lo que buscabas"
else finish.return_values["output"]
)
self.adaptive_card = create_adaptive_card(message)
resp = IChatResponse(
id="",
message_id=self.message_id,
sender="bot",
message=self.adaptive_card.to_dict(),
type="stream",
)
await self.websocket.send_json(resp.dict())
suggested_responses = await get_suggestions_questions(message)
if len(suggested_responses) > 0:
self.adaptive_card = create_adaptive_card(
answer=message,
)
medium_resp = IChatResponse(
id="",
message_id=self.message_id,
sender="bot",
message=self.adaptive_card.to_dict(),
type="end",
suggested_responses=suggested_responses,
)
await self.websocket.send_json(medium_resp.model_dump())
# Reset values
self.text = ""
self.answer_reached = False
self.started = False
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
# Remember the last n tokens, where n = len(answer_prefix_tokens)
self.append_to_last_tokens(token)
# Check if the last n tokens match the answer_prefix_tokens list ...
if self.check_if_answer_reached():
self.answer_reached = True
return
# ... if yes, then print tokens from now on
if self.answer_reached:
self.text += f"{token}"
self.adaptive_card = create_adaptive_card(self.text)
resp = IChatResponse(
id="",
message_id=self.message_id,
sender="bot",
message=self.adaptive_card.to_dict(),
type="stream",
)
await self.websocket.send_json(resp.model_dump())