Spaces:
Paused
Paused
import os | |
from langchain_openai import ChatOpenAI | |
from langchain.chains import create_retrieval_chain, create_history_aware_retriever | |
from langchain.chains.combine_documents import create_stuff_documents_chain | |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from langchain_core.runnables.history import RunnableWithMessageHistory | |
from langchain_community.chat_message_histories import ChatMessageHistory | |
def get_model(): | |
return ChatOpenAI(api_key=os.getenv("OPEN_API_KEY")) | |
def create_contextualize_q_prompt(): | |
contextualize_q_system_prompt = ( | |
"Given a chat history and the latest user question " | |
"which might reference context in the chat history, " | |
"formulate a standalone question which can be understood " | |
"without the chat history. Do NOT answer the question, " | |
"just reformulate it if needed and otherwise return it as is." | |
) | |
return ChatPromptTemplate.from_messages([ | |
("system", contextualize_q_system_prompt), | |
MessagesPlaceholder("chat_history"), | |
("human", "{input}"), | |
]) | |
def create_qa_prompt(): | |
qa_system_prompt = """You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. \ | |
The retrieved content belongs to subject textbooks present. You will receive different chunks, each of which belongs to a single page of a textbook. \ | |
Using the chunk given, think logically and answer the questions from the user. \ | |
If you are not able to identify the relevant information regarding to the user's question in the retrieved chunks, then just return 'No data found'.\ | |
Use three sentences maximum and keep the answer concise. \ | |
{context}""" | |
return ChatPromptTemplate.from_messages([ | |
("system", qa_system_prompt), | |
MessagesPlaceholder("chat_history"), | |
("human", "{input}"), | |
]) | |
def create_rag_chain(model, retriever): | |
contextualize_q_prompt = create_contextualize_q_prompt() | |
qa_prompt = create_qa_prompt() | |
history_aware_retriever = create_history_aware_retriever(model, retriever, contextualize_q_prompt) | |
question_answer_chain = create_stuff_documents_chain(model, qa_prompt) | |
return create_retrieval_chain(history_aware_retriever, question_answer_chain) | |
def get_conversational_rag_chain(rag_chain): | |
store = {} | |
def get_session_history(session_id: str): | |
if session_id not in store: | |
store[session_id] = ChatMessageHistory() | |
return store[session_id] | |
return RunnableWithMessageHistory( | |
rag_chain, | |
get_session_history, | |
input_messages_key="input", | |
history_messages_key="chat_history", | |
output_messages_key="answer", | |
) |