srinidhidevaraj commited on
Commit
99fde54
1 Parent(s): 0760b60

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -158
app.py CHANGED
@@ -1,158 +1,147 @@
1
- import streamlit as st
2
- import os
3
- from langchain_groq import ChatGroq
4
- from langchain_community.document_loaders import WebBaseLoader
5
- from langchain_community.embeddings import OllamaEmbeddings
6
- from langchain.text_splitter import RecursiveCharacterTextSplitter
7
- from langchain.chains.combine_documents import create_stuff_documents_chain
8
- from langchain_core.prompts import ChatPromptTemplate
9
- from langchain.chains import create_retrieval_chain
10
- from langchain_community.vectorstores import FAISS
11
- from langchain_community.document_loaders import PyPDFLoader
12
- from langchain_community.document_loaders import PyPDFDirectoryLoader
13
- from langchain_community.embeddings import HuggingFaceBgeEmbeddings
14
- # from langchain.vectorstores.cassandra import Cassandra
15
- from langchain_community.vectorstores import Cassandra
16
- from langchain_community.llms import Ollama
17
- from cassandra.auth import PlainTextAuthProvider
18
- import tempfile
19
- import cassio
20
- from PyPDF2 import PdfReader
21
- from cassandra.cluster import Cluster
22
- import warnings
23
- warnings.filterwarnings("ignore")
24
-
25
- from dotenv import load_dotenv
26
- import time
27
- load_dotenv()
28
-
29
- ASTRA_DB_SECURE_BUNDLE_PATH ='G:/GENAI/groq_astra/secure-connect-pdf-query-db.zip'
30
- groq_api_key=os.environ['GROQ_API_KEY']
31
- os.environ["LANGCHAIN_TRACING_V2"]="true"
32
- os.environ["LANGCHAIN_API_KEY"]="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"
33
- os.environ["LANGCHAIN_PROJECT"]="pt-only-pupil-70"
34
- os.environ["LANGCHAIN_ENDPOINT"]="https://api.smith.langchain.com"
35
- ASTRA_DB_APPLICATION_TOKEN="AstraCS:SuHeqXWZDTGfvwliFFyCnCvM:29d8b2ec4888d271b8aa32b3675a20c050280680f2a95873fa33d265c889ae0d"
36
- ASTRA_DB_ID=os.getenv("ASTRA_DB_ID")
37
- ASTRA_DB_KEYSPACE="pdf_query_db"
38
- ASTRA_DB_API_ENDPOINT="https://68dfd628-1ad7-4951-ae84-45402a193c81-us-east1.apps.astra.datastax.com"
39
- ASTRA_DB_CLIENT_ID="SuHeqXWZDTGfvwliFFyCnCvM"
40
- ASTRA_DB_CLIENT_SECRET="JNZsN-R156.BfMJ+B4M4XvFMWNtQvxW2QZiR4kjTnPHdy9bcszr3UA-ZK7X_c_P20cKajX1_CeodPuQwJZvfWDfRfY_sEFCGdrYc2pobxoOX7UQ4p5.kIf1.oraLa-p"
41
- ASTRA_DB_TABLE='qa_mini_demo'
42
- cassio.init(token=ASTRA_DB_APPLICATION_TOKEN,database_id=ASTRA_DB_ID,secure_connect_bundle=ASTRA_DB_SECURE_BUNDLE_PATH)
43
-
44
- cloud_config = {
45
- 'secure_connect_bundle': ASTRA_DB_SECURE_BUNDLE_PATH
46
- }
47
-
48
- def doc_loader(pdf_reader):
49
-
50
- encode_kwargs = {'normalize_embeddings': True}
51
- huggigface_embeddings=HuggingFaceBgeEmbeddings(
52
- model_name='BAAI/bge-small-en-v1.5',
53
- # model_name='sentence-transformers/all-MiniLM-16-v2',
54
- model_kwargs={'device':'cpu'},
55
- encode_kwargs=encode_kwargs)
56
-
57
-
58
- loader=PyPDFLoader(pdf_reader)
59
- documents=loader.load_and_split()
60
-
61
-
62
- text_splitter=RecursiveCharacterTextSplitter(chunk_size=1000,chunk_overlap=200)
63
- final_documents=text_splitter.split_documents(documents)
64
-
65
- astrasession = Cluster(
66
- cloud={"secure_connect_bundle": ASTRA_DB_SECURE_BUNDLE_PATH},
67
- auth_provider=PlainTextAuthProvider("token", ASTRA_DB_APPLICATION_TOKEN),
68
- ).connect()
69
-
70
-
71
- # Truncate the existing table
72
- astrasession.execute(f'TRUNCATE {ASTRA_DB_KEYSPACE}.{ASTRA_DB_TABLE}')
73
-
74
- astra_vector_store=Cassandra(
75
- embedding=huggigface_embeddings,
76
- table_name="qa_mini_demo",
77
- session=astrasession,
78
- keyspace=ASTRA_DB_KEYSPACE
79
- )
80
-
81
-
82
- astra_vector_store.add_documents(final_documents)
83
-
84
- return astra_vector_store
85
-
86
- def prompt_temp():
87
- prompt=ChatPromptTemplate.from_template(
88
- """
89
- Answer the question based on provided context only.
90
- Your context retrieval mechanism works correclty but your are not providing answer from context.
91
- Please provide the most accurate response based on question.
92
- {context},
93
- Questions:{input}
94
- """
95
- )
96
-
97
- return prompt
98
-
99
- def generate_response(llm,prompt,user_input,vectorstore):
100
-
101
-
102
- document_chain=create_stuff_documents_chain(llm,prompt)
103
- retriever=vectorstore.as_retriever(search_type="similarity",search_kwargs={"k":5})
104
- retrieval_chain=create_retrieval_chain(retriever,document_chain)
105
- response=retrieval_chain.invoke({"input":user_input})
106
-
107
- return response
108
- # ['answer']
109
-
110
-
111
-
112
- def main():
113
- st.set_page_config(page_title='Chat Groq Demo')
114
- st.header('Chat Groq Demo')
115
- user_input=st.text_input('Enter the Prompt here')
116
- file=st.file_uploader('Choose Invoice File',type='pdf')
117
-
118
-
119
- submit = st.button("Submit")
120
- st.session_state.submit_clicked = False
121
- if submit :
122
- st.session_state.submit_clicked = True
123
- if user_input and file:
124
- with tempfile.NamedTemporaryFile(delete=False) as temp_file:
125
- temp_file.write(file.getbuffer())
126
- file_path = temp_file.name
127
- # with open(file.name, mode='wb') as w:
128
- # # w.write(file.getvalue())
129
- # w.write(file.getbuffer())
130
- llm=ChatGroq(groq_api_key=groq_api_key,model_name="gemma-7b-it")
131
- prompt=prompt_temp()
132
-
133
- vectorstore=doc_loader(file_path)
134
-
135
-
136
- response=generate_response(llm,prompt,user_input,vectorstore)
137
- st.write(response['answer'])
138
-
139
- with st.expander("Document Similarity Search"):
140
- for i,doc in enumerate(response['context']):
141
- st.write(doc.page_content)
142
- st.write('---------------------------------')
143
-
144
-
145
-
146
- if __name__=="__main__":
147
- main()
148
-
149
-
150
-
151
-
152
-
153
-
154
-
155
-
156
-
157
-
158
-
 
1
+ import streamlit as st
2
+ import os
3
+ from langchain_groq import ChatGroq
4
+ from langchain_community.document_loaders import WebBaseLoader
5
+ from langchain_community.embeddings import OllamaEmbeddings
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
7
+ from langchain.chains.combine_documents import create_stuff_documents_chain
8
+ from langchain_core.prompts import ChatPromptTemplate
9
+ from langchain.chains import create_retrieval_chain
10
+ from langchain_community.vectorstores import FAISS
11
+ from langchain_community.document_loaders import PyPDFLoader
12
+ from langchain_community.document_loaders import PyPDFDirectoryLoader
13
+ from langchain_community.embeddings import HuggingFaceBgeEmbeddings
14
+ # from langchain.vectorstores.cassandra import Cassandra
15
+ from langchain_community.vectorstores import Cassandra
16
+ from langchain_community.llms import Ollama
17
+ from cassandra.auth import PlainTextAuthProvider
18
+ import tempfile
19
+ import cassio
20
+ from PyPDF2 import PdfReader
21
+ from cassandra.cluster import Cluster
22
+ import warnings
23
+ warnings.filterwarnings("ignore")
24
+
25
+ from dotenv import load_dotenv
26
+ import time
27
+ load_dotenv()
28
+
29
+ ASTRA_DB_SECURE_BUNDLE_PATH ='secure-connect-pdf-query-db.zip'
30
+
31
+ cassio.init(token=ASTRA_DB_APPLICATION_TOKEN,database_id=ASTRA_DB_ID,secure_connect_bundle=ASTRA_DB_SECURE_BUNDLE_PATH)
32
+
33
+ cloud_config = {
34
+ 'secure_connect_bundle': ASTRA_DB_SECURE_BUNDLE_PATH
35
+ }
36
+
37
+ def doc_loader(pdf_reader):
38
+
39
+ encode_kwargs = {'normalize_embeddings': True}
40
+ huggigface_embeddings=HuggingFaceBgeEmbeddings(
41
+ model_name='BAAI/bge-small-en-v1.5',
42
+ # model_name='sentence-transformers/all-MiniLM-16-v2',
43
+ model_kwargs={'device':'cpu'},
44
+ encode_kwargs=encode_kwargs)
45
+
46
+
47
+ loader=PyPDFLoader(pdf_reader)
48
+ documents=loader.load_and_split()
49
+
50
+
51
+ text_splitter=RecursiveCharacterTextSplitter(chunk_size=1000,chunk_overlap=200)
52
+ final_documents=text_splitter.split_documents(documents)
53
+
54
+ astrasession = Cluster(
55
+ cloud={"secure_connect_bundle": ASTRA_DB_SECURE_BUNDLE_PATH},
56
+ auth_provider=PlainTextAuthProvider("token", ASTRA_DB_APPLICATION_TOKEN),
57
+ ).connect()
58
+
59
+
60
+ # Truncate the existing table
61
+ astrasession.execute(f'TRUNCATE {ASTRA_DB_KEYSPACE}.{ASTRA_DB_TABLE}')
62
+
63
+ astra_vector_store=Cassandra(
64
+ embedding=huggigface_embeddings,
65
+ table_name="qa_mini_demo",
66
+ session=astrasession,
67
+ keyspace=ASTRA_DB_KEYSPACE
68
+ )
69
+
70
+
71
+ astra_vector_store.add_documents(final_documents)
72
+
73
+ return astra_vector_store
74
+
75
+ def prompt_temp():
76
+ prompt=ChatPromptTemplate.from_template(
77
+ """
78
+ Answer the question based on provided context only.
79
+ Your context retrieval mechanism works correclty but your are not providing answer from context.
80
+ Please provide the most accurate response based on question.
81
+ {context},
82
+ Questions:{input}
83
+ """
84
+ )
85
+
86
+ return prompt
87
+
88
+ def generate_response(llm,prompt,user_input,vectorstore):
89
+
90
+
91
+ document_chain=create_stuff_documents_chain(llm,prompt)
92
+ retriever=vectorstore.as_retriever(search_type="similarity",search_kwargs={"k":5})
93
+ retrieval_chain=create_retrieval_chain(retriever,document_chain)
94
+ response=retrieval_chain.invoke({"input":user_input})
95
+
96
+ return response
97
+ # ['answer']
98
+
99
+
100
+
101
+ def main():
102
+ st.set_page_config(page_title='Chat Groq Demo')
103
+ st.header('Chat Groq Demo')
104
+ user_input=st.text_input('Enter the Prompt here')
105
+ file=st.file_uploader('Choose Invoice File',type='pdf')
106
+
107
+
108
+ submit = st.button("Submit")
109
+ st.session_state.submit_clicked = False
110
+ if submit :
111
+ st.session_state.submit_clicked = True
112
+ if user_input and file:
113
+ with tempfile.NamedTemporaryFile(delete=False) as temp_file:
114
+ temp_file.write(file.getbuffer())
115
+ file_path = temp_file.name
116
+ # with open(file.name, mode='wb') as w:
117
+ # # w.write(file.getvalue())
118
+ # w.write(file.getbuffer())
119
+ llm=ChatGroq(groq_api_key=groq_api_key,model_name="gemma-7b-it")
120
+ prompt=prompt_temp()
121
+
122
+ vectorstore=doc_loader(file_path)
123
+
124
+
125
+ response=generate_response(llm,prompt,user_input,vectorstore)
126
+ st.write(response['answer'])
127
+
128
+ with st.expander("Document Similarity Search"):
129
+ for i,doc in enumerate(response['context']):
130
+ st.write(doc.page_content)
131
+ st.write('---------------------------------')
132
+
133
+
134
+
135
+ if __name__=="__main__":
136
+ main()
137
+
138
+
139
+
140
+
141
+
142
+
143
+
144
+
145
+
146
+
147
+