aie3-e2e-rag / app.py
dobinyim's picture
Update app.py
3bd1b51 verified
raw
history blame contribute delete
No virus
3.84 kB
import os
import chainlit as cl
from dotenv import load_dotenv
from operator import itemgetter
from langchain_huggingface import HuggingFaceEndpoint
from langchain_community.document_loaders import TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEndpointEmbeddings
from langchain_core.prompts import PromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
from langchain.schema.runnable.config import RunnableConfig
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load environment variables
load_dotenv()
HF_LLM_ENDPOINT = os.environ["HF_LLM_ENDPOINT"]
HF_EMBED_ENDPOINT = os.environ["HF_EMBED_ENDPOINT"]
HF_TOKEN = os.environ["HF_TOKEN"]
vectorstore_path = "./data/vectorstore"
index_file = os.path.join(vectorstore_path, "index.faiss")
hf_embeddings = HuggingFaceEndpointEmbeddings(
model=HF_EMBED_ENDPOINT,
task="feature-extraction",
huggingfacehub_api_token=HF_TOKEN,
)
try:
vectorstore = FAISS.load_local(
vectorstore_path,
hf_embeddings,
allow_dangerous_deserialization=True
)
hf_retriever = vectorstore.as_retriever()
logger.info("Loaded Vectorstore")
except Exception as e:
logger.error(f"Error loading Vectorstore: {e}")
raise
RAG_PROMPT_TEMPLATE = """\
system
You are a helpful assistant. You answer user questions based on provided context. If you can't answer the question with the provided context, say you don't know.
user
User Query:
{query}
Context:
{context}
assistant
"""
rag_prompt = PromptTemplate.from_template(RAG_PROMPT_TEMPLATE)
try:
hf_llm = HuggingFaceEndpoint(
endpoint_url=HF_LLM_ENDPOINT,
max_new_tokens=512,
top_k=10,
top_p=0.95,
temperature=0.1,
repetition_penalty=1.0,
huggingfacehub_api_token=HF_TOKEN,
)
logger.info("Initialized HuggingFace LLM endpoint")
except Exception as e:
logger.error(f"Error initializing HuggingFace LLM endpoint: {e}")
raise
@cl.author_rename
def rename(original_author: str):
rename_dict = {
"Assistant": "Paul Graham Essay Bot"
}
return rename_dict.get(original_author, original_author)
@cl.on_chat_start
async def start_chat():
try:
lcel_rag_chain = (
{"context": itemgetter("query") | hf_retriever, "query": itemgetter("query")}
| rag_prompt | hf_llm
)
cl.user_session.set("lcel_rag_chain", lcel_rag_chain)
logger.info("Started chat and set LCEL RAG chain in user session")
except Exception as e:
logger.error(f"Error during chat start: {e}")
@cl.on_message
async def main(message: cl.Message):
try:
lcel_rag_chain = cl.user_session.get("lcel_rag_chain")
if lcel_rag_chain is None:
logger.warning("Session has expired. Asking user to restart the chat.")
await cl.Message(content="Session has expired. Please restart the chat.").send()
return
msg = cl.Message(content="")
async for chunk in lcel_rag_chain.astream(
{"query": message.content},
config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]),
):
await msg.stream_token(chunk)
await msg.send()
logger.info("Message processed successfully")
except KeyError as e:
logger.error(f"Session error: {e}")
await cl.Message(content="Session is disconnected. Please restart the chat.").send()
except Exception as e:
logger.error(f"Unexpected error: {e}")
await cl.Message(content="An unexpected error occurred. Please try again.").send()