RubenAMtz commited on
Commit
cdda8d7
·
1 Parent(s): 8166d2a

first try to chainlit app with RAQA, WaB and Chains

Browse files
Files changed (6) hide show
  1. .gitignore +4 -0
  2. app.py +211 -114
  3. requirements.txt +3 -5
  4. utils/__init__.py +0 -0
  5. utils/chain.py +71 -0
  6. utils/store.py +44 -0
.gitignore CHANGED
@@ -3,6 +3,10 @@ __pycache__/
3
  *.py[cod]
4
  *$py.class
5
 
 
 
 
 
6
  # C extensions
7
  *.so
8
 
 
3
  *.py[cod]
4
  *$py.class
5
 
6
+ # project
7
+ cache/
8
+ wandb/
9
+
10
  # C extensions
11
  *.so
12
 
app.py CHANGED
@@ -7,30 +7,37 @@ import chainlit as cl # importing chainlit for our app
7
  from chainlit.prompt import Prompt, PromptMessage # importing prompt tools
8
  from chainlit.playground.providers import ChatOpenAI # importing ChatOpenAI tools
9
  from dotenv import load_dotenv
10
- from aimakerspace.text_utils import PDFFileLoader, CharacterTextSplitter
11
- from aimakerspace.vectordatabase import VectorDatabase
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  load_dotenv()
14
-
15
- # ChatOpenAI Templates
16
- system_template = """You are a Wizzard and everything you say is a spell!
17
- """
18
-
19
- user_template = """{input}
20
- Wizzard, think through your response step by step.
21
- """
22
-
23
- assistant_template = """Use the following context, if any, to help you
24
- answer the user's input, if the answer is not in the context say you don't
25
- know the answer.
26
- CONTEXT:
27
- ===============
28
- {context}
29
- ===============
30
-
31
- Spell away Wizzard!
32
- """
33
-
34
 
35
 
36
  @cl.on_chat_start # marks a function that will be executed at the start of a user session
@@ -38,113 +45,203 @@ async def start_chat():
38
  settings = {
39
  "model": "gpt-3.5-turbo",
40
  "temperature": 0,
41
- "max_tokens": 500,
42
- "top_p": 1,
43
- "frequency_penalty": 0,
44
- "presence_penalty": 0,
45
  }
46
 
47
- cl.user_session.set("settings", settings)
48
-
49
- files = None
50
- while files is None:
51
- files = await cl.AskFileMessage(
52
- content="Please upload a PDF file to begin",
53
- accept=["application/pdf"],
54
- max_files=10,
55
- max_size_mb=10,
56
- timeout=60
57
- ).send()
58
-
59
- # let the user know you are processing the file(s)
60
  await cl.Message(
61
- content="Loading your files..."
62
  ).send()
63
 
64
- # decode the file
65
- documents = PDFFileLoader(path="", files=files).load_documents()
66
 
67
- # split the text into chunks
68
- chunks = CharacterTextSplitter(
69
- chunk_size=1000,
70
- chunk_overlap=200
71
- ).split_texts(documents)
72
 
73
- print(chunks[0])
 
 
74
 
75
- # create a vector store
76
- # let the user know you are processing the document(s)
77
- await cl.Message(
78
- content="Creating vector store"
79
- ).send()
80
 
81
- vector_db = VectorDatabase()
82
- vector_db = await vector_db.abuild_from_list(chunks)
 
 
 
83
 
84
- await cl.Message(
85
- content="Done. Ask away!"
86
- ).send()
 
 
 
 
87
 
88
- cl.user_session.set("vector_db", vector_db)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
 
91
  @cl.on_message # marks a function that should be run each time the chatbot receives a message from a user
92
  async def main(message: cl.Message):
93
- vector_db = cl.user_session.get("vector_db")
94
  settings = cl.user_session.get("settings")
95
-
96
- client = AsyncOpenAI()
97
-
98
- print(message.content)
99
-
100
- results_list = vector_db.search_by_text(query_text=message.content, k=3, return_as_text=True)
101
- if results_list:
102
- results_string = "\n\n".join(results_list)
103
- else:
104
- results_string = ""
105
-
106
- prompt = Prompt(
107
- provider=ChatOpenAI.id,
108
- messages=[
109
- PromptMessage(
110
- role="system",
111
- template=system_template,
112
- formatted=system_template,
113
- ),
114
- PromptMessage(
115
- role="user",
116
- template=user_template,
117
- formatted=user_template.format(input=message.content),
118
- ),
119
- PromptMessage(
120
- role="assistant",
121
- template=assistant_template,
122
- formatted=assistant_template.format(context=results_string)
123
- )
124
- ],
125
- inputs={
126
- "input": message.content,
127
- "context": results_string
128
- },
129
- settings=settings,
130
- )
131
-
132
- print([m.to_openai() for m in prompt.messages])
133
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  msg = cl.Message(content="")
135
-
136
- # Call OpenAI
137
- async for stream_resp in await client.chat.completions.create(
138
- messages=[m.to_openai() for m in prompt.messages], stream=True, **settings
139
- ):
140
- token = stream_resp.choices[0].delta.content
141
- if not token:
142
- token = ""
143
- await msg.stream_token(token)
144
-
145
- # Update the prompt object with the completion
146
- prompt.completion = msg.content
147
- msg.prompt = prompt
148
-
149
- # Send and close the message stream
150
  await msg.send()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  from chainlit.prompt import Prompt, PromptMessage # importing prompt tools
8
  from chainlit.playground.providers import ChatOpenAI # importing ChatOpenAI tools
9
  from dotenv import load_dotenv
10
+ import arxiv
11
+ from langchain.document_loaders import PyPDFLoader
12
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
13
+ import pinecone
14
+ from langchain.embeddings.openai import OpenAIEmbeddings
15
+ from langchain.embeddings import CacheBackedEmbeddings
16
+ from langchain.storage import LocalFileStore
17
+ from utils.store import index_documents
18
+ from utils.chain import create_chain
19
+ from langchain.vectorstores import Pinecone
20
+ from langchain.chat_models import ChatOpenAI
21
+ from langchain.prompts import ChatPromptTemplate
22
+ from langchain.prompts import PromptTemplate
23
+ from operator import itemgetter
24
+ from langchain.schema.runnable import RunnableSequence
25
+ from langchain.schema import format_document
26
+ from langchain.schema.output_parser import StrOutputParser
27
+ from langchain.prompts.prompt import PromptTemplate
28
+ from pprint import pprint
29
+ from langchain_core.documents.base import Document
30
+ from langchain_core.vectorstores import VectorStoreRetriever
31
+ import langchain
32
+ from langchain.cache import InMemoryCache
33
 
34
  load_dotenv()
35
+ YOUR_API_KEY = os.environ["PINECONE_API_KEY"]
36
+ YOUR_ENV = os.environ["PINECONE_ENV"]
37
+ INDEX_NAME= 'arxiv-paper-index'
38
+ WANDB_API_KEY=os.environ["WANDB_API_KEY"]
39
+ WANDB_PROJECT=os.environ["WANDB_PROJECT"]
40
+ first_run = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
 
43
  @cl.on_chat_start # marks a function that will be executed at the start of a user session
 
45
  settings = {
46
  "model": "gpt-3.5-turbo",
47
  "temperature": 0,
48
+ "max_tokens": 500
 
 
 
49
  }
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  await cl.Message(
52
+ content="What would you like to learn about today? 😊"
53
  ).send()
54
 
55
+ # instantiate arXiv client (on start)
56
+ arxiv_client = arxiv.Client()
57
 
58
+ # create an embedder through a cache interface (locally) (on start)
59
+ store = LocalFileStore("./cache/")
 
 
 
60
 
61
+ core_embeddings_model = OpenAIEmbeddings(
62
+ api_key=os.environ['OPENAI_API_KEY']
63
+ )
64
 
65
+ embedder = CacheBackedEmbeddings.from_bytes_store(
66
+ underlying_embeddings=core_embeddings_model,
67
+ document_embedding_cache=store,
68
+ namespace=core_embeddings_model.model
69
+ )
70
 
71
+ # instantiate pinecone (on start)
72
+ pinecone.init(
73
+ api_key=YOUR_API_KEY,
74
+ environment=YOUR_ENV
75
+ )
76
 
77
+ if INDEX_NAME not in pinecone.list_indexes():
78
+ pinecone.create_index(
79
+ name=INDEX_NAME,
80
+ metric='cosine',
81
+ dimension=1536
82
+ )
83
+ index = pinecone.GRPCIndex(INDEX_NAME)
84
 
85
+ # setup your ChatOpenAI model (on start)
86
+ llm = ChatOpenAI(
87
+ model=settings['model'],
88
+ temperature=settings['temperature'],
89
+ max_tokens=settings['max_tokens'],
90
+ api_key=os.environ["OPENAI_API_KEY"],
91
+ streaming=True
92
+ )
93
+
94
+ # create a prompt cache (locally) (on start)
95
+ langchain.llm_cache = InMemoryCache()
96
+
97
+ # log data in WaB (on start)
98
+ os.environ["LANGCHAIN_WANDB_TRACING"] = "true"
99
+
100
+ tools = {
101
+ "arxiv_client": arxiv_client,
102
+ "index": index,
103
+ "embedder": embedder,
104
+ "llm": llm
105
+ }
106
+ cl.user_session.set("tools", tools)
107
+ cl.user_session.set("settings", settings)
108
+ cl.user_session.set("first_run", False)
109
 
110
 
111
  @cl.on_message # marks a function that should be run each time the chatbot receives a message from a user
112
  async def main(message: cl.Message):
 
113
  settings = cl.user_session.get("settings")
114
+ tools = cl.user_session.get("tools")
115
+ first_run = cl.user_session.get("first_run")
116
+
117
+ if not first_run:
118
+
119
+ arxiv_client: arxiv.Client = tools['arxiv_client']
120
+ index: pinecone.GRPCIndex = tools['index']
121
+ embedder: CacheBackedEmbeddings = tools['embedder']
122
+ llm: ChatOpenAI = tools['llm']
123
+
124
+ # using query search for ArXiv documents (on message)
125
+
126
+ search = arxiv.Search(
127
+ query = message.content,
128
+ max_results = 10,
129
+ sort_by = arxiv.SortCriterion.Relevance
130
+ )
131
+ paper_urls = []
132
+
133
+ sys_message = cl.Message(content="")
134
+ await sys_message.send() # renders a loader
135
+ for result in arxiv_client.results(search):
136
+ paper_urls.append(result.pdf_url)
137
+ sys_message.content = """
138
+ I found some papers, let me study them real quick to help
139
+ you learn, don't worry it'll be a few seconds 😉"""
140
+ await sys_message.update()
141
+ await sys_message.send()
142
+
143
+ sys_message = cl.Message(content="")
144
+ await sys_message.send() # renders a loader
145
+ # load them and split them (on message)
146
+ docs = []
147
+ for paper_url in paper_urls:
148
+ try:
149
+ loader = PyPDFLoader(paper_url)
150
+ docs.append(loader.load())
151
+ except:
152
+ print(f"Error loading {paper_url}")
153
+
154
+ text_splitter = RecursiveCharacterTextSplitter(
155
+ chunk_size = 400,
156
+ chunk_overlap = 30,
157
+ length_function = len
158
+ )
159
+
160
+ # create an index using pinecone (on message)
161
+ index_documents(docs, text_splitter, embedder, index)
162
+ sys_message.content = "Done studying :)"
163
+ await sys_message.update()
164
+ await sys_message.send()
165
+
166
+ text_field = "source_document"
167
+ index = pinecone.Index(INDEX_NAME)
168
+ vectorstore = Pinecone(
169
+ index=index,
170
+ embedding=embedder.embed_query,
171
+ text_key=text_field
172
+ )
173
+ retriever: VectorStoreRetriever = vectorstore.as_retriever()
174
+
175
+ # create the chain (on message)
176
+ retrieval_augmented_qa_chain: RunnableSequence = create_chain(retriever=retriever, llm=llm)
177
+
178
+ # message.content = await cl.AskUserMessage(
179
+ # content="Ask away"
180
+ # ).send()
181
+
182
+ # run
183
  msg = cl.Message(content="")
184
+ for chunk in retrieval_augmented_qa_chain.stream({"question": f"{message.content}"}):
185
+ pprint(chunk)
186
+ if res:= chunk.get('response'):
187
+ await msg.stream_token(res.content)
 
 
 
 
 
 
 
 
 
 
 
188
  await msg.send()
189
+ cl.user_session.set("first_run", True)
190
+ # first_run = True
191
+
192
+
193
+ # client = AsyncOpenAI()
194
+
195
+ # print(message.content)
196
+
197
+ # results_list = vector_db.search_by_text(query_text=message.content, k=3, return_as_text=True)
198
+ # if results_list:
199
+ # results_string = "\n\n".join(results_list)
200
+ # else:
201
+ # results_string = ""
202
+
203
+ # prompt = Prompt(
204
+ # provider=ChatOpenAI.id,
205
+ # messages=[
206
+ # PromptMessage(
207
+ # role="system",
208
+ # template=system_template,
209
+ # formatted=system_template,
210
+ # ),
211
+ # PromptMessage(
212
+ # role="user",
213
+ # template=user_template,
214
+ # formatted=user_template.format(input=message.content),
215
+ # ),
216
+ # PromptMessage(
217
+ # role="assistant",
218
+ # template=assistant_template,
219
+ # formatted=assistant_template.format(context=results_string)
220
+ # )
221
+ # ],
222
+ # inputs={
223
+ # "input": message.content,
224
+ # "context": results_string
225
+ # },
226
+ # settings=settings,
227
+ # )
228
+
229
+ # print([m.to_openai() for m in prompt.messages])
230
+
231
+ # msg = cl.Message(content="")
232
+
233
+ # # Call OpenAI
234
+ # async for stream_resp in await client.chat.completions.create(
235
+ # messages=[m.to_openai() for m in prompt.messages], stream=True, **settings
236
+ # ):
237
+ # token = stream_resp.choices[0].delta.content
238
+ # if not token:
239
+ # token = ""
240
+ # await msg.stream_token(token)
241
+
242
+ # # Update the prompt object with the completion
243
+ # prompt.completion = msg.content
244
+ # msg.prompt = prompt
245
+
246
+ # # Send and close the message stream
247
+ # await msg.send()
requirements.txt CHANGED
@@ -4,8 +4,6 @@ openai==1.3.5
4
  tiktoken==0.5.1
5
  python-dotenv==1.0.0
6
  numpy==1.25.2
7
- pandas
8
- scikit-learn
9
- matplotlib
10
- plotly
11
- pdfminer.six
 
4
  tiktoken==0.5.1
5
  python-dotenv==1.0.0
6
  numpy==1.25.2
7
+ langchain
8
+ pinecone-client[grpc]
9
+ pypdf
 
 
utils/__init__.py ADDED
File without changes
utils/chain.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from operator import itemgetter
2
+ from langchain_core.vectorstores import VectorStoreRetriever
3
+ from langchain.schema.runnable import RunnableLambda, RunnableParallel, RunnableSequence
4
+ from langchain.chat_models import ChatOpenAI
5
+ from langchain.prompts import PromptTemplate
6
+ from langchain_core.documents import Document
7
+ from langchain_core.messages.ai import AIMessage
8
+
9
+
10
+ template = """
11
+ You are a helpful assistant, your job is to answer the user's question using the relevant context.
12
+ CONTEXT
13
+ =========
14
+ {context}
15
+ =========
16
+
17
+ User question: {question}
18
+ """
19
+ prompt = PromptTemplate.from_template(template=template)
20
+
21
+
22
+ def to_doc(input: AIMessage) -> list[Document]:
23
+ return [Document(page_content="LLM", metadata={'chunk': 1.0, 'page_number': 1.0, 'text':input.content})]
24
+
25
+ def merge_docs(a: dict[str, list[Document]]) -> list[Document]:
26
+ merged_docs = []
27
+ for key,value in a.items():
28
+ merged_docs.extend(value)
29
+ return merged_docs
30
+
31
+
32
+
33
+ def create_chain(**kwargs) -> RunnableSequence:
34
+ """
35
+ Requires retriever, llm and prompt
36
+ """
37
+
38
+ retriever: VectorStoreRetriever = kwargs["retriever"]
39
+ llm:ChatOpenAI = kwargs.get("llm", None)
40
+
41
+
42
+ if not isinstance(retriever, VectorStoreRetriever):
43
+ raise ValueError
44
+ if not isinstance(llm, ChatOpenAI):
45
+ raise ValueError
46
+
47
+ docs_chain = (itemgetter("question") | retriever).with_config(config={"run_name": "docs"})
48
+ self_knowledge_chain = (itemgetter("question") | llm | to_doc).with_config(config={"run_name": "self knowledge"})
49
+ response_chain = (prompt | llm).with_config(config={"run_name": "response"})
50
+ merge_docs_link = RunnableLambda(merge_docs).with_config(config={"run_name": "merge docs"})
51
+ context_chain = (
52
+ RunnableParallel(
53
+ {
54
+ "docs": docs_chain,
55
+ "self_knowledge": self_knowledge_chain
56
+ }
57
+ ).with_config(config={"run_name": "parallel context"})
58
+ | merge_docs_link
59
+ )
60
+
61
+ retrieval_augmented_qa_chain = (
62
+ RunnableParallel({
63
+ "question": itemgetter("question"),
64
+ "context": context_chain
65
+ })
66
+ | RunnableParallel({
67
+ "response": response_chain,
68
+ "context": itemgetter("context"),
69
+ })
70
+ )
71
+ return retrieval_augmented_qa_chain
utils/store.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm.auto import tqdm
2
+ from langchain.embeddings import CacheBackedEmbeddings
3
+ from uuid import uuid4
4
+ from langchain_core.documents import Document
5
+ from typing import List
6
+ from langchain.text_splitter import TextSplitter
7
+ from pinecone import GRPCIndex
8
+
9
+ BATCH_LIMIT = 100
10
+
11
+ def index_documents(
12
+ docs: List[Document],
13
+ text_splitter: TextSplitter,
14
+ embedder: CacheBackedEmbeddings,
15
+ index: GRPCIndex) -> None:
16
+
17
+ texts = []
18
+ metadatas = []
19
+
20
+ for i in tqdm(range(len(docs))):
21
+ for doc in docs[i]:
22
+ metadata = {
23
+ 'source_document' : doc.metadata["source"],
24
+ 'page_number' : doc.metadata["page"]
25
+ }
26
+
27
+ record_texts = text_splitter.split_text(doc.page_content)
28
+
29
+ record_metadatas = [{
30
+ "chunk": j, "text": text, **metadata
31
+ } for j, text in enumerate(record_texts)]
32
+ texts.extend(record_texts)
33
+ metadatas.extend(record_metadatas)
34
+ if len(texts) >= BATCH_LIMIT:
35
+ ids = [str(uuid4()) for _ in range(len(texts))]
36
+ embeds = embedder.embed_documents(texts)
37
+ index.upsert(vectors=zip(ids, embeds, metadatas))
38
+ texts = []
39
+ metadatas = []
40
+
41
+ if len(texts) > 0:
42
+ ids = [str(uuid4()) for _ in range(len(texts))]
43
+ embeds = embedder.embed_documents(texts)
44
+ index.upsert(vectors=zip(ids, embeds, metadatas))