Syed Junaid Iqbal commited on
Commit
787200f
β€’
1 Parent(s): a4ad0a9

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -0
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from langchain.llms import LlamaCpp
3
+ from langchain.memory import ConversationBufferMemory
4
+ from langchain.chains import RetrievalQA
5
+ from langchain.embeddings import FastEmbedEmbeddings
6
+ from langchain.vectorstores import Chroma
7
+ from langchain.callbacks.manager import CallbackManager
8
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
9
+ from langchain import hub
10
+
11
+ def init_retriever():
12
+ """
13
+ Initialize and return the retriever function
14
+ """
15
+ callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
16
+ llm = LlamaCpp(model_path="./models/llama-2-13b-chat.Q4_K_S.gguf",
17
+ n_ctx=4000,
18
+ max_tokens=4000,
19
+ f16_kv=True,
20
+ callback_manager=callback_manager,
21
+ verbose=True)
22
+ embeddings = FastEmbedEmbeddings(model_name="BAAI/bge-small-en-v1.5", cache_dir="./embedding_model/")
23
+ db = Chroma(persist_directory="./vectordb/", embedding_function=embeddings)
24
+ rag_prompt_llama = hub.pull("rlm/rag-prompt-llama")
25
+ qa_chain = RetrievalQA.from_chain_type(
26
+ llm,
27
+ retriever=db.as_retriever(),
28
+ chain_type_kwargs={"prompt": rag_prompt_llama},
29
+ )
30
+ qa_chain.callback_manager = callback_manager
31
+ qa_chain.memory = ConversationBufferMemory()
32
+
33
+ return qa_chain
34
+
35
+ # Check if retriever is already initialized in the session state
36
+ if "retriever" not in st.session_state:
37
+ st.session_state.retriever = init_retriever()
38
+
39
+ # Function to apply rounded edges using CSS
40
+ def add_rounded_edges(image_path="./randstad_featuredimage.png", radius=30):
41
+ st.markdown(
42
+ f'<style>.rounded-img{{border-radius: {radius}px; overflow: hidden;}}</style>',
43
+ unsafe_allow_html=True,
44
+ )
45
+ st.image(image_path, use_column_width=True, output_format='auto')
46
+
47
+ # add side bar
48
+ with st.sidebar:
49
+ # add Randstad logo
50
+ add_rounded_edges()
51
+
52
+ st.title("πŸ’¬ HR Chatbot")
53
+ st.caption("πŸš€ A chatbot powered by Local LLM")
54
+
55
+ clear = False
56
+
57
+ # Add clear chat button
58
+ if st.button("Clear Chat History"):
59
+ clear = True
60
+ st.session_state.messages = []
61
+
62
+ if "messages" not in st.session_state:
63
+ st.session_state.messages = [{"role": "assistant", "content": "How can I help you?"}]
64
+
65
+ for msg in st.session_state.messages:
66
+ st.chat_message(msg["role"]).write(msg["content"])
67
+
68
+ if prompt := st.chat_input():
69
+ st.session_state.messages.append({"role": "user", "content": prompt})
70
+ st.chat_message("user").write(prompt)
71
+ chain = st.session_state.retriever
72
+ if clear:
73
+ chain.clean()
74
+ msg = chain.run(st.session_state.messages)
75
+ st.session_state.messages.append({"role": "assistant", "content": msg})
76
+ st.chat_message("assistant").write(msg)
77
+