|
from langchain.chains import ConversationalRetrievalChain |
|
from langchain.chains.base import Chain |
|
from langchain.vectorstores.base import VectorStore |
|
|
|
from app_modules.llm_inference import LLMInference |
|
|
|
|
|
class QAChain(LLMInference): |
|
vectorstore: VectorStore |
|
|
|
def __init__(self, vectorstore, llm_loader, doc_id_to_vectorstore_mapping=None): |
|
super().__init__(llm_loader) |
|
self.vectorstore = vectorstore |
|
self.doc_id_to_vectorstore_mapping = doc_id_to_vectorstore_mapping |
|
|
|
def get_chain(self, inputs) -> Chain: |
|
return self.create_chain(inputs) |
|
|
|
def create_chain(self, inputs) -> Chain: |
|
vectorstore = self.vectorstore |
|
if "chat_id" in inputs: |
|
if inputs["chat_id"] in self.doc_id_to_vectorstore_mapping: |
|
vectorstore = self.doc_id_to_vectorstore_mapping[inputs["chat_id"]] |
|
|
|
qa = ConversationalRetrievalChain.from_llm( |
|
self.llm_loader.llm, |
|
vectorstore.as_retriever(search_kwargs=self.llm_loader.search_kwargs), |
|
max_tokens_limit=self.llm_loader.max_tokens_limit, |
|
return_source_documents=True, |
|
) |
|
|
|
return qa |
|
|