cd@bziiit.com commited on
Commit
0e34878
·
1 Parent(s): 8c27c79

feat : add stream, copy response, model display

Browse files
Files changed (3) hide show
  1. model/selector.py +1 -1
  2. pages/chatbot.py +43 -24
  3. rag.py +7 -2
model/selector.py CHANGED
@@ -38,5 +38,5 @@ def ModelSelector():
38
 
39
  if(st.session_state["assistant"]):
40
  splitter = model_mapping[selected_model_option].split(".")
41
- st.session_state["assistant"].setModel(ModelManager().get_model(splitter[0], splitter[1]))
42
 
 
38
 
39
  if(st.session_state["assistant"]):
40
  splitter = model_mapping[selected_model_option].split(".")
41
+ st.session_state["assistant"].setModel(ModelManager().get_model(splitter[0], splitter[1]), splitter[1])
42
 
pages/chatbot.py CHANGED
@@ -1,27 +1,39 @@
1
  import streamlit as st
2
- from streamlit_chat import message
3
  from model import selector
4
  from util import getYamlConfig
5
-
6
 
7
  def display_messages():
8
- for i, (msg, is_user) in enumerate(st.session_state["messages"]):
9
- message(msg, is_user=is_user, key=str(i))
10
- st.session_state["thinking_spinner"] = st.empty()
11
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- def process_input():
14
- if "user_input" in st.session_state and st.session_state["user_input"] and len(st.session_state["user_input"].strip()) > 0:
15
- user_text = st.session_state["user_input"].strip()
16
 
17
- prompt_sys = st.session_state.prompt_system if 'prompt_system' in st.session_state and st.session_state.prompt_system != '' else ""
18
-
19
- with st.session_state["thinking_spinner"], st.spinner(f"Je réfléchis"):
20
- agent_text = st.session_state["assistant"].ask(user_text, prompt_system=prompt_sys, messages=st.session_state["messages"] if "messages" in st.session_state else [], variables=st.session_state["data_dict"])
21
-
22
- st.session_state["messages"].append((user_text, True))
23
- st.session_state["messages"].append((agent_text, False))
24
- st.session_state["user_input"] = ""
 
 
 
 
 
 
25
 
26
 
27
  def show_prompts():
@@ -34,30 +46,37 @@ def show_prompts():
34
 
35
  for item in yaml_data[categroy]:
36
  if expander.button(item, key=f"button_{item}"):
37
- st.session_state["user_input"] = item
38
- process_input()
39
 
40
 
41
  def page():
42
  st.subheader("Posez vos questions")
43
 
44
- if "user_input" in st.session_state:
45
- process_input()
46
-
47
  if "assistant" not in st.session_state:
48
  st.text("Assistant non initialisé")
49
 
 
 
 
 
 
50
  # Collpase for default prompts
51
  show_prompts()
52
 
53
  # Models selector
54
  selector.ModelSelector()
55
-
56
  # Displaying messages
57
  display_messages()
58
 
59
- # Input user query
60
- st.text_input("Message", key="user_input", on_change=process_input)
 
 
 
 
 
 
61
 
62
 
63
  page()
 
1
  import streamlit as st
2
+ from langchain_core.messages import AIMessage, HumanMessage
3
  from model import selector
4
  from util import getYamlConfig
5
+ from st_copy_to_clipboard import st_copy_to_clipboard
6
 
7
  def display_messages():
 
 
 
8
 
9
+ for i, message in enumerate(st.session_state.chat_history):
10
+ if isinstance(message, AIMessage):
11
+ with st.chat_message("AI"):
12
+ # Display the model from the kwargs
13
+ model = message.kwargs.get("model", "Unknown Model") # Get the model, default to "Unknown Model"
14
+ st.write(f"**Model :** {model}")
15
+ st.markdown(message.content)
16
+ st_copy_to_clipboard(message.content,key=f"message_{i}")
17
+
18
+ elif isinstance(message, HumanMessage):
19
+ with st.chat_message("Moi"):
20
+ st.write(message.content)
21
 
 
 
 
22
 
23
+ def launchQuery(query: str = None):
24
+
25
+ # Initialize the assistant's response
26
+ full_response = st.write_stream(
27
+ st.session_state["assistant"].ask(
28
+ query,
29
+ prompt_system=st.session_state.prompt_system,
30
+ messages=st.session_state["chat_history"] if "chat_history" in st.session_state else [],
31
+ variables=st.session_state["data_dict"]
32
+ ))
33
+
34
+ # Temporary placeholder AI message in chat history
35
+ st.session_state["chat_history"].append(AIMessage(content=full_response, kwargs={"model": st.session_state["assistant"].getReadableModel()}))
36
+ st.rerun()
37
 
38
 
39
  def show_prompts():
 
46
 
47
  for item in yaml_data[categroy]:
48
  if expander.button(item, key=f"button_{item}"):
49
+ launchQuery(item)
 
50
 
51
 
52
  def page():
53
  st.subheader("Posez vos questions")
54
 
 
 
 
55
  if "assistant" not in st.session_state:
56
  st.text("Assistant non initialisé")
57
 
58
+ if "chat_history" not in st.session_state:
59
+ st.session_state["chat_history"] = []
60
+
61
+ st.markdown("<style>iframe{height:50px;}</style>", unsafe_allow_html=True)
62
+
63
  # Collpase for default prompts
64
  show_prompts()
65
 
66
  # Models selector
67
  selector.ModelSelector()
68
+
69
  # Displaying messages
70
  display_messages()
71
 
72
+
73
+ user_query = st.chat_input("")
74
+ if user_query is not None and user_query != "":
75
+
76
+ st.session_state["chat_history"].append(HumanMessage(content=user_query))
77
+
78
+ # Stream and display response
79
+ launchQuery(user_query)
80
 
81
 
82
  page()
rag.py CHANGED
@@ -23,6 +23,7 @@ class Rag:
23
  document_vector_store = None
24
  retriever = None
25
  chain = None
 
26
 
27
  def __init__(self, vectore_store=None):
28
 
@@ -36,9 +37,13 @@ class Rag:
36
 
37
  self.vector_store = vectore_store
38
 
39
- def setModel(self, model):
40
  self.model = model
 
41
 
 
 
 
42
  def ingestToDb(self, file_path: str, filename: str):
43
 
44
  docs = PyPDFLoader(file_path=file_path).load()
@@ -105,7 +110,7 @@ class Rag:
105
  chain_input.update(extra_vars)
106
 
107
 
108
- return self.chain.invoke(chain_input)
109
 
110
  def clear(self):
111
  self.document_vector_store = None
 
23
  document_vector_store = None
24
  retriever = None
25
  chain = None
26
+ readableModelName = ""
27
 
28
  def __init__(self, vectore_store=None):
29
 
 
37
 
38
  self.vector_store = vectore_store
39
 
40
+ def setModel(self, model, readableModelName = ""):
41
  self.model = model
42
+ self.readableModelName = readableModelName
43
 
44
+ def getReadableModel(self):
45
+ return self.readableModelName
46
+
47
  def ingestToDb(self, file_path: str, filename: str):
48
 
49
  docs = PyPDFLoader(file_path=file_path).load()
 
110
  chain_input.update(extra_vars)
111
 
112
 
113
+ return self.chain.stream(chain_input)
114
 
115
  def clear(self):
116
  self.document_vector_store = None