NEXAS commited on
Commit
8214bc3
·
verified ·
1 Parent(s): 56d7a00

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -8
app.py CHANGED
@@ -6,11 +6,16 @@ from src.utils.ingest_image import extract_and_store_images
6
  from src.utils.text_qa import qa_bot
7
  from src.utils.image_qa import query_and_print_results
8
  import nest_asyncio
 
 
 
9
  nest_asyncio.apply()
10
 
11
- from dotenv import load_dotenv
12
  load_dotenv()
13
 
 
 
 
14
  def get_answer(query, chain):
15
  try:
16
  response = chain.invoke(query)
@@ -23,17 +28,13 @@ st.title("MULTIMODAL DOC QA")
23
 
24
  uploaded_file = st.file_uploader("File upload", type="pdf")
25
  if uploaded_file is not None:
26
- # Save the uploaded file to a temporary location
27
  temp_file_path = os.path.join("temp", uploaded_file.name)
28
- os.makedirs("temp", exist_ok=True) # Ensure the temp directory exists
29
  with open(temp_file_path, "wb") as f:
30
  f.write(uploaded_file.getbuffer())
31
 
32
- # Get the absolute path of the saved file
33
- #temp_dir = tempfile.mkdtemp()
34
  path = os.path.abspath(temp_file_path)
35
  st.write(f"File saved to: {path}")
36
- print(path)
37
 
38
  st.write("Document uploaded successfully!")
39
 
@@ -44,8 +45,8 @@ if st.button("Start Processing"):
44
  client = create_vector_database(path)
45
  image_vdb = extract_and_store_images(path)
46
  chain = qa_bot(client)
47
- st.session_state['chain'] = chain # Store chain in session state
48
- st.session_state['image_vdb'] = image_vdb # Store image_vdb in session state
49
  st.success("Processing complete.")
50
  except Exception as e:
51
  st.error(f"Error during processing: {e}")
@@ -59,11 +60,15 @@ if user_input := st.chat_input("User Input"):
59
 
60
  with st.chat_message("user"):
61
  st.markdown(user_input)
 
62
 
63
  with st.spinner("Generating Response..."):
64
  response = get_answer(user_input, chain)
65
  if response:
66
  st.markdown(response)
 
 
 
67
  try:
68
  query_and_print_results(image_vdb, user_input)
69
  except Exception as e:
@@ -72,3 +77,14 @@ if user_input := st.chat_input("User Input"):
72
  st.error("Failed to generate response.")
73
  else:
74
  st.error("Please start processing before entering user input.")
 
 
 
 
 
 
 
 
 
 
 
 
6
  from src.utils.text_qa import qa_bot
7
  from src.utils.image_qa import query_and_print_results
8
  import nest_asyncio
9
+ from langchain.memory import ConversationBufferWindowMemory
10
+ from langchain_community.chat_message_histories import StreamlitChatMessageHistory
11
+ from dotenv import load_dotenv
12
  nest_asyncio.apply()
13
 
 
14
  load_dotenv()
15
 
16
+ memory_storage = StreamlitChatMessageHistory(key="chat_messages")
17
+ memory = ConversationBufferWindowMemory(memory_key="chat_history", human_prefix="User", chat_memory=memory_storage, k=3)
18
+
19
  def get_answer(query, chain):
20
  try:
21
  response = chain.invoke(query)
 
28
 
29
  uploaded_file = st.file_uploader("File upload", type="pdf")
30
  if uploaded_file is not None:
 
31
  temp_file_path = os.path.join("temp", uploaded_file.name)
32
+ os.makedirs("temp", exist_ok=True)
33
  with open(temp_file_path, "wb") as f:
34
  f.write(uploaded_file.getbuffer())
35
 
 
 
36
  path = os.path.abspath(temp_file_path)
37
  st.write(f"File saved to: {path}")
 
38
 
39
  st.write("Document uploaded successfully!")
40
 
 
45
  client = create_vector_database(path)
46
  image_vdb = extract_and_store_images(path)
47
  chain = qa_bot(client)
48
+ st.session_state['chain'] = chain
49
+ st.session_state['image_vdb'] = image_vdb
50
  st.success("Processing complete.")
51
  except Exception as e:
52
  st.error(f"Error during processing: {e}")
 
60
 
61
  with st.chat_message("user"):
62
  st.markdown(user_input)
63
+ memory.save_context({"role": "user", "content": user_input})
64
 
65
  with st.spinner("Generating Response..."):
66
  response = get_answer(user_input, chain)
67
  if response:
68
  st.markdown(response)
69
+ with st.chat_message("assistant"):
70
+ st.markdown(response)
71
+ memory.save_context({"role": "assistant", "content": response})
72
  try:
73
  query_and_print_results(image_vdb, user_input)
74
  except Exception as e:
 
77
  st.error("Failed to generate response.")
78
  else:
79
  st.error("Please start processing before entering user input.")
80
+
81
+ if "messages" not in st.session_state:
82
+ st.session_state.messages = []
83
+
84
+ for message in st.session_state.messages:
85
+ with st.chat_message(message["role"]):
86
+ st.write(message["content"])
87
+
88
+ for i, msg in enumerate(memory_storage.messages):
89
+ name = "user" if i % 2 == 0 else "assistant"
90
+ st.chat_message(name).markdown(msg.content)