pycui's picture
Add RealChar deployment for HuggingFace (V0)
babeaf6
raw
history blame contribute delete
No virus
1.93 kB
import os
from typing import List
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.chat_models import ChatAnthropic
from langchain.schema import BaseMessage, HumanMessage
from realtime_ai_character.database.chroma import get_chroma
from realtime_ai_character.llm.base import AsyncCallbackAudioHandler, AsyncCallbackTextHandler, LLM
from realtime_ai_character.logger import get_logger
from realtime_ai_character.utils import Character
logger = get_logger(__name__)
class AnthropicLlm(LLM):
def __init__(self, model):
self.chat_anthropic = ChatAnthropic(
model=model,
temperature=0.5,
streaming=True
)
self.db = get_chroma()
async def achat(self,
history: List[BaseMessage],
user_input: str,
user_input_template: str,
callback: AsyncCallbackTextHandler,
audioCallback: AsyncCallbackAudioHandler,
character: Character) -> str:
# 1. Generate context
context = self._generate_context(user_input, character)
# 2. Add user input to history
history.append(HumanMessage(content=user_input_template.format(
context=context, query=user_input)))
# 3. Generate response
response = await self.chat_anthropic.agenerate(
[history], callbacks=[callback, audioCallback, StreamingStdOutCallbackHandler()])
logger.info(f'Response: {response}')
return response.generations[0][0].text
def _generate_context(self, query, character: Character) -> str:
docs = self.db.similarity_search(query)
docs = [d for d in docs if d.metadata['character_name'] == character.name]
logger.info(f'Found {len(docs)} documents')
context = '\n'.join([d.page_content for d in docs])
return context