Spaces:
Sleeping
Sleeping
from app.schemas.message_schema import IChatResponse | |
from app.utils.adaptive_cards.cards import create_adaptive_card, create_image_card | |
from app.utils.chains import get_suggestions_questions | |
from langchain.callbacks.base import AsyncCallbackHandler | |
from app.utils.utils import generate_uuid | |
from fastapi import WebSocket | |
from uuid import UUID | |
from typing import Any | |
from langchain.schema.agent import AgentFinish | |
from langchain.schema.output import LLMResult | |
DEFAULT_ANSWER_PREFIX_TOKENS = ["Final", " Answer", ":"] | |
class CustomAsyncCallbackHandler(AsyncCallbackHandler): | |
def append_to_last_tokens(self, token: str) -> None: | |
self.last_tokens.append(token) | |
self.last_tokens_stripped.append(token.strip()) | |
if len(self.last_tokens) > len(self.answer_prefix_tokens): | |
self.last_tokens.pop(0) | |
self.last_tokens_stripped.pop(0) | |
def check_if_answer_reached(self) -> bool: | |
if self.strip_tokens: | |
return self.last_tokens_stripped == self.answer_prefix_tokens_stripped | |
else: | |
return self.last_tokens == self.answer_prefix_tokens | |
def update_message_id(self, message_id: str = generate_uuid()): | |
self.message_id = message_id | |
def __init__( | |
self, | |
websocket: WebSocket, | |
*, | |
message_id: str = generate_uuid(), | |
answer_prefix_tokens: list[str] | None = None, | |
strip_tokens: bool = True, | |
stream_prefix: bool = False, | |
) -> None: | |
"""Instantiate FinalStreamingStdOutCallbackHandler. | |
Args: | |
answer_prefix_tokens: Token sequence that prefixes the answer. | |
Default is ["Final", "Answer", ":"] | |
strip_tokens: Ignore white spaces and new lines when comparing | |
answer_prefix_tokens to last tokens? (to determine if answer has been | |
reached) | |
stream_prefix: Should answer prefix itself also be streamed? | |
""" | |
self.websocket: WebSocket = websocket | |
self.message_id: str = message_id | |
self.text: str = "" | |
self.started: bool = False | |
self.loading_card = create_image_card( | |
"https://res.cloudinary.com/dnv0qwkrk/image/upload/v1691005682/Alita/Ellipsis-2.4s-81px_1_nja8hq.gif" | |
) | |
self.adaptive_card = self.loading_card | |
if answer_prefix_tokens is None: | |
self.answer_prefix_tokens = DEFAULT_ANSWER_PREFIX_TOKENS | |
else: | |
self.answer_prefix_tokens = answer_prefix_tokens | |
if strip_tokens: | |
self.answer_prefix_tokens_stripped = [ | |
token.strip() for token in self.answer_prefix_tokens | |
] | |
else: | |
self.answer_prefix_tokens_stripped = self.answer_prefix_tokens | |
self.last_tokens = [""] * len(self.answer_prefix_tokens) | |
self.last_tokens_stripped = [""] * len(self.answer_prefix_tokens) | |
self.strip_tokens = strip_tokens | |
self.stream_prefix = stream_prefix | |
self.answer_reached = False | |
async def on_llm_start( | |
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any | |
) -> None: | |
"""Run when LLM starts running.""" | |
resp = IChatResponse( | |
id="", | |
message_id=self.message_id, | |
sender="bot", | |
message=self.loading_card.to_dict(), | |
type="start", | |
) | |
await self.websocket.send_json(resp.dict()) | |
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: | |
"""Run on new LLM token. Only available when streaming is enabled.""" | |
# Remember the last n tokens, where n = len(answer_prefix_tokens) | |
self.append_to_last_tokens(token) | |
self.text += f"{token}" | |
self.adaptive_card = create_adaptive_card(self.text) | |
resp = IChatResponse( | |
# id=generate_uuid(), | |
id="", | |
message_id=self.message_id, | |
sender="bot", | |
message=self.adaptive_card.to_dict(), | |
type="stream", | |
) | |
await self.websocket.send_json(resp.model_dump()) | |
async def on_llm_end( | |
self, | |
response: LLMResult, | |
*, | |
run_id: UUID, | |
parent_run_id: UUID | None = None, | |
tags: list[str] | None = None, | |
**kwargs: Any, | |
) -> None: | |
"""Run when LLM ends running.""" | |
resp = IChatResponse( | |
id="", | |
message_id=self.message_id, | |
sender="bot", | |
message=self.adaptive_card.to_dict(), | |
type="end", | |
) | |
await self.websocket.send_json(resp.model_dump()) | |
class CustomFinalStreamingStdOutCallbackHandler(AsyncCallbackHandler): | |
"""Callback handler for streaming in agents. | |
Only works with agents using LLMs that support streaming. | |
Only the final output of the agent will be streamed. | |
""" | |
def append_to_last_tokens(self, token: str) -> None: | |
self.last_tokens.append(token) | |
self.last_tokens_stripped.append(token.strip()) | |
if len(self.last_tokens) > len(self.answer_prefix_tokens): | |
self.last_tokens.pop(0) | |
self.last_tokens_stripped.pop(0) | |
def check_if_answer_reached(self) -> bool: | |
if self.strip_tokens: | |
return self.last_tokens_stripped == self.answer_prefix_tokens_stripped | |
else: | |
return self.last_tokens == self.answer_prefix_tokens | |
def update_message_id(self, message_id: str = generate_uuid()): | |
self.message_id = message_id | |
def __init__( | |
self, | |
websocket: WebSocket, | |
*, | |
message_id: str = generate_uuid(), | |
answer_prefix_tokens: list[str] | None = None, | |
strip_tokens: bool = True, | |
stream_prefix: bool = False, | |
) -> None: | |
"""Instantiate FinalStreamingStdOutCallbackHandler. | |
Args: | |
answer_prefix_tokens: Token sequence that prefixes the answer. | |
Default is ["Final", "Answer", ":"] | |
strip_tokens: Ignore white spaces and new lines when comparing | |
answer_prefix_tokens to last tokens? (to determine if answer has been | |
reached) | |
stream_prefix: Should answer prefix itself also be streamed? | |
""" | |
self.websocket: WebSocket = websocket | |
self.message_id: str = message_id | |
self.text: str = "" | |
self.started: bool = False | |
self.loading_card = create_image_card( | |
"https://res.cloudinary.com/dnv0qwkrk/image/upload/v1691005682/Alita/Ellipsis-2.4s-81px_1_nja8hq.gif" | |
) | |
self.adaptive_card = self.loading_card | |
if answer_prefix_tokens is None: | |
self.answer_prefix_tokens = DEFAULT_ANSWER_PREFIX_TOKENS | |
else: | |
self.answer_prefix_tokens = answer_prefix_tokens | |
if strip_tokens: | |
self.answer_prefix_tokens_stripped = [ | |
token.strip() for token in self.answer_prefix_tokens | |
] | |
else: | |
self.answer_prefix_tokens_stripped = self.answer_prefix_tokens | |
self.last_tokens = [""] * len(self.answer_prefix_tokens) | |
self.last_tokens_stripped = [""] * len(self.answer_prefix_tokens) | |
self.strip_tokens = strip_tokens | |
self.stream_prefix = stream_prefix | |
self.answer_reached = False | |
async def on_llm_start( | |
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any | |
) -> None: | |
"""Run when LLM starts running.""" | |
if self.started == False: | |
self.started = True | |
resp = IChatResponse( | |
id="", | |
message_id=self.message_id, | |
sender="bot", | |
message=self.loading_card.to_dict(), | |
type="start", | |
) | |
await self.websocket.send_json(resp.model_dump()) | |
async def on_agent_finish( | |
self, | |
finish: AgentFinish, | |
*, | |
run_id: UUID, | |
parent_run_id: UUID | None = None, | |
**kwargs: Any, | |
) -> Any: | |
"""Run on agent end.""" | |
message: str = ( | |
self.text | |
if self.text != "" | |
# else "😕 Lo siento no he podido hallar lo que buscabas" | |
else finish.return_values["output"] | |
) | |
self.adaptive_card = create_adaptive_card(message) | |
resp = IChatResponse( | |
id="", | |
message_id=self.message_id, | |
sender="bot", | |
message=self.adaptive_card.to_dict(), | |
type="stream", | |
) | |
await self.websocket.send_json(resp.dict()) | |
suggested_responses = await get_suggestions_questions(message) | |
if len(suggested_responses) > 0: | |
self.adaptive_card = create_adaptive_card( | |
answer=message, | |
) | |
medium_resp = IChatResponse( | |
id="", | |
message_id=self.message_id, | |
sender="bot", | |
message=self.adaptive_card.to_dict(), | |
type="end", | |
suggested_responses=suggested_responses, | |
) | |
await self.websocket.send_json(medium_resp.model_dump()) | |
# Reset values | |
self.text = "" | |
self.answer_reached = False | |
self.started = False | |
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: | |
"""Run on new LLM token. Only available when streaming is enabled.""" | |
# Remember the last n tokens, where n = len(answer_prefix_tokens) | |
self.append_to_last_tokens(token) | |
# Check if the last n tokens match the answer_prefix_tokens list ... | |
if self.check_if_answer_reached(): | |
self.answer_reached = True | |
return | |
# ... if yes, then print tokens from now on | |
if self.answer_reached: | |
self.text += f"{token}" | |
self.adaptive_card = create_adaptive_card(self.text) | |
resp = IChatResponse( | |
id="", | |
message_id=self.message_id, | |
sender="bot", | |
message=self.adaptive_card.to_dict(), | |
type="stream", | |
) | |
await self.websocket.send_json(resp.model_dump()) |