Spaces:
Sleeping
Sleeping
Upload 10 files
Browse files- .gitattributes +1 -0
- app.py +93 -0
- cima_faiss_index/index.faiss +3 -0
- cima_faiss_index/index.pkl +3 -0
- config.py +15 -0
- docs_data.pkl +3 -0
- image.jpg +0 -0
- markup.py +32 -0
- memory.py +4 -0
- query_data.py +125 -0
- requirements.txt +11 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
cima_faiss_index/index.faiss filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gdown
|
3 |
+
|
4 |
+
file_id = "1G_OK4BPWgqpgEK540VReSKhzOxKspjgy"
|
5 |
+
output_folder = "faiss_index" # Specify the folder name
|
6 |
+
output_file = os.path.join(output_folder, "index.faiss")
|
7 |
+
|
8 |
+
if not os.path.exists(output_file):
|
9 |
+
os.makedirs(output_folder, exist_ok=True) # Create the folder if it doesn't exist
|
10 |
+
url = f"https://drive.google.com/uc?id={file_id}"
|
11 |
+
gdown.download(url, output_file, quiet=False)
|
12 |
+
|
13 |
+
import streamlit as st
|
14 |
+
from streamlit_option_menu import option_menu
|
15 |
+
from markup import app_intro
|
16 |
+
import langchain
|
17 |
+
from langchain.cache import InMemoryCache
|
18 |
+
from query_data import chat_chain
|
19 |
+
|
20 |
+
langchain.llm_cache = InMemoryCache()
|
21 |
+
|
22 |
+
def tab1():
|
23 |
+
st.header("CIMA Chatbot")
|
24 |
+
col1, col2 = st.columns([1, 2])
|
25 |
+
with col1:
|
26 |
+
st.image("image.jpg", use_column_width=True)
|
27 |
+
with col2:
|
28 |
+
st.markdown(app_intro(), unsafe_allow_html=True)
|
29 |
+
|
30 |
+
|
31 |
+
metadata_list = []
|
32 |
+
unique_metadata_list = []
|
33 |
+
seen = set()
|
34 |
+
|
35 |
+
def tab4():
|
36 |
+
if "messages" not in st.session_state:
|
37 |
+
st.session_state.messages = []
|
38 |
+
|
39 |
+
st.header("🗣️ Chat with the AI about the ingested documents! 📚")
|
40 |
+
|
41 |
+
for message in st.session_state.messages:
|
42 |
+
with st.chat_message(message["role"]):
|
43 |
+
st.markdown(message["content"])
|
44 |
+
|
45 |
+
if user_input := st.chat_input("User Input"):
|
46 |
+
st.session_state.messages.append({"role": "user", "content": user_input})
|
47 |
+
|
48 |
+
with st.chat_message("user"):
|
49 |
+
st.markdown(user_input)
|
50 |
+
|
51 |
+
with st.spinner("Generating Response..."):
|
52 |
+
|
53 |
+
with st.chat_message("assistant"):
|
54 |
+
response = chat_chain({"question": user_input})
|
55 |
+
|
56 |
+
answer = response['answer']
|
57 |
+
source_documents = response['source_documents']
|
58 |
+
|
59 |
+
for doc in source_documents:
|
60 |
+
if hasattr(doc, 'metadata'):
|
61 |
+
metadata = doc.metadata
|
62 |
+
metadata_list.append(metadata)
|
63 |
+
|
64 |
+
for metadata in metadata_list:
|
65 |
+
metadata_tuple = tuple(metadata.items())
|
66 |
+
if metadata_tuple not in seen:
|
67 |
+
unique_metadata_list.append(metadata)
|
68 |
+
seen.add(metadata_tuple)
|
69 |
+
|
70 |
+
st.write(answer)
|
71 |
+
st.write(unique_metadata_list)
|
72 |
+
|
73 |
+
st.session_state.messages.append({"role": "assistant", "content": answer})
|
74 |
+
|
75 |
+
|
76 |
+
def main():
|
77 |
+
st.set_page_config(page_title="CIMA Chat", page_icon=":memo:", layout="wide")
|
78 |
+
tabs = ["Intro", "Chat"]
|
79 |
+
|
80 |
+
with st.sidebar:
|
81 |
+
|
82 |
+
current_tab = option_menu("Select a Tab", tabs, menu_icon="cast")
|
83 |
+
|
84 |
+
tab_functions = {
|
85 |
+
"Intro": tab1,
|
86 |
+
"Chat": tab4,
|
87 |
+
}
|
88 |
+
|
89 |
+
if current_tab in tab_functions:
|
90 |
+
tab_functions[current_tab]()
|
91 |
+
|
92 |
+
if __name__ == "__main__":
|
93 |
+
main()
|
cima_faiss_index/index.faiss
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:635988cbd5841d7a301b35ce3230ac6f0b7b4444fc314e46d330a10144a1a90f
|
3 |
+
size 29380653
|
cima_faiss_index/index.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fa2e5b877140b26b7c6eb0e03284184d441722d69dd5de0c43bdb8f01c778f25
|
3 |
+
size 8526672
|
config.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
OPENAI_API_KEY = "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"
|
2 |
+
|
3 |
+
DEFAULT_QA_TEMPERATURE = 0
|
4 |
+
|
5 |
+
DEFAULT_CHAT_TEMPLATE = """<|system|> You are a helpful virtual assistant for CIMA textbooks related questions. </s>
|
6 |
+
|
7 |
+
your conversation history:
|
8 |
+
{chat_history} </s>
|
9 |
+
|
10 |
+
CIMA textbook content you can use for the answer if needed:
|
11 |
+
{context}</s>
|
12 |
+
|
13 |
+
<|user|>: {question}</s>
|
14 |
+
<|assistant|>:"""
|
15 |
+
|
docs_data.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:58e85d1679c567117937df49e7294968f0851ec52c3e6858f82cea45f0034a9f
|
3 |
+
size 7965757
|
image.jpg
ADDED
![]() |
markup.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def app_intro():
|
2 |
+
return """
|
3 |
+
<div style='text-align: left;'>
|
4 |
+
<h2 style='text-align: center;'>CIMA Textbooks Chatbot</h2>
|
5 |
+
<h3 style='text-align: center;'>Introduction</h3>
|
6 |
+
|
7 |
+
<p>Welcome to the CIMA Textbooks Chatbot! Our chatbot is designed to assist you with questions and information related to CIMA (Chartered Institute of Management Accountants) studies. Whether you need answers to standalone questions or want to evaluate sources, our chatbot is here to help.</p>
|
8 |
+
|
9 |
+
<h4>Chat Options:</h4>
|
10 |
+
<ul>
|
11 |
+
<li><b>QA Option:</b> Get answers to standalone questions and evaluate sources.</li>
|
12 |
+
<li><b>QA with Memory Option:</b> Has short-term memory to answer follow-up questions to a main question.</li>
|
13 |
+
<li><b>Chat Option:</b> Fully dynamic and conversational AI chatbot that can be easily customized with additional options.</li>
|
14 |
+
</ul>
|
15 |
+
|
16 |
+
<h4>Indexed PDFs for Context:</h4>
|
17 |
+
<ul>
|
18 |
+
<li>E1 Study Text 2019-20 NPE.pdf - Managing Finance in a digital World</li>
|
19 |
+
<li>E2 Study Text 2019-20 NPA.pdf - Managing Performance</li>
|
20 |
+
<li>E3 Study Text 2019-20.pdf - Strategic Management</li>
|
21 |
+
<li>F1 Study Text 2019-20 NPE_unlocked.pdf - Financial Reporting</li>
|
22 |
+
<li>F2 Study Text 2019-20 NPA.pdf - Advanced Financial Reporting</li>
|
23 |
+
<li>F3 Study Text 2019-20.pdf - Financial Strategy</li>
|
24 |
+
<li>P1 Study Text 2019-20 NPA-unlocked.pdf - Management Accounting</li>
|
25 |
+
<li>P2 Study Text 2019-20 NPE.pdf - Advanced Management Accounting</li>
|
26 |
+
<li>P3 Study Text 2019-20 NPE.pdf - Risk Management</li>
|
27 |
+
</ul>
|
28 |
+
|
29 |
+
<p>This application utilizes a hybrid search method, which retrieves relevant text from the documents, regardless of the large number of pages in each textbook. This method combines a state-of-the-art sparse retrieval algorithm, ideal for finding keywords, with a dense retrieval method, proficient at locating documents through semantic similarity.</p>
|
30 |
+
|
31 |
+
</div>
|
32 |
+
"""
|
memory.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.memory import ConversationBufferWindowMemory
|
2 |
+
|
3 |
+
memory3 = ConversationBufferWindowMemory(
|
4 |
+
k=2, memory_key='chat_history', return_messages=True, output_key='answer', human_prefix="<|user|>", ai_prefix="<|assistant|>")
|
query_data.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.chat_models import ChatOpenAI
|
2 |
+
from langchain.chains import ConversationalRetrievalChain
|
3 |
+
from langchain.prompts import PromptTemplate
|
4 |
+
import pickle
|
5 |
+
import config
|
6 |
+
from langchain.retrievers import EnsembleRetriever, BM25Retriever, ContextualCompressionRetriever
|
7 |
+
from memory import memory3
|
8 |
+
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
9 |
+
from langchain.vectorstores import FAISS
|
10 |
+
from langchain.embeddings.openai import OpenAIEmbeddings
|
11 |
+
from langchain.retrievers.document_compressors import EmbeddingsFilter
|
12 |
+
from langchain.document_transformers import EmbeddingsRedundantFilter
|
13 |
+
from langchain.retrievers.document_compressors import DocumentCompressorPipeline
|
14 |
+
from langchain.text_splitter import CharacterTextSplitter
|
15 |
+
from pydantic import BaseModel, Field
|
16 |
+
from typing import Any, Optional, Dict, List
|
17 |
+
from huggingface_hub import InferenceClient
|
18 |
+
from langchain.llms.base import LLM
|
19 |
+
|
20 |
+
import os
|
21 |
+
os.environ["OPENAI_API_KEY"] = config.OPENAI_API_KEY
|
22 |
+
|
23 |
+
chat_model_name = "HuggingFaceH4/zephyr-7b-alpha"
|
24 |
+
reform_model_name = "mistralai/Mistral-7B-Instruct-v0.1"
|
25 |
+
hf_token = "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"
|
26 |
+
kwargs = {"max_new_tokens":500, "temperature":0.9, "top_p":0.95, "repetition_penalty":1.0, "do_sample":True}
|
27 |
+
reform_kwargs = {"max_new_tokens":50, "temperature":0.5, "top_p":0.9, "repetition_penalty":1.0, "do_sample":True}
|
28 |
+
|
29 |
+
class KwArgsModel(BaseModel):
|
30 |
+
kwargs: Dict[str, Any] = Field(default_factory=dict)
|
31 |
+
|
32 |
+
class CustomInferenceClient(LLM, KwArgsModel):
|
33 |
+
model_name: str
|
34 |
+
inference_client: InferenceClient
|
35 |
+
|
36 |
+
def __init__(self, model_name: str, hf_token: str, kwargs: Optional[Dict[str, Any]] = None):
|
37 |
+
inference_client = InferenceClient(model=model_name, token=hf_token)
|
38 |
+
super().__init__(
|
39 |
+
model_name=model_name,
|
40 |
+
hf_token=hf_token,
|
41 |
+
kwargs=kwargs,
|
42 |
+
inference_client=inference_client
|
43 |
+
)
|
44 |
+
|
45 |
+
def _call(
|
46 |
+
self,
|
47 |
+
prompt: str,
|
48 |
+
stop: Optional[List[str]] = None
|
49 |
+
) -> str:
|
50 |
+
if stop is not None:
|
51 |
+
raise ValueError("stop kwargs are not permitted.")
|
52 |
+
response_gen = self.inference_client.text_generation(prompt, **self.kwargs, stream=True, return_full_text=False)
|
53 |
+
response = ''.join(response_gen)
|
54 |
+
return response
|
55 |
+
|
56 |
+
@property
|
57 |
+
def _llm_type(self) -> str:
|
58 |
+
return "custom"
|
59 |
+
|
60 |
+
@property
|
61 |
+
def _identifying_params(self) -> dict:
|
62 |
+
return {"model_name": self.model_name}
|
63 |
+
|
64 |
+
|
65 |
+
chat_llm = CustomInferenceClient(model_name=chat_model_name, hf_token=hf_token, kwargs=kwargs)
|
66 |
+
reform_llm = CustomInferenceClient(model_name=reform_model_name, hf_token=hf_token, kwargs=reform_kwargs)
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
prompt_template = config.DEFAULT_CHAT_TEMPLATE
|
71 |
+
|
72 |
+
PROMPT = PromptTemplate(
|
73 |
+
template=prompt_template, input_variables=["context", "question", "chat_history"]
|
74 |
+
)
|
75 |
+
|
76 |
+
|
77 |
+
chain_type_kwargs = {"prompt": PROMPT}
|
78 |
+
|
79 |
+
embeddings = OpenAIEmbeddings()
|
80 |
+
vectorstore = FAISS.load_local("cima_faiss_index", embeddings)
|
81 |
+
|
82 |
+
retriever=vectorstore.as_retriever(search_type="similarity", search_kwargs={"k":5})
|
83 |
+
|
84 |
+
|
85 |
+
splitter = CharacterTextSplitter(chunk_size=300, chunk_overlap=0, separator=". ")
|
86 |
+
redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings)
|
87 |
+
relevant_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.76)
|
88 |
+
pipeline_compressor = DocumentCompressorPipeline(
|
89 |
+
transformers=[splitter, redundant_filter, relevant_filter]
|
90 |
+
)
|
91 |
+
|
92 |
+
compression_retriever = ContextualCompressionRetriever(base_compressor=pipeline_compressor, base_retriever=retriever)
|
93 |
+
|
94 |
+
with open("docs_data.pkl", "rb") as file:
|
95 |
+
docs = pickle.load(file)
|
96 |
+
|
97 |
+
bm25_retriever = BM25Retriever.from_texts(docs)
|
98 |
+
bm25_retriever.k = 2
|
99 |
+
|
100 |
+
bm25_compression_retriever = ContextualCompressionRetriever(base_compressor=pipeline_compressor, base_retriever=bm25_retriever)
|
101 |
+
|
102 |
+
ensemble_retriever = EnsembleRetriever(retrievers=[compression_retriever, bm25_compression_retriever], weights=[0.5, 0.5])
|
103 |
+
|
104 |
+
|
105 |
+
custom_template = """Given the following conversation and a follow-up message, rephrase the follow-up user message to be a standalone message. If the follow-up message is not a question, keep it unchanged[/INST].
|
106 |
+
|
107 |
+
Chat History:
|
108 |
+
{chat_history}
|
109 |
+
|
110 |
+
Follow-up user message: {question}
|
111 |
+
Rewritten user message:"""
|
112 |
+
|
113 |
+
CUSTOM_QUESTION_PROMPT = PromptTemplate.from_template(custom_template)
|
114 |
+
|
115 |
+
|
116 |
+
chat_chain = ConversationalRetrievalChain.from_llm(llm=chat_llm,
|
117 |
+
chain_type="stuff",
|
118 |
+
retriever=ensemble_retriever,
|
119 |
+
combine_docs_chain_kwargs=chain_type_kwargs,
|
120 |
+
return_source_documents=True,
|
121 |
+
get_chat_history=lambda h : h,
|
122 |
+
condense_question_prompt=CUSTOM_QUESTION_PROMPT,
|
123 |
+
memory=memory3,
|
124 |
+
condense_question_llm = reform_llm
|
125 |
+
)
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
langchain
|
2 |
+
openai
|
3 |
+
streamlit_option_menu
|
4 |
+
pypdf
|
5 |
+
rank_bm25
|
6 |
+
faiss-cpu
|
7 |
+
tiktoken
|
8 |
+
scikit-learn
|
9 |
+
gdown
|
10 |
+
sentence_transformers
|
11 |
+
huggingface_hub
|