dobinyim commited on
Commit
27e491c
1 Parent(s): 9bfa017

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -82
app.py CHANGED
@@ -13,31 +13,12 @@ from langchain.schema.runnable import RunnablePassthrough
13
  from langchain.schema.runnable.config import RunnableConfig
14
 
15
  # GLOBAL SCOPE - ENTIRE APPLICATION HAS ACCESS TO VALUES SET IN THIS SCOPE #
16
- # ---- ENV VARIABLES ---- #
17
- """
18
- This function will load our environment file (.env) if it is present.
19
-
20
- NOTE: Make sure that .env is in your .gitignore file - it is by default, but please ensure it remains there.
21
- """
22
  load_dotenv()
23
 
24
- """
25
- We will load our environment variables here.
26
- """
27
  HF_LLM_ENDPOINT = os.environ["HF_LLM_ENDPOINT"]
28
  HF_EMBED_ENDPOINT = os.environ["HF_EMBED_ENDPOINT"]
29
  HF_TOKEN = os.environ["HF_TOKEN"]
30
 
31
- # ---- GLOBAL DECLARATIONS ---- #
32
-
33
- # -- RETRIEVAL -- #
34
- """
35
- 1. Load Documents from Text File
36
- 2. Split Documents into Chunks
37
- 3. Load HuggingFace Embeddings (remember to use the URL we set above)
38
- 4. Index Files if they do not exist, otherwise load the vectorstore
39
- """
40
-
41
  vectorstore_path = "./data/vectorstore"
42
  index_file = os.path.join(vectorstore_path, "index.faiss")
43
  hf_embeddings = HuggingFaceEndpointEmbeddings(
@@ -47,42 +28,26 @@ hf_embeddings = HuggingFaceEndpointEmbeddings(
47
  )
48
 
49
  vectorstore = FAISS.load_local(
50
- vectorstore_path,
51
- hf_embeddings,
52
- allow_dangerous_deserialization=True
53
  )
54
  hf_retriever = vectorstore.as_retriever()
55
  print("Loaded Vectorstore")
56
-
57
 
58
- # -- AUGMENTED -- #
59
- """
60
- 1. Define a String Template
61
- 2. Create a Prompt Template from the String Template
62
- """
63
- ### 1. DEFINE STRING TEMPLATE
64
  RAG_PROMPT_TEMPLATE = """\
65
- <|start_header_id|>system<|end_header_id|>
66
- You are a helpful assistant. You answer user questions based on provided context. If you can't answer the question with the provided context, say you don't know.<|eot_id|>
67
-
68
- <|start_header_id|>user<|end_header_id|>
69
  User Query:
70
  {query}
71
-
72
  Context:
73
- {context}<|eot_id|>
74
-
75
- <|start_header_id|>assistant<|end_header_id|>
76
  """
77
 
78
- ### 2. CREATE PROMPT TEMPLATE
79
  rag_prompt = PromptTemplate.from_template(RAG_PROMPT_TEMPLATE)
80
 
81
- # -- GENERATION -- #
82
- """
83
- 1. Create a HuggingFaceEndpoint for the LLM
84
- """
85
- ### 1. CREATE HUGGINGFACE ENDPOINT FOR LLM
86
  hf_llm = HuggingFaceEndpoint(
87
  endpoint_url=HF_LLM_ENDPOINT,
88
  max_new_tokens=512,
@@ -95,51 +60,37 @@ hf_llm = HuggingFaceEndpoint(
95
 
96
  @cl.author_rename
97
  def rename(original_author: str):
98
- """
99
- This function can be used to rename the 'author' of a message.
100
-
101
- In this case, we're overriding the 'Assistant' author to be 'Paul Graham Essay Bot'.
102
- """
103
  rename_dict = {
104
- "Assistant" : "Paul Graham Essay Bot"
105
  }
106
  return rename_dict.get(original_author, original_author)
107
 
108
  @cl.on_chat_start
109
  async def start_chat():
110
- """
111
- This function will be called at the start of every user session.
112
-
113
- We will build our LCEL RAG chain here, and store it in the user session.
114
-
115
- The user session is a dictionary that is unique to each user session, and is stored in the memory of the server.
116
- """
117
-
118
- ### BUILD LCEL RAG CHAIN THAT ONLY RETURNS TEXT
119
- lcel_rag_chain = (
120
- {"context": itemgetter("query") | hf_retriever, "query": itemgetter("query")}
121
- | rag_prompt | hf_llm
122
- )
123
-
124
- cl.user_session.set("lcel_rag_chain", lcel_rag_chain)
125
 
126
  @cl.on_message
127
  async def main(message: cl.Message):
128
- """
129
- This function will be called every time a message is recieved from a session.
130
-
131
- We will use the LCEL RAG chain to generate a response to the user query.
132
-
133
- The LCEL RAG chain is stored in the user session, and is unique to each user session - this is why we can access it here.
134
- """
135
- lcel_rag_chain = cl.user_session.get("lcel_rag_chain")
136
-
137
- msg = cl.Message(content="")
138
-
139
- async for chunk in lcel_rag_chain.astream(
140
- {"query": message.content},
141
- config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]),
142
- ):
143
- await msg.stream_token(chunk)
144
-
145
- await msg.send()
 
13
  from langchain.schema.runnable.config import RunnableConfig
14
 
15
  # GLOBAL SCOPE - ENTIRE APPLICATION HAS ACCESS TO VALUES SET IN THIS SCOPE #
 
 
 
 
 
 
16
  load_dotenv()
17
 
 
 
 
18
  HF_LLM_ENDPOINT = os.environ["HF_LLM_ENDPOINT"]
19
  HF_EMBED_ENDPOINT = os.environ["HF_EMBED_ENDPOINT"]
20
  HF_TOKEN = os.environ["HF_TOKEN"]
21
 
 
 
 
 
 
 
 
 
 
 
22
  vectorstore_path = "./data/vectorstore"
23
  index_file = os.path.join(vectorstore_path, "index.faiss")
24
  hf_embeddings = HuggingFaceEndpointEmbeddings(
 
28
  )
29
 
30
  vectorstore = FAISS.load_local(
31
+ vectorstore_path,
32
+ hf_embeddings,
33
+ allow_dangerous_deserialization=True
34
  )
35
  hf_retriever = vectorstore.as_retriever()
36
  print("Loaded Vectorstore")
 
37
 
 
 
 
 
 
 
38
  RAG_PROMPT_TEMPLATE = """\
39
+ system
40
+ You are a helpful assistant. You answer user questions based on provided context. If you can't answer the question with the provided context, say you don't know.
41
+ user
 
42
  User Query:
43
  {query}
 
44
  Context:
45
+ {context}
46
+ assistant
 
47
  """
48
 
 
49
  rag_prompt = PromptTemplate.from_template(RAG_PROMPT_TEMPLATE)
50
 
 
 
 
 
 
51
  hf_llm = HuggingFaceEndpoint(
52
  endpoint_url=HF_LLM_ENDPOINT,
53
  max_new_tokens=512,
 
60
 
61
  @cl.author_rename
62
  def rename(original_author: str):
 
 
 
 
 
63
  rename_dict = {
64
+ "Assistant": "Paul Graham Essay Bot"
65
  }
66
  return rename_dict.get(original_author, original_author)
67
 
68
  @cl.on_chat_start
69
  async def start_chat():
70
+ try:
71
+ lcel_rag_chain = (
72
+ {"context": itemgetter("query") | hf_retriever, "query": itemgetter("query")}
73
+ | rag_prompt | hf_llm
74
+ )
75
+ cl.user_session.set("lcel_rag_chain", lcel_rag_chain)
76
+ except KeyError as e:
77
+ print(f"Session error on start: {e}")
 
 
 
 
 
 
 
78
 
79
  @cl.on_message
80
  async def main(message: cl.Message):
81
+ try:
82
+ lcel_rag_chain = cl.user_session.get("lcel_rag_chain")
83
+ if lcel_rag_chain is None:
84
+ await cl.Message(content="Session has expired. Please restart the chat.").send()
85
+ return
86
+
87
+ msg = cl.Message(content="")
88
+ async for chunk in lcel_rag_chain.astream(
89
+ {"query": message.content},
90
+ config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]),
91
+ ):
92
+ await msg.stream_token(chunk)
93
+ await msg.send()
94
+ except KeyError as e:
95
+ await cl.Message(content="An error occurred. Please restart the chat.").send()
96
+ print(f"Session error: {e}")