File size: 3,838 Bytes
6acb7ee
 
 
 
 
 
 
 
 
 
 
 
 
e3eb2ab
6acb7ee
e3eb2ab
 
 
 
 
6acb7ee
 
 
 
 
 
c827140
 
6acb7ee
 
 
 
 
7da07a5
e3eb2ab
 
 
 
 
 
 
 
 
 
 
6acb7ee
 
27e491c
 
 
6acb7ee
 
 
27e491c
 
6acb7ee
 
 
 
e3eb2ab
 
 
 
 
 
 
 
 
 
 
 
 
 
6acb7ee
 
 
 
27e491c
6acb7ee
 
 
 
 
27e491c
 
 
 
 
 
e3eb2ab
 
 
6acb7ee
 
 
27e491c
 
 
e3eb2ab
27e491c
 
 
 
 
 
 
 
 
 
e3eb2ab
27e491c
e3eb2ab
3bd1b51
e3eb2ab
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import chainlit as cl
from dotenv import load_dotenv
from operator import itemgetter
from langchain_huggingface import HuggingFaceEndpoint
from langchain_community.document_loaders import TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEndpointEmbeddings
from langchain_core.prompts import PromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
from langchain.schema.runnable.config import RunnableConfig
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Load environment variables
load_dotenv()

HF_LLM_ENDPOINT = os.environ["HF_LLM_ENDPOINT"]
HF_EMBED_ENDPOINT = os.environ["HF_EMBED_ENDPOINT"]
HF_TOKEN = os.environ["HF_TOKEN"]

vectorstore_path = "./data/vectorstore"
index_file = os.path.join(vectorstore_path, "index.faiss")
hf_embeddings = HuggingFaceEndpointEmbeddings(
    model=HF_EMBED_ENDPOINT,
    task="feature-extraction",
    huggingfacehub_api_token=HF_TOKEN,
)

try:
    vectorstore = FAISS.load_local(
        vectorstore_path, 
        hf_embeddings, 
        allow_dangerous_deserialization=True
    )
    hf_retriever = vectorstore.as_retriever()
    logger.info("Loaded Vectorstore")
except Exception as e:
    logger.error(f"Error loading Vectorstore: {e}")
    raise

RAG_PROMPT_TEMPLATE = """\
system
You are a helpful assistant. You answer user questions based on provided context. If you can't answer the question with the provided context, say you don't know.
user
User Query:
{query}
Context:
{context}
assistant
"""

rag_prompt = PromptTemplate.from_template(RAG_PROMPT_TEMPLATE)

try:
    hf_llm = HuggingFaceEndpoint(
        endpoint_url=HF_LLM_ENDPOINT,
        max_new_tokens=512,
        top_k=10,
        top_p=0.95,
        temperature=0.1,
        repetition_penalty=1.0,
        huggingfacehub_api_token=HF_TOKEN,
    )
    logger.info("Initialized HuggingFace LLM endpoint")
except Exception as e:
    logger.error(f"Error initializing HuggingFace LLM endpoint: {e}")
    raise

@cl.author_rename
def rename(original_author: str):
    rename_dict = {
        "Assistant": "Paul Graham Essay Bot"
    }
    return rename_dict.get(original_author, original_author)

@cl.on_chat_start
async def start_chat():
    try:
        lcel_rag_chain = (
            {"context": itemgetter("query") | hf_retriever, "query": itemgetter("query")}
            | rag_prompt | hf_llm
        )
        cl.user_session.set("lcel_rag_chain", lcel_rag_chain)
        logger.info("Started chat and set LCEL RAG chain in user session")
    except Exception as e:
        logger.error(f"Error during chat start: {e}")

@cl.on_message  
async def main(message: cl.Message):
    try:
        lcel_rag_chain = cl.user_session.get("lcel_rag_chain")
        if lcel_rag_chain is None:
            logger.warning("Session has expired. Asking user to restart the chat.")
            await cl.Message(content="Session has expired. Please restart the chat.").send()
            return

        msg = cl.Message(content="")
        async for chunk in lcel_rag_chain.astream(
            {"query": message.content},
            config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]),
        ):
            await msg.stream_token(chunk)
        await msg.send()
        logger.info("Message processed successfully")
    except KeyError as e:
        logger.error(f"Session error: {e}")
        await cl.Message(content="Session is disconnected. Please restart the chat.").send()
    except Exception as e:
        logger.error(f"Unexpected error: {e}")
        await cl.Message(content="An unexpected error occurred. Please try again.").send()