Spaces:
Sleeping
Sleeping
Sandaruth
commited on
Commit
•
2ffda8f
1
Parent(s):
870ee5f
update model
Browse files- Retrieval.py +34 -0
- app.py +19 -15
- model.py +1 -11
Retrieval.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from model import llm, vectorstore, splitter, embedding, QA_PROMPT
|
3 |
+
|
4 |
+
|
5 |
+
# Chain for Web
|
6 |
+
from langchain.chains import RetrievalQA
|
7 |
+
|
8 |
+
bsic_chain = RetrievalQA.from_chain_type(
|
9 |
+
llm=llm,
|
10 |
+
chain_type="stuff",
|
11 |
+
retriever = vectorstore.as_retriever(search_kwargs={"k": 4}),
|
12 |
+
return_source_documents= True,
|
13 |
+
input_key="question",
|
14 |
+
chain_type_kwargs={"prompt": QA_PROMPT},
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
from langchain.retrievers.multi_query import MultiQueryRetriever
|
20 |
+
# from kk import MultiQueryRetriever
|
21 |
+
|
22 |
+
retriever_from_llm = MultiQueryRetriever.from_llm(
|
23 |
+
retriever=vectorstore.as_retriever(search_kwargs={"k": 3}),
|
24 |
+
llm=llm,
|
25 |
+
)
|
26 |
+
|
27 |
+
multiQuery_chain = RetrievalQA.from_chain_type(
|
28 |
+
llm=llm,
|
29 |
+
chain_type="stuff",
|
30 |
+
retriever = retriever_from_llm,
|
31 |
+
return_source_documents= True,
|
32 |
+
input_key="question",
|
33 |
+
chain_type_kwargs={"prompt": QA_PROMPT},
|
34 |
+
)
|
app.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
import streamlit as st
|
2 |
-
from
|
3 |
import time
|
4 |
|
5 |
from htmlTemplates import css, bot_template, user_template, source_template
|
6 |
|
7 |
-
st.set_page_config(page_title="Chat with ATrad",page_icon=":currency_exchange:")
|
8 |
st.write(css, unsafe_allow_html=True)
|
9 |
|
10 |
def main():
|
@@ -17,8 +17,11 @@ def main():
|
|
17 |
4. Source documents will be displayed in the sidebar.
|
18 |
""")
|
19 |
|
|
|
|
|
|
|
|
|
20 |
# Button to connect to Google link ------------------------------------------------
|
21 |
-
|
22 |
st.sidebar.markdown('<a href="https://drive.google.com/drive/folders/13v6LsaYH9wEwvqVtlLG1U4OiUHgZ7hY4?usp=sharing" target="_blank" style="display: inline-block;'
|
23 |
'background-color: #475063; color: white; padding: 10px 20px; text-align: center;border: 1px solid white;'
|
24 |
'text-decoration: none; cursor: pointer; border-radius: 5px;">Sources</a>',
|
@@ -27,8 +30,8 @@ def main():
|
|
27 |
st.title("ATrad Chat App")
|
28 |
|
29 |
# Chat area -----------------------------------------------------------------------
|
30 |
-
|
31 |
user_input = st.text_input("", key="user_input",placeholder="Type your question here...")
|
|
|
32 |
# JavaScript code to submit the form on Enter key press
|
33 |
js_submit = f"""
|
34 |
document.addEventListener("keydown", function(event) {{
|
@@ -38,31 +41,32 @@ def main():
|
|
38 |
}});
|
39 |
"""
|
40 |
st.markdown(f'<script>{js_submit}</script>', unsafe_allow_html=True)
|
|
|
41 |
if st.button("Send"):
|
42 |
if user_input:
|
43 |
-
|
44 |
with st.spinner('Waiting for response...'):
|
45 |
-
|
46 |
# Add bot response here (you can replace this with your bot logic)
|
47 |
-
response, metadata, source_documents = generate_bot_response(user_input)
|
48 |
-
st.write(user_template.replace(
|
49 |
-
"{{MSG}}",
|
50 |
-
st.write(bot_template.replace(
|
51 |
-
"{{MSG}}", response ), unsafe_allow_html=True)
|
52 |
|
53 |
# Source documents
|
54 |
-
print("metadata", metadata)
|
55 |
st.sidebar.title("Source Documents")
|
56 |
for i, doc in enumerate(source_documents, 1):
|
57 |
-
tit=metadata[i-1]["source"].split("\\")[-1]
|
58 |
with st.sidebar.expander(f"{tit}"):
|
59 |
st.write(doc) # Assuming the Document object can be directly written to display its content
|
60 |
|
61 |
-
def generate_bot_response(user_input):
|
62 |
# Simple bot logic (replace with your actual bot logic)
|
63 |
start_time = time.time()
|
64 |
print(f"User Input: {user_input}")
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
66 |
response = res['result']
|
67 |
metadata = [i.metadata for i in res.get("source_documents", [])]
|
68 |
end_time = time.time()
|
|
|
1 |
import streamlit as st
|
2 |
+
from Retrieval import bsic_chain, multiQuery_chain
|
3 |
import time
|
4 |
|
5 |
from htmlTemplates import css, bot_template, user_template, source_template
|
6 |
|
7 |
+
st.set_page_config(page_title="Chat with ATrad", page_icon=":currency_exchange:")
|
8 |
st.write(css, unsafe_allow_html=True)
|
9 |
|
10 |
def main():
|
|
|
17 |
4. Source documents will be displayed in the sidebar.
|
18 |
""")
|
19 |
|
20 |
+
# Dropdown to select model --------------------------------------------------------
|
21 |
+
model_selection = st.sidebar.selectbox("Select Model", ["Basic", "MultiQuery"])
|
22 |
+
print(model_selection)
|
23 |
+
|
24 |
# Button to connect to Google link ------------------------------------------------
|
|
|
25 |
st.sidebar.markdown('<a href="https://drive.google.com/drive/folders/13v6LsaYH9wEwvqVtlLG1U4OiUHgZ7hY4?usp=sharing" target="_blank" style="display: inline-block;'
|
26 |
'background-color: #475063; color: white; padding: 10px 20px; text-align: center;border: 1px solid white;'
|
27 |
'text-decoration: none; cursor: pointer; border-radius: 5px;">Sources</a>',
|
|
|
30 |
st.title("ATrad Chat App")
|
31 |
|
32 |
# Chat area -----------------------------------------------------------------------
|
|
|
33 |
user_input = st.text_input("", key="user_input",placeholder="Type your question here...")
|
34 |
+
|
35 |
# JavaScript code to submit the form on Enter key press
|
36 |
js_submit = f"""
|
37 |
document.addEventListener("keydown", function(event) {{
|
|
|
41 |
}});
|
42 |
"""
|
43 |
st.markdown(f'<script>{js_submit}</script>', unsafe_allow_html=True)
|
44 |
+
|
45 |
if st.button("Send"):
|
46 |
if user_input:
|
|
|
47 |
with st.spinner('Waiting for response...'):
|
|
|
48 |
# Add bot response here (you can replace this with your bot logic)
|
49 |
+
response, metadata, source_documents = generate_bot_response(user_input, model_selection)
|
50 |
+
st.write(user_template.replace("{{MSG}}", user_input), unsafe_allow_html=True)
|
51 |
+
st.write(bot_template.replace("{{MSG}}", response), unsafe_allow_html=True)
|
|
|
|
|
52 |
|
53 |
# Source documents
|
|
|
54 |
st.sidebar.title("Source Documents")
|
55 |
for i, doc in enumerate(source_documents, 1):
|
56 |
+
tit = metadata[i-1]["source"].split("\\")[-1]
|
57 |
with st.sidebar.expander(f"{tit}"):
|
58 |
st.write(doc) # Assuming the Document object can be directly written to display its content
|
59 |
|
60 |
+
def generate_bot_response(user_input, model):
|
61 |
# Simple bot logic (replace with your actual bot logic)
|
62 |
start_time = time.time()
|
63 |
print(f"User Input: {user_input}")
|
64 |
+
|
65 |
+
if model == "Basic":
|
66 |
+
res = bsic_chain(user_input)
|
67 |
+
elif model == "MultiQuery":
|
68 |
+
res = multiQuery_chain(user_input)
|
69 |
+
|
70 |
response = res['result']
|
71 |
metadata = [i.metadata for i in res.get("source_documents", [])]
|
72 |
end_time = time.time()
|
model.py
CHANGED
@@ -68,15 +68,5 @@ from langchain.prompts import PromptTemplate
|
|
68 |
QA_PROMPT = PromptTemplate(input_variables=["context", "question"],template=qa_template_V2,)
|
69 |
|
70 |
|
71 |
-
|
72 |
-
from langchain.chains import RetrievalQA
|
73 |
-
|
74 |
-
Web_qa = RetrievalQA.from_chain_type(
|
75 |
-
llm=llm,
|
76 |
-
chain_type="stuff",
|
77 |
-
retriever = vectorstore.as_retriever(search_kwargs={"k": 4}),
|
78 |
-
return_source_documents= True,
|
79 |
-
input_key="question",
|
80 |
-
chain_type_kwargs={"prompt": QA_PROMPT},
|
81 |
-
)
|
82 |
|
|
|
68 |
QA_PROMPT = PromptTemplate(input_variables=["context", "question"],template=qa_template_V2,)
|
69 |
|
70 |
|
71 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|