|
from langchain import LLMChain, PromptTemplate |
|
from langchain.chains import ConversationalRetrievalChain |
|
from langchain.chains.base import Chain |
|
from langchain.memory import ConversationBufferMemory |
|
|
|
from app_modules.llm_inference import LLMInference |
|
|
|
|
|
class ChatChain(LLMInference): |
|
def __init__(self, llm_loader): |
|
super().__init__(llm_loader) |
|
|
|
def create_chain(self) -> Chain: |
|
template = """You are a chatbot having a conversation with a human. |
|
{chat_history} |
|
Human: {question} |
|
Chatbot:""" |
|
|
|
prompt = PromptTemplate( |
|
input_variables=["chat_history", "question"], template=template |
|
) |
|
memory = ConversationBufferMemory(memory_key="chat_history") |
|
|
|
llm_chain = LLMChain( |
|
llm=self.llm_loader.llm, |
|
prompt=prompt, |
|
verbose=True, |
|
memory=memory, |
|
) |
|
|
|
return llm_chain |
|
|