Spaces:
Running
Running
Commit
Β·
7b7f6c4
1
Parent(s):
28ba981
rag app
Browse files- Dockerfile +11 -0
- README.md +4 -7
- app.py +83 -0
- chainlit.md +5 -0
- data/test.txt +1 -0
- requirements.txt +108 -0
- src/retrieval_lib.py +105 -0
Dockerfile
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.9
|
2 |
+
RUN useradd -m -u 1000 user
|
3 |
+
USER user
|
4 |
+
ENV HOME=/home/user \
|
5 |
+
PATH=/home/user/.local/bin:$PATH
|
6 |
+
WORKDIR $HOME/app
|
7 |
+
COPY --chown=user . $HOME/app
|
8 |
+
COPY ./requirements.txt ~/app/requirements.txt
|
9 |
+
RUN pip install -r requirements.txt
|
10 |
+
COPY . .
|
11 |
+
CMD ["chainlit", "run", "app.py", "--port", "7890"]
|
README.md
CHANGED
@@ -1,11 +1,8 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: docker
|
7 |
pinned: false
|
8 |
-
license: openrail
|
9 |
---
|
10 |
-
|
11 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: PDF RAG Demo
|
3 |
+
emoji: π
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: yellow
|
6 |
sdk: docker
|
7 |
pinned: false
|
|
|
8 |
---
|
|
|
|
app.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# You can find this code for Chainlit python streaming here (https://docs.chainlit.io/concepts/streaming/python)
|
2 |
+
|
3 |
+
# OpenAI Chat completion
|
4 |
+
import os
|
5 |
+
from openai import AsyncOpenAI # importing openai for API usage
|
6 |
+
import chainlit as cl # importing chainlit for our app
|
7 |
+
from chainlit.prompt import Prompt, PromptMessage # importing prompt tools
|
8 |
+
from chainlit.playground.providers import ChatOpenAI # importing ChatOpenAI tools
|
9 |
+
from dotenv import load_dotenv
|
10 |
+
from src.retrieval_lib import initialize_index, load_pdf_to_text, split_text, load_text_to_index, query_index, create_answer_prompt, generate_answer
|
11 |
+
|
12 |
+
load_dotenv()
|
13 |
+
|
14 |
+
retriever = initialize_index()
|
15 |
+
|
16 |
+
@cl.on_chat_start # marks a function that will be executed at the start of a user session
|
17 |
+
async def start_chat():
|
18 |
+
settings = {
|
19 |
+
"model": "gpt-3.5-turbo",
|
20 |
+
"temperature": 0,
|
21 |
+
"max_tokens": 500,
|
22 |
+
"top_p": 1,
|
23 |
+
"frequency_penalty": 0,
|
24 |
+
"presence_penalty": 0,
|
25 |
+
}
|
26 |
+
cl.user_session.set("settings", settings)
|
27 |
+
|
28 |
+
|
29 |
+
@cl.on_message # marks a function that should be run each time the chatbot receives a message from a user
|
30 |
+
async def main(message: cl.Message):
|
31 |
+
settings = cl.user_session.get("settings")
|
32 |
+
|
33 |
+
client = AsyncOpenAI()
|
34 |
+
|
35 |
+
print(message.content)
|
36 |
+
|
37 |
+
# prompt = Prompt(
|
38 |
+
# provider=ChatOpenAI.id,
|
39 |
+
# messages=[
|
40 |
+
# PromptMessage(
|
41 |
+
# role="system",
|
42 |
+
# template=system_template,
|
43 |
+
# formatted=system_template,
|
44 |
+
# ),
|
45 |
+
# PromptMessage(
|
46 |
+
# role="user",
|
47 |
+
# template=user_template,
|
48 |
+
# formatted=user_template.format(input=message.content),
|
49 |
+
# ),
|
50 |
+
# ],
|
51 |
+
# inputs={"input": message.content},
|
52 |
+
# settings=settings,
|
53 |
+
#)
|
54 |
+
|
55 |
+
#print([m.to_openai() for m in prompt.messages])
|
56 |
+
|
57 |
+
query = message.content
|
58 |
+
# query = "what is the reason for the lawsuit"
|
59 |
+
retrieved_docs = query_index(retriever, query)
|
60 |
+
print("retrieved_docs: \n", len(retrieved_docs))
|
61 |
+
answer_prompt = create_answer_prompt()
|
62 |
+
print("answer_prompt: \n", answer_prompt)
|
63 |
+
result = generate_answer(retriever, answer_prompt, query)
|
64 |
+
print("result: \n", result["response"].content)
|
65 |
+
|
66 |
+
msg = cl.Message(content="")
|
67 |
+
|
68 |
+
# Call OpenAI
|
69 |
+
#async for stream_resp in await client.chat.completions.create(
|
70 |
+
# messages=[m.to_openai() for m in prompt.messages], stream=True, **settings
|
71 |
+
#):
|
72 |
+
# token = stream_resp.choices[0].delta.content
|
73 |
+
# if not token:
|
74 |
+
# token = ""
|
75 |
+
# await msg.stream_token(token)
|
76 |
+
|
77 |
+
# Update the prompt object with the completion
|
78 |
+
#prompt.completion = msg.content
|
79 |
+
#msg.prompt = prompt
|
80 |
+
msg.content = result["response"].content
|
81 |
+
|
82 |
+
# Send and close the message stream
|
83 |
+
await msg.send()
|
chainlit.md
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# PDF RAG
|
2 |
+
|
3 |
+
RAG over a PDF document: Musk vs Altman/Openai (https://www.courthousenews.com/wp-content/uploads/2024/02/musk-v-altman-openai-complaint-sf.pdf)
|
4 |
+
|
5 |
+
Disclaimer: this is running the query over the pdf document and generating answers using LLM. LLMs can hellucinate and can generate wrong answers.
|
data/test.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
i
|
requirements.txt
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiofiles==23.2.1
|
2 |
+
aiohttp==3.9.3
|
3 |
+
aiosignal==1.3.1
|
4 |
+
annotated-types==0.6.0
|
5 |
+
anyio==3.7.1
|
6 |
+
appdirs==1.4.4
|
7 |
+
async-timeout==4.0.3
|
8 |
+
asyncer==0.0.2
|
9 |
+
attrs==23.2.0
|
10 |
+
bidict==0.23.1
|
11 |
+
certifi==2024.2.2
|
12 |
+
chainlit==0.7.700
|
13 |
+
charset-normalizer==3.3.2
|
14 |
+
click==8.1.7
|
15 |
+
dataclasses-json==0.5.14
|
16 |
+
datasets==2.18.0
|
17 |
+
Deprecated==1.2.14
|
18 |
+
dill==0.3.8
|
19 |
+
distro==1.9.0
|
20 |
+
exceptiongroup==1.2.0
|
21 |
+
faiss-cpu==1.8.0
|
22 |
+
fastapi==0.100.1
|
23 |
+
fastapi-socketio==0.0.10
|
24 |
+
filelock==3.13.1
|
25 |
+
filetype==1.2.0
|
26 |
+
frozenlist==1.4.1
|
27 |
+
fsspec==2024.2.0
|
28 |
+
googleapis-common-protos==1.62.0
|
29 |
+
grpcio==1.62.1
|
30 |
+
h11==0.14.0
|
31 |
+
httpcore==0.17.3
|
32 |
+
httpx==0.24.1
|
33 |
+
huggingface-hub==0.21.4
|
34 |
+
idna==3.6
|
35 |
+
importlib-metadata==6.11.0
|
36 |
+
jsonpatch==1.33
|
37 |
+
jsonpointer==2.4
|
38 |
+
langchain==0.1.11
|
39 |
+
langchain-community==0.0.27
|
40 |
+
langchain-core==0.1.30
|
41 |
+
langchain-openai==0.0.8
|
42 |
+
langchain-text-splitters==0.0.1
|
43 |
+
langchainhub==0.1.15
|
44 |
+
langsmith==0.1.23
|
45 |
+
Lazify==0.4.0
|
46 |
+
marshmallow==3.21.1
|
47 |
+
multidict==6.0.5
|
48 |
+
multiprocess==0.70.16
|
49 |
+
mypy-extensions==1.0.0
|
50 |
+
nest-asyncio==1.6.0
|
51 |
+
numpy==1.26.4
|
52 |
+
openai==1.13.3
|
53 |
+
opentelemetry-api==1.23.0
|
54 |
+
opentelemetry-exporter-otlp==1.23.0
|
55 |
+
opentelemetry-exporter-otlp-proto-common==1.23.0
|
56 |
+
opentelemetry-exporter-otlp-proto-grpc==1.23.0
|
57 |
+
opentelemetry-exporter-otlp-proto-http==1.23.0
|
58 |
+
opentelemetry-instrumentation==0.44b0
|
59 |
+
opentelemetry-proto==1.23.0
|
60 |
+
opentelemetry-sdk==1.23.0
|
61 |
+
opentelemetry-semantic-conventions==0.44b0
|
62 |
+
orjson==3.9.15
|
63 |
+
packaging==23.2
|
64 |
+
pandas==2.2.1
|
65 |
+
protobuf==4.25.3
|
66 |
+
pyarrow==15.0.1
|
67 |
+
pyarrow-hotfix==0.6
|
68 |
+
pydantic==2.6.3
|
69 |
+
pydantic_core==2.16.3
|
70 |
+
PyJWT==2.8.0
|
71 |
+
PyMuPDF==1.23.26
|
72 |
+
PyMuPDFb==1.23.22
|
73 |
+
pysbd==0.3.4
|
74 |
+
python-dateutil==2.9.0.post0
|
75 |
+
python-dotenv==1.0.1
|
76 |
+
python-engineio==4.9.0
|
77 |
+
python-graphql-client==0.4.3
|
78 |
+
python-multipart==0.0.6
|
79 |
+
python-socketio==5.11.1
|
80 |
+
pytz==2024.1
|
81 |
+
PyYAML==6.0.1
|
82 |
+
ragas==0.1.3
|
83 |
+
regex==2023.12.25
|
84 |
+
requests==2.31.0
|
85 |
+
simple-websocket==1.0.0
|
86 |
+
six==1.16.0
|
87 |
+
sniffio==1.3.1
|
88 |
+
SQLAlchemy==2.0.28
|
89 |
+
starlette==0.27.0
|
90 |
+
syncer==2.0.3
|
91 |
+
tenacity==8.2.3
|
92 |
+
tiktoken==0.6.0
|
93 |
+
tomli==2.0.1
|
94 |
+
tqdm==4.66.2
|
95 |
+
types-requests==2.31.0.20240311
|
96 |
+
typing-inspect==0.9.0
|
97 |
+
typing_extensions==4.10.0
|
98 |
+
tzdata==2024.1
|
99 |
+
uptrace==1.22.0
|
100 |
+
urllib3==2.2.1
|
101 |
+
uvicorn==0.23.2
|
102 |
+
watchfiles==0.20.0
|
103 |
+
websockets==12.0
|
104 |
+
wrapt==1.16.0
|
105 |
+
wsproto==1.2.0
|
106 |
+
xxhash==3.4.1
|
107 |
+
yarl==1.9.4
|
108 |
+
zipp==3.17.0
|
src/retrieval_lib.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# import libraries
|
3 |
+
import os
|
4 |
+
import openai
|
5 |
+
from langchain_community.document_loaders import PyMuPDFLoader
|
6 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
7 |
+
from langchain_openai import OpenAIEmbeddings
|
8 |
+
from langchain_community.vectorstores import FAISS
|
9 |
+
from langchain.prompts import ChatPromptTemplate
|
10 |
+
from operator import itemgetter
|
11 |
+
from langchain_openai import ChatOpenAI
|
12 |
+
from langchain_core.output_parsers import StrOutputParser
|
13 |
+
from langchain_core.runnables import RunnablePassthrough
|
14 |
+
|
15 |
+
|
16 |
+
LLM_MODEL_NAME = "gpt-3.5-turbo"
|
17 |
+
|
18 |
+
|
19 |
+
# load PDF doc and convert to text
|
20 |
+
def load_pdf_to_text(pdf_path):
|
21 |
+
# create a document loader
|
22 |
+
loader = PyMuPDFLoader(pdf_path)
|
23 |
+
# load the document
|
24 |
+
doc = loader.load()
|
25 |
+
return doc
|
26 |
+
|
27 |
+
def split_text(text):
|
28 |
+
# create a text splitter
|
29 |
+
splitter = RecursiveCharacterTextSplitter(
|
30 |
+
chunk_size=700,
|
31 |
+
chunk_overlap=100,
|
32 |
+
)
|
33 |
+
# split the text
|
34 |
+
split_text = splitter.split_documents(text)
|
35 |
+
return split_text
|
36 |
+
|
37 |
+
# load text into FAISS index
|
38 |
+
def load_text_to_index(doc_splits):
|
39 |
+
embeddings = OpenAIEmbeddings(
|
40 |
+
model = "text-embedding-3-small"
|
41 |
+
)
|
42 |
+
vector_store = FAISS.from_documents(doc_splits, embeddings)
|
43 |
+
retriever = vector_store.as_retriever()
|
44 |
+
return retriever
|
45 |
+
|
46 |
+
# query FAISS index
|
47 |
+
def query_index(retriever, query):
|
48 |
+
retrieved_docs = retriever.invoke(query)
|
49 |
+
return retrieved_docs
|
50 |
+
|
51 |
+
# create answer prompt
|
52 |
+
def create_answer_prompt():
|
53 |
+
template = """Answer the question based only on the following context. If you cannot answer the question with the context, please respond with 'I don't know':
|
54 |
+
|
55 |
+
Context:
|
56 |
+
{context}
|
57 |
+
|
58 |
+
Question:
|
59 |
+
{question}
|
60 |
+
"""
|
61 |
+
print("template: ", len(template))
|
62 |
+
prompt = ChatPromptTemplate.from_template(template)
|
63 |
+
return prompt
|
64 |
+
|
65 |
+
# generate answer
|
66 |
+
def generate_answer(retriever, answer_prompt, query):
|
67 |
+
print("generate_answer()")
|
68 |
+
QnA_LLM = ChatOpenAI(model_name=LLM_MODEL_NAME, temperature=0.0)
|
69 |
+
|
70 |
+
retrieval_qna_chain = (
|
71 |
+
{"context": itemgetter("question") | retriever, "question": itemgetter("question")}
|
72 |
+
| RunnablePassthrough.assign(context = itemgetter("context"))
|
73 |
+
| {"response": answer_prompt | QnA_LLM, "context": itemgetter("context")}
|
74 |
+
)
|
75 |
+
result = retrieval_qna_chain.invoke({"question": query})
|
76 |
+
return result
|
77 |
+
|
78 |
+
def initialize_index():
|
79 |
+
# load pdf
|
80 |
+
cwd = os.path.abspath(os.getcwd())
|
81 |
+
data_dir = "data"
|
82 |
+
#pdf_file = "nvidia_earnings_report.pdf"
|
83 |
+
pdf_file = "musk-v-altman-openai-complaint-sf.pdf"
|
84 |
+
pdf_path = os.path.join(cwd, data_dir, pdf_file)
|
85 |
+
print("path: ", pdf_path)
|
86 |
+
doc = load_pdf_to_text(pdf_path)
|
87 |
+
print("doc: \n", len(doc))
|
88 |
+
doc_splits = split_text(doc)
|
89 |
+
print("doc_splits length: \n", len(doc_splits))
|
90 |
+
retriever = load_text_to_index(doc_splits)
|
91 |
+
return retriever
|
92 |
+
|
93 |
+
def main():
|
94 |
+
retriever = initialize_index()
|
95 |
+
# query = "Who is the E-VP, Operations"
|
96 |
+
query = "what is the reason for the lawsuit"
|
97 |
+
retrieved_docs = query_index(retriever, query)
|
98 |
+
print("retrieved_docs: \n", len(retrieved_docs))
|
99 |
+
answer_prompt = create_answer_prompt()
|
100 |
+
print("answer_prompt: \n", answer_prompt)
|
101 |
+
result = generate_answer(retriever, answer_prompt, query)
|
102 |
+
print("result: \n", result["response"].content)
|
103 |
+
|
104 |
+
if __name__ == "__main__":
|
105 |
+
main()
|