Spaces:
Paused
Paused
File size: 3,838 Bytes
6acb7ee e3eb2ab 6acb7ee e3eb2ab 6acb7ee c827140 6acb7ee 7da07a5 e3eb2ab 6acb7ee 27e491c 6acb7ee 27e491c 6acb7ee e3eb2ab 6acb7ee 27e491c 6acb7ee 27e491c e3eb2ab 6acb7ee 27e491c e3eb2ab 27e491c e3eb2ab 27e491c e3eb2ab 3bd1b51 e3eb2ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import os
import chainlit as cl
from dotenv import load_dotenv
from operator import itemgetter
from langchain_huggingface import HuggingFaceEndpoint
from langchain_community.document_loaders import TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEndpointEmbeddings
from langchain_core.prompts import PromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
from langchain.schema.runnable.config import RunnableConfig
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load environment variables
load_dotenv()
HF_LLM_ENDPOINT = os.environ["HF_LLM_ENDPOINT"]
HF_EMBED_ENDPOINT = os.environ["HF_EMBED_ENDPOINT"]
HF_TOKEN = os.environ["HF_TOKEN"]
vectorstore_path = "./data/vectorstore"
index_file = os.path.join(vectorstore_path, "index.faiss")
hf_embeddings = HuggingFaceEndpointEmbeddings(
model=HF_EMBED_ENDPOINT,
task="feature-extraction",
huggingfacehub_api_token=HF_TOKEN,
)
try:
vectorstore = FAISS.load_local(
vectorstore_path,
hf_embeddings,
allow_dangerous_deserialization=True
)
hf_retriever = vectorstore.as_retriever()
logger.info("Loaded Vectorstore")
except Exception as e:
logger.error(f"Error loading Vectorstore: {e}")
raise
RAG_PROMPT_TEMPLATE = """\
system
You are a helpful assistant. You answer user questions based on provided context. If you can't answer the question with the provided context, say you don't know.
user
User Query:
{query}
Context:
{context}
assistant
"""
rag_prompt = PromptTemplate.from_template(RAG_PROMPT_TEMPLATE)
try:
hf_llm = HuggingFaceEndpoint(
endpoint_url=HF_LLM_ENDPOINT,
max_new_tokens=512,
top_k=10,
top_p=0.95,
temperature=0.1,
repetition_penalty=1.0,
huggingfacehub_api_token=HF_TOKEN,
)
logger.info("Initialized HuggingFace LLM endpoint")
except Exception as e:
logger.error(f"Error initializing HuggingFace LLM endpoint: {e}")
raise
@cl.author_rename
def rename(original_author: str):
rename_dict = {
"Assistant": "Paul Graham Essay Bot"
}
return rename_dict.get(original_author, original_author)
@cl.on_chat_start
async def start_chat():
try:
lcel_rag_chain = (
{"context": itemgetter("query") | hf_retriever, "query": itemgetter("query")}
| rag_prompt | hf_llm
)
cl.user_session.set("lcel_rag_chain", lcel_rag_chain)
logger.info("Started chat and set LCEL RAG chain in user session")
except Exception as e:
logger.error(f"Error during chat start: {e}")
@cl.on_message
async def main(message: cl.Message):
try:
lcel_rag_chain = cl.user_session.get("lcel_rag_chain")
if lcel_rag_chain is None:
logger.warning("Session has expired. Asking user to restart the chat.")
await cl.Message(content="Session has expired. Please restart the chat.").send()
return
msg = cl.Message(content="")
async for chunk in lcel_rag_chain.astream(
{"query": message.content},
config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]),
):
await msg.stream_token(chunk)
await msg.send()
logger.info("Message processed successfully")
except KeyError as e:
logger.error(f"Session error: {e}")
await cl.Message(content="Session is disconnected. Please restart the chat.").send()
except Exception as e:
logger.error(f"Unexpected error: {e}")
await cl.Message(content="An unexpected error occurred. Please try again.").send()
|