Spaces:
Running
Running
Ilyas KHIAT
commited on
Commit
·
5e3fa8e
1
Parent(s):
6402624
test
Browse files- .streamlit/config.toml +1 -1
- app.py +7 -10
- audit_page/dialogue_doc.py +46 -21
- audit_page/knowledge_graph.py +338 -280
- test.ipynb +0 -0
- utils/assets/kg_ia_signature.pkl +3 -0
- utils/assets/scenes.pkl +3 -0
.streamlit/config.toml
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
maxUploadSize = 20
|
3 |
|
4 |
[theme]
|
5 |
-
base="
|
6 |
primaryColor="#63abdf"
|
7 |
secondaryBackgroundColor="#fbf7f1"
|
8 |
textColor="#011166"
|
|
|
2 |
maxUploadSize = 20
|
3 |
|
4 |
[theme]
|
5 |
+
base="dark"
|
6 |
primaryColor="#63abdf"
|
7 |
secondaryBackgroundColor="#fbf7f1"
|
8 |
textColor="#011166"
|
app.py
CHANGED
@@ -8,21 +8,18 @@ def main():
|
|
8 |
|
9 |
st.set_page_config(page_title="RAG Agent", page_icon="🤖", layout="wide")
|
10 |
|
11 |
-
audit_page = st.Page("audit_page/audit.py", title="Audit", icon="📋", default=True)
|
12 |
dialog_page = st.Page("audit_page/dialogue_doc.py", title="Dialoguer avec le document", icon="💬")
|
13 |
kg_page = st.Page("audit_page/knowledge_graph.py", title="Graphe de connaissance", icon="🧠")
|
14 |
-
agents_page = st.Page("agents_page/catalogue.py", title="Catalogue des agents", icon="📇")
|
15 |
-
compte_rendu = st.Page("audit_page/compte_rendu.py", title="Compte rendu", icon="📝")
|
16 |
-
recommended_agents = st.Page("agents_page/recommended_agent.py", title="Agents recommandés", icon="⭐")
|
17 |
-
chatbot = st.Page("chatbot_page/chatbot.py", title="Chatbot", icon="💬")
|
18 |
-
documentation = st.Page("doc_page/documentation.py", title="Documentation", icon="📚")
|
19 |
|
20 |
pg = st.navigation(
|
21 |
{
|
22 |
-
"
|
23 |
-
"Equipe d'agents IA": [recommended_agents],
|
24 |
-
"Chatbot": [chatbot],
|
25 |
-
"Documentation": [documentation]
|
26 |
}
|
27 |
)
|
28 |
|
|
|
8 |
|
9 |
st.set_page_config(page_title="RAG Agent", page_icon="🤖", layout="wide")
|
10 |
|
11 |
+
# audit_page = st.Page("audit_page/audit.py", title="Audit", icon="📋", default=True)
|
12 |
dialog_page = st.Page("audit_page/dialogue_doc.py", title="Dialoguer avec le document", icon="💬")
|
13 |
kg_page = st.Page("audit_page/knowledge_graph.py", title="Graphe de connaissance", icon="🧠")
|
14 |
+
# agents_page = st.Page("agents_page/catalogue.py", title="Catalogue des agents", icon="📇")
|
15 |
+
# compte_rendu = st.Page("audit_page/compte_rendu.py", title="Compte rendu", icon="📝")
|
16 |
+
# recommended_agents = st.Page("agents_page/recommended_agent.py", title="Agents recommandés", icon="⭐")
|
17 |
+
# chatbot = st.Page("chatbot_page/chatbot.py", title="Chatbot", icon="💬")
|
18 |
+
# documentation = st.Page("doc_page/documentation.py", title="Documentation", icon="📚")
|
19 |
|
20 |
pg = st.navigation(
|
21 |
{
|
22 |
+
"Graphe de connaissance": [kg_page],
|
|
|
|
|
|
|
23 |
}
|
24 |
)
|
25 |
|
audit_page/dialogue_doc.py
CHANGED
@@ -8,6 +8,7 @@ from utils.kg.construct_kg import get_graph,get_advanced_graph
|
|
8 |
from audit_page.knowledge_graph import *
|
9 |
import json
|
10 |
from time import sleep
|
|
|
11 |
|
12 |
def graph_doc_to_json(graph):
|
13 |
nodes = []
|
@@ -75,13 +76,17 @@ def format_cr(cr:report):
|
|
75 |
formatted_cr = f"### Résumé :\n{cr.summary}\n\n### Notes :\n{cr.Notes}\n\n### Actions :\n{cr.Actions}"
|
76 |
return formatted_cr
|
77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
def doc_dialog_main():
|
80 |
st.title("Dialogue avec le document")
|
81 |
-
|
82 |
-
if "audit" not in st.session_state or st.session_state.audit == {}:
|
83 |
-
st.error("Veuillez d'abord effectuer un audit pour générer le compte rendu ou le graphe de connaissance.")
|
84 |
-
return
|
85 |
|
86 |
#init cr and chat history cr
|
87 |
if "cr" not in st.session_state:
|
@@ -96,7 +101,7 @@ def doc_dialog_main():
|
|
96 |
st.session_state.current_chunk_index = 0
|
97 |
st.session_state.number_of_entities = 0
|
98 |
st.session_state.number_of_relationships = 0
|
99 |
-
|
100 |
if "filter_views" not in st.session_state:
|
101 |
st.session_state.filter_views = {}
|
102 |
if "current_view" not in st.session_state:
|
@@ -108,6 +113,32 @@ def doc_dialog_main():
|
|
108 |
if "chat_graph_history" not in st.session_state:
|
109 |
st.session_state.chat_graph_history = []
|
110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
#init a radio button for the choice
|
112 |
if "radio_choice" not in st.session_state:
|
113 |
st.session_state.radio_choice = None
|
@@ -122,15 +153,13 @@ def doc_dialog_main():
|
|
122 |
st.session_state.radio_choice = options.index(choice)
|
123 |
|
124 |
|
125 |
-
audit =
|
126 |
-
content =
|
|
|
|
|
127 |
|
128 |
-
|
129 |
-
|
130 |
-
elif audit["type de fichier"] == "audio":
|
131 |
-
text = get_text_from_content_for_audio(content)
|
132 |
-
elif audit["type de fichier"] == "text":
|
133 |
-
text = content
|
134 |
|
135 |
prompt_cr = dedent(f'''
|
136 |
|
@@ -255,24 +284,20 @@ def doc_dialog_main():
|
|
255 |
st_copy_to_clipboard(chat_formatted,key="cp_but_cr_chat",show_text=False)
|
256 |
# col_success_c.success("Historique copié !")
|
257 |
|
|
|
|
|
258 |
elif choice == "graphe de connaissance":
|
259 |
# st.write(st.session_state.graph)
|
260 |
if "graph" not in st.session_state or st.session_state.graph == None:
|
261 |
keywords_list = [keyword.strip() for keyword in audit["Mots clés"].strip().split(",")]
|
262 |
-
|
263 |
-
number_tokens = audit["Nombre de tokens"]
|
264 |
-
with st.spinner("Synthétisation des informations..."):
|
265 |
-
if st.session_state.cr == "":
|
266 |
-
st.session_state.cr = generate_structured_response(prompt_cr2)
|
267 |
-
entete = f"**RESUME:** {st.session_state.cr.summary} \n **MOTS CLES:** {keywords_list}"
|
268 |
-
|
269 |
|
270 |
with st.spinner("Construction du graphe de connaissance..."):
|
271 |
|
272 |
#graph = get_graph(text,allowed_nodes=allowed_nodes_types)
|
273 |
# chunk = st.session_state.chunks[st.session_state.current_chunk_index]
|
274 |
# print(chunk)
|
275 |
-
graph =
|
276 |
st.session_state.graph = graph
|
277 |
st.session_state.current_chunk_index = 0
|
278 |
st.session_state.filter_views = {}
|
|
|
8 |
from audit_page.knowledge_graph import *
|
9 |
import json
|
10 |
from time import sleep
|
11 |
+
import pickle
|
12 |
|
13 |
def graph_doc_to_json(graph):
|
14 |
nodes = []
|
|
|
76 |
formatted_cr = f"### Résumé :\n{cr.summary}\n\n### Notes :\n{cr.Notes}\n\n### Actions :\n{cr.Actions}"
|
77 |
return formatted_cr
|
78 |
|
79 |
+
def load_text_from_pkl(file_path:str):
|
80 |
+
with open(file_path,"rb") as f:
|
81 |
+
return pickle.load(f)
|
82 |
+
|
83 |
+
def load_graph_from_pkl(file_path:str):
|
84 |
+
with open(file_path,"rb") as f:
|
85 |
+
return pickle.load(f)
|
86 |
+
|
87 |
|
88 |
def doc_dialog_main():
|
89 |
st.title("Dialogue avec le document")
|
|
|
|
|
|
|
|
|
90 |
|
91 |
#init cr and chat history cr
|
92 |
if "cr" not in st.session_state:
|
|
|
101 |
st.session_state.current_chunk_index = 0
|
102 |
st.session_state.number_of_entities = 0
|
103 |
st.session_state.number_of_relationships = 0
|
104 |
+
|
105 |
if "filter_views" not in st.session_state:
|
106 |
st.session_state.filter_views = {}
|
107 |
if "current_view" not in st.session_state:
|
|
|
113 |
if "chat_graph_history" not in st.session_state:
|
114 |
st.session_state.chat_graph_history = []
|
115 |
|
116 |
+
global_graph = load_graph_from_pkl("./utils/assets/kg_ia_signature.pkl")
|
117 |
+
st.write("graphe global chargé")
|
118 |
+
st.session_state.graph = global_graph
|
119 |
+
st.write("graphe global assigné")
|
120 |
+
# st.session_state.current_chunk_index = 0
|
121 |
+
# st.session_state.filter_views = {}
|
122 |
+
# st.session_state.current_view = None
|
123 |
+
# st.session_state.node_types = None
|
124 |
+
# st.session_state.chat_graph_history = []
|
125 |
+
st.write("searching for node types")
|
126 |
+
node_types = get_node_types_advanced(st.session_state.graph)
|
127 |
+
st.write("types de noeuds obtenus")
|
128 |
+
list_node_types = list(node_types)
|
129 |
+
sorted_node_types = sorted(list_node_types,key=lambda x: x.lower())
|
130 |
+
print(sorted_node_types)
|
131 |
+
st.write("tri des types de noeuds effectué")
|
132 |
+
nodes_type_dict = list_to_dict_colors(sorted_node_types)
|
133 |
+
st.write("dictionnaire de types de noeuds créé")
|
134 |
+
st.session_state.node_types = nodes_type_dict
|
135 |
+
st.session_state.filter_views["Vue par défaut"] = list(node_types)
|
136 |
+
st.session_state.current_view = "Vue par défaut"
|
137 |
+
st.write("finished init")
|
138 |
+
#######################################################################
|
139 |
+
|
140 |
+
|
141 |
+
|
142 |
#init a radio button for the choice
|
143 |
if "radio_choice" not in st.session_state:
|
144 |
st.session_state.radio_choice = None
|
|
|
153 |
st.session_state.radio_choice = options.index(choice)
|
154 |
|
155 |
|
156 |
+
audit = {"Mots clés": ""}
|
157 |
+
content = {}
|
158 |
+
|
159 |
+
text = load_text_from_pkl("./utils/assets/scenes.pkl")
|
160 |
|
161 |
+
st.write(text)
|
162 |
+
|
|
|
|
|
|
|
|
|
163 |
|
164 |
prompt_cr = dedent(f'''
|
165 |
|
|
|
284 |
st_copy_to_clipboard(chat_formatted,key="cp_but_cr_chat",show_text=False)
|
285 |
# col_success_c.success("Historique copié !")
|
286 |
|
287 |
+
|
288 |
+
|
289 |
elif choice == "graphe de connaissance":
|
290 |
# st.write(st.session_state.graph)
|
291 |
if "graph" not in st.session_state or st.session_state.graph == None:
|
292 |
keywords_list = [keyword.strip() for keyword in audit["Mots clés"].strip().split(",")]
|
293 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
294 |
|
295 |
with st.spinner("Construction du graphe de connaissance..."):
|
296 |
|
297 |
#graph = get_graph(text,allowed_nodes=allowed_nodes_types)
|
298 |
# chunk = st.session_state.chunks[st.session_state.current_chunk_index]
|
299 |
# print(chunk)
|
300 |
+
graph = global_graph
|
301 |
st.session_state.graph = graph
|
302 |
st.session_state.current_chunk_index = 0
|
303 |
st.session_state.filter_views = {}
|
audit_page/knowledge_graph.py
CHANGED
@@ -1,410 +1,468 @@
|
|
1 |
import streamlit as st
|
2 |
-
|
3 |
-
from utils.audit.rag import get_text_from_content_for_doc,get_text_from_content_for_audio
|
4 |
-
from streamlit_agraph import agraph, Node, Edge, Config
|
5 |
import random
|
6 |
import math
|
|
|
|
|
|
|
|
|
7 |
from utils.audit.response_llm import generate_response_via_langchain
|
|
|
8 |
from langchain_core.messages import AIMessage, HumanMessage
|
9 |
from langchain_core.prompts import PromptTemplate
|
|
|
10 |
from itext2kg.models import KnowledgeGraph
|
11 |
|
12 |
-
def if_node_exists(nodes, node_id):
|
13 |
-
"""
|
14 |
-
Check if a node exists in the graph.
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
"""
|
23 |
for node in nodes:
|
24 |
if node.id == node_id:
|
25 |
return True
|
26 |
return False
|
27 |
|
28 |
def generate_random_color():
|
|
|
29 |
r = random.randint(180, 255)
|
30 |
g = random.randint(180, 255)
|
31 |
b = random.randint(180, 255)
|
32 |
return (r, g, b)
|
33 |
|
34 |
def rgb_to_hex(rgb):
|
|
|
35 |
return '#{:02x}{:02x}{:02x}'.format(rgb[0], rgb[1], rgb[2])
|
36 |
|
37 |
-
def get_node_types(graph):
|
38 |
-
node_types = set()
|
39 |
-
for node in graph.nodes:
|
40 |
-
node_types.add(node.type)
|
41 |
-
for relationship in graph.relationships:
|
42 |
-
source = relationship.source
|
43 |
-
target = relationship.target
|
44 |
-
node_types.add(source.type)
|
45 |
-
node_types.add(target.type)
|
46 |
-
return node_types
|
47 |
-
|
48 |
-
def get_node_types_advanced(graph:KnowledgeGraph):
|
49 |
-
node_types = set()
|
50 |
-
for node in graph.entities:
|
51 |
-
node_types.add(node.label)
|
52 |
-
for relationship in graph.relationships:
|
53 |
-
source = relationship.startEntity
|
54 |
-
target = relationship.endEntity
|
55 |
-
node_types.add(source.label)
|
56 |
-
node_types.add(target.label)
|
57 |
-
return node_types
|
58 |
-
|
59 |
-
|
60 |
def color_distance(color1, color2):
|
61 |
-
|
62 |
-
return math.sqrt(
|
|
|
|
|
|
|
|
|
63 |
|
64 |
def generate_distinct_colors(num_colors, min_distance=30):
|
|
|
|
|
|
|
|
|
65 |
colors = []
|
66 |
while len(colors) < num_colors:
|
67 |
new_color = generate_random_color()
|
68 |
-
if all(color_distance(new_color, existing_color) >= min_distance
|
|
|
69 |
colors.append(new_color)
|
70 |
return [rgb_to_hex(color) for color in colors]
|
71 |
|
72 |
-
def list_to_dict_colors(node_types
|
73 |
-
|
|
|
|
|
74 |
number_of_colors = len(node_types)
|
75 |
-
|
|
|
76 |
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
|
|
|
|
83 |
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
-
|
94 |
-
|
|
|
95 |
"""
|
96 |
nodes = []
|
97 |
edges = []
|
98 |
|
99 |
-
#
|
100 |
for node in neo4j_graph.nodes:
|
101 |
-
|
102 |
-
node_id = node.id.replace(" ", "_") # Replace spaces with underscores for ids
|
103 |
label = node.id
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
|
|
|
|
|
|
|
|
111 |
nodes.append(new_node)
|
112 |
|
113 |
-
#
|
114 |
-
for
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
edges.append(Edge(
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
return edges, nodes, config
|
145 |
|
146 |
-
def convert_advanced_neo4j_to_agraph(neo4j_graph:KnowledgeGraph, node_colors):
|
147 |
"""
|
148 |
-
|
149 |
-
|
150 |
-
Args:
|
151 |
-
neo4j_graph (dict): A dictionary representing the Neo4j graph with keys 'nodes' and 'relationships'.
|
152 |
-
'nodes' is a list of dicts with each dict having 'id' and 'type' keys.
|
153 |
-
'relationships' is a list of dicts with 'source', 'target', and 'type' keys.
|
154 |
-
|
155 |
-
Returns:
|
156 |
-
return_value: The Agraph visualization object.
|
157 |
"""
|
158 |
nodes = []
|
159 |
edges = []
|
160 |
|
161 |
-
#
|
162 |
for node in neo4j_graph.entities:
|
163 |
-
|
164 |
-
node_id = node.name.replace(" ", "_") # Replace spaces with underscores for ids
|
165 |
label = node.name
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
|
|
|
|
175 |
|
176 |
-
#
|
177 |
for relationship in neo4j_graph.relationships:
|
178 |
-
size = 25 # Default size, can be customized
|
179 |
-
shape = "circle" # Default shape, can be customized
|
180 |
-
|
181 |
source = relationship.startEntity
|
182 |
-
source_type = source.label
|
183 |
-
source_id = source.name.replace(" ", "_")
|
184 |
-
label_source = source.name
|
185 |
-
|
186 |
-
source_node = Node(id=source_id,title=source_type, label=label_source, size=size, shape=shape,color=node_colors[source_type])
|
187 |
-
# if not if_node_exists(nodes, source_node.id):
|
188 |
-
# nodes.append(source_node)
|
189 |
-
|
190 |
target = relationship.endEntity
|
191 |
-
target_type = target.label
|
192 |
-
target_id = target.name.replace(" ", "_")
|
193 |
-
label_target = target.name
|
194 |
-
|
195 |
-
target_node = Node(id=target_id,title=target_type, label=label_target, size=size, shape=shape,color=node_colors[target_type])
|
196 |
-
# if not if_node_exists(nodes, target_node.id):
|
197 |
-
# nodes.append(target_node)
|
198 |
-
|
199 |
-
label = relationship.name
|
200 |
-
|
201 |
-
edges.append(Edge(source=source_id, label=label, target=target_id))
|
202 |
|
|
|
|
|
203 |
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
return edges, nodes, config
|
209 |
|
210 |
def display_graph(edges, nodes, config):
|
211 |
-
|
212 |
return agraph(edges=edges, nodes=nodes, config=config)
|
213 |
|
214 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
|
216 |
-
def filter_nodes_by_types(nodes:list[Node], node_types_filter:list) -> list[Node]:
|
217 |
-
filtered_nodes = []
|
218 |
-
for node in nodes:
|
219 |
-
if node.title in node_types_filter: #the title represents the type of the node
|
220 |
-
filtered_nodes.append(node)
|
221 |
-
return filtered_nodes
|
222 |
|
|
|
|
|
|
|
223 |
@st.dialog(title="Changer la vue")
|
224 |
def change_view_dialog():
|
|
|
|
|
|
|
|
|
225 |
st.write("Changer la vue")
|
226 |
-
|
227 |
for index, item in enumerate(st.session_state.filter_views.keys()):
|
228 |
emp = st.empty()
|
229 |
col1, col2, col3 = emp.columns([8, 1, 1])
|
230 |
|
|
|
231 |
if index > 0 and col2.button("🗑️", key=f"del{index}"):
|
232 |
del st.session_state.filter_views[item]
|
233 |
st.session_state.current_view = "Vue par défaut"
|
234 |
st.rerun()
|
|
|
|
|
235 |
but_content = "🔍" if st.session_state.current_view != item else "✅"
|
236 |
if col3.button(but_content, key=f"valid{index}"):
|
237 |
st.session_state.current_view = item
|
238 |
st.rerun()
|
|
|
|
|
239 |
if len(st.session_state.filter_views.keys()) > index:
|
240 |
with col1.expander(item):
|
|
|
241 |
if index > 0:
|
242 |
-
change_name = st.text_input(
|
243 |
-
|
244 |
-
|
|
|
|
|
|
|
|
|
|
|
245 |
st.session_state.filter_views[change_name] = st.session_state.filter_views.pop(item)
|
246 |
st.session_state.current_view = change_name
|
247 |
st.rerun()
|
248 |
-
st.markdown(
|
|
|
|
|
|
|
249 |
else:
|
250 |
emp.empty()
|
251 |
|
|
|
252 |
@st.dialog(title="Ajouter une vue")
|
253 |
def add_view_dialog(filters):
|
|
|
|
|
|
|
254 |
st.write("Ajouter une vue")
|
255 |
view_name = st.text_input("Nom de la vue")
|
256 |
-
st.markdown("
|
257 |
st.write(filters)
|
258 |
if st.button("Ajouter la vue"):
|
259 |
-
|
260 |
-
|
|
|
261 |
st.rerun()
|
262 |
|
|
|
263 |
@st.dialog(title="Changer la couleur")
|
264 |
def change_color_dialog():
|
|
|
265 |
st.write("Changer la couleur")
|
266 |
-
for node_type,color in st.session_state.node_types.items():
|
267 |
-
|
268 |
-
|
|
|
|
|
|
|
269 |
|
270 |
if st.button("Valider"):
|
271 |
st.rerun()
|
272 |
|
273 |
|
|
|
|
|
|
|
274 |
|
275 |
def kg_main():
|
276 |
-
#
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
st.error("Veuillez d'abord effectuer un audit pour visualiser le graphe de connaissance.")
|
282 |
-
return
|
283 |
-
|
284 |
-
if "cr" not in st.session_state:
|
285 |
-
st.error("Veuillez d'abord effectuer un compte rendu pour visualiser le graphe de connaissance.")
|
286 |
-
return
|
287 |
|
288 |
if "graph" not in st.session_state:
|
289 |
-
|
290 |
-
|
|
|
|
|
|
|
|
|
|
|
291 |
if "filter_views" not in st.session_state:
|
292 |
st.session_state.filter_views = {}
|
293 |
if "current_view" not in st.session_state:
|
294 |
st.session_state.current_view = None
|
295 |
-
|
296 |
-
st.title("Graphe de connaissance")
|
297 |
-
|
298 |
if "node_types" not in st.session_state:
|
299 |
st.session_state.node_types = None
|
300 |
-
|
301 |
-
if "summary" not in st.session_state:
|
302 |
-
st.session_state.summary = None
|
303 |
-
|
304 |
if "chat_graph_history" not in st.session_state:
|
305 |
st.session_state.chat_graph_history = []
|
306 |
-
|
307 |
-
audit = st.session_state.audit_simplified
|
308 |
-
# content = st.session_state.audit["content"]
|
309 |
|
310 |
-
|
311 |
-
# text = get_text_from_content_for_doc(content)
|
312 |
-
# elif audit["type de fichier"] == "audio":
|
313 |
-
# text = get_text_from_content_for_audio(content)
|
314 |
|
315 |
-
|
316 |
-
|
317 |
-
#
|
318 |
-
|
319 |
-
|
320 |
-
#
|
321 |
-
|
322 |
-
#
|
323 |
-
|
324 |
-
|
325 |
-
keywords_list = audit["Mots clés"].strip().split(",")
|
326 |
-
allowed_nodes_types =keywords_list+ ["Person","Organization","Location","Event","Date","Time","Ressource","Concept"]
|
327 |
-
graph = get_graph(text,allowed_nodes=allowed_nodes_types)
|
328 |
-
st.session_state.graph = graph
|
329 |
-
|
330 |
-
node_types = get_node_types(graph[0])
|
331 |
-
nodes_type_dict = list_to_dict_colors(node_types)
|
332 |
-
st.session_state.node_types = nodes_type_dict
|
333 |
st.session_state.filter_views["Vue par défaut"] = list(node_types)
|
334 |
st.session_state.current_view = "Vue par défaut"
|
335 |
|
336 |
-
|
337 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
338 |
|
339 |
-
|
340 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
341 |
|
342 |
-
|
|
|
|
|
|
|
|
|
343 |
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
for i,prompt in enumerate(prompts):
|
404 |
-
button = st.button(prompt,key=f"p_{i}",on_click=lambda i=i: st.session_state.chat_graph_history.append(HumanMessage(content=prompts[i])))
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
node_types = st.session_state.node_types
|
410 |
-
|
|
|
1 |
import streamlit as st
|
2 |
+
import pickle
|
|
|
|
|
3 |
import random
|
4 |
import math
|
5 |
+
from streamlit_agraph import agraph, Node, Edge, Config
|
6 |
+
|
7 |
+
from utils.kg.construct_kg import get_graph # if still needed for something else
|
8 |
+
from utils.audit.rag import get_text_from_content_for_doc, get_text_from_content_for_audio
|
9 |
from utils.audit.response_llm import generate_response_via_langchain
|
10 |
+
from utils.audit.rag import get_vectorstore
|
11 |
from langchain_core.messages import AIMessage, HumanMessage
|
12 |
from langchain_core.prompts import PromptTemplate
|
13 |
+
|
14 |
from itext2kg.models import KnowledgeGraph
|
15 |
|
|
|
|
|
|
|
16 |
|
17 |
+
################################################################################
|
18 |
+
# Utility Functions
|
19 |
+
################################################################################
|
20 |
|
21 |
+
def if_node_exists(nodes, node_id):
|
22 |
+
"""Check if a node with the given id already exists in a list of Node objects."""
|
|
|
23 |
for node in nodes:
|
24 |
if node.id == node_id:
|
25 |
return True
|
26 |
return False
|
27 |
|
28 |
def generate_random_color():
|
29 |
+
"""Generate a random pastel-ish RGB color."""
|
30 |
r = random.randint(180, 255)
|
31 |
g = random.randint(180, 255)
|
32 |
b = random.randint(180, 255)
|
33 |
return (r, g, b)
|
34 |
|
35 |
def rgb_to_hex(rgb):
|
36 |
+
"""Convert an (R, G, B) tuple to a hex string like '#aabbcc'."""
|
37 |
return '#{:02x}{:02x}{:02x}'.format(rgb[0], rgb[1], rgb[2])
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
def color_distance(color1, color2):
|
40 |
+
"""Calculate Euclidean distance between two RGB colors."""
|
41 |
+
return math.sqrt(
|
42 |
+
(color1[0] - color2[0])**2 +
|
43 |
+
(color1[1] - color2[1])**2 +
|
44 |
+
(color1[2] - color2[2])**2
|
45 |
+
)
|
46 |
|
47 |
def generate_distinct_colors(num_colors, min_distance=30):
|
48 |
+
"""
|
49 |
+
Generate a list of distinct pastel-ish colors (in hex), ensuring each is
|
50 |
+
at least `min_distance` away from the others in RGB space.
|
51 |
+
"""
|
52 |
colors = []
|
53 |
while len(colors) < num_colors:
|
54 |
new_color = generate_random_color()
|
55 |
+
if all(color_distance(new_color, existing_color) >= min_distance
|
56 |
+
for existing_color in colors):
|
57 |
colors.append(new_color)
|
58 |
return [rgb_to_hex(color) for color in colors]
|
59 |
|
60 |
+
def list_to_dict_colors(node_types):
|
61 |
+
"""
|
62 |
+
Create a dict mapping each node type to a random (distinct) hex color.
|
63 |
+
"""
|
64 |
number_of_colors = len(node_types)
|
65 |
+
color_hexes = generate_distinct_colors(number_of_colors)
|
66 |
+
return {typ: color_hexes[i] for i, typ in enumerate(node_types)}
|
67 |
|
68 |
+
def get_node_types_advanced(graph: KnowledgeGraph):
|
69 |
+
"""
|
70 |
+
Extract the set of node labels from an itext2kg KnowledgeGraph.
|
71 |
+
(graph.entities have .label, relationships have .startEntity, .endEntity)
|
72 |
+
"""
|
73 |
+
node_types = set()
|
74 |
+
dict_node_colors = {}
|
75 |
+
for node in graph.entities:
|
76 |
+
node_types.add(node.label)
|
77 |
+
for relationship in graph.relationships:
|
78 |
+
node_types.add(relationship.startEntity.label)
|
79 |
+
node_types.add(relationship.endEntity.label)
|
80 |
|
81 |
+
dict_node_colors = {node:rgb_to_hex(generate_random_color()) for node in node_types}
|
82 |
+
return node_types, dict_node_colors
|
83 |
|
84 |
+
################################################################################
|
85 |
+
# Graph Conversion
|
86 |
+
################################################################################
|
87 |
|
88 |
+
def get_node_types(graph):
|
89 |
+
"""
|
90 |
+
Extract the set of node types from a graph that has:
|
91 |
+
graph.nodes -> [ Node(id, type) ... ]
|
92 |
+
graph.relationships -> [ Relationship(source, target, type) ... ]
|
93 |
+
"""
|
94 |
+
node_types = set()
|
95 |
+
for node in graph.nodes:
|
96 |
+
node_types.add(node.type)
|
97 |
+
for rel in graph.relationships:
|
98 |
+
node_types.add(rel.source.type)
|
99 |
+
node_types.add(rel.target.type)
|
100 |
+
return node_types
|
101 |
|
102 |
+
def convert_neo4j_to_agraph(neo4j_graph, node_colors):
|
103 |
+
"""
|
104 |
+
Convert a “Neo4j-like” object into Agraph Nodes & Edges.
|
105 |
"""
|
106 |
nodes = []
|
107 |
edges = []
|
108 |
|
109 |
+
# Create nodes
|
110 |
for node in neo4j_graph.nodes:
|
111 |
+
node_id = node.id.replace(" ", "_")
|
|
|
112 |
label = node.id
|
113 |
+
type_ = node.type
|
114 |
+
|
115 |
+
new_node = Node(
|
116 |
+
id=node_id,
|
117 |
+
title=type_, # 'title' effectively becomes "type"
|
118 |
+
label=label,
|
119 |
+
size=25,
|
120 |
+
shape="circle",
|
121 |
+
color=node_colors.get(type_, "#cccccc")
|
122 |
+
)
|
123 |
+
if not if_node_exists(nodes, node_id):
|
124 |
nodes.append(new_node)
|
125 |
|
126 |
+
# Create edges
|
127 |
+
for rel in neo4j_graph.relationships:
|
128 |
+
source_id = rel.source.id.replace(" ", "_")
|
129 |
+
target_id = rel.target.id.replace(" ", "_")
|
130 |
+
|
131 |
+
# Ensure nodes exist (if not from the loop above):
|
132 |
+
if not if_node_exists(nodes, source_id):
|
133 |
+
nodes.append(Node(
|
134 |
+
id=source_id,
|
135 |
+
title=rel.source.type,
|
136 |
+
label=rel.source.id,
|
137 |
+
size=25,
|
138 |
+
shape="circle",
|
139 |
+
color=node_colors.get(rel.source.type, "#cccccc")
|
140 |
+
))
|
141 |
+
if not if_node_exists(nodes, target_id):
|
142 |
+
nodes.append(Node(
|
143 |
+
id=target_id,
|
144 |
+
title=rel.target.type,
|
145 |
+
label=rel.target.id,
|
146 |
+
size=25,
|
147 |
+
shape="circle",
|
148 |
+
color=node_colors.get(rel.target.type, "#cccccc")
|
149 |
+
))
|
150 |
+
|
151 |
+
edges.append(Edge(
|
152 |
+
source=source_id,
|
153 |
+
label=rel.type,
|
154 |
+
target=target_id
|
155 |
+
))
|
156 |
+
|
157 |
+
config = Config(
|
158 |
+
width=1200,
|
159 |
+
height=800,
|
160 |
+
directed=True,
|
161 |
+
physics=True,
|
162 |
+
hierarchical=True,
|
163 |
+
from_json="config.json"
|
164 |
+
)
|
165 |
return edges, nodes, config
|
166 |
|
167 |
+
def convert_advanced_neo4j_to_agraph(neo4j_graph: KnowledgeGraph, node_colors):
|
168 |
"""
|
169 |
+
Same logic as above, but adapted to an itext2kg.models.KnowledgeGraph object
|
170 |
+
(graph.entities, graph.relationships).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
"""
|
172 |
nodes = []
|
173 |
edges = []
|
174 |
|
175 |
+
# Create nodes
|
176 |
for node in neo4j_graph.entities:
|
177 |
+
node_id = node.name.replace(" ", "_")
|
|
|
178 |
label = node.name
|
179 |
+
type_ = node.label
|
180 |
+
new_node = Node(
|
181 |
+
id=node_id,
|
182 |
+
title=type_,
|
183 |
+
label=label,
|
184 |
+
size=25,
|
185 |
+
shape="circle",
|
186 |
+
color=node_colors[type_]
|
187 |
+
)
|
188 |
+
if not if_node_exists(nodes, new_node.id):
|
189 |
+
nodes.append(new_node)
|
190 |
|
191 |
+
# Create edges
|
192 |
for relationship in neo4j_graph.relationships:
|
|
|
|
|
|
|
193 |
source = relationship.startEntity
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
target = relationship.endEntity
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
|
196 |
+
source_id = source.name.replace(" ", "_")
|
197 |
+
target_id = target.name.replace(" ", "_")
|
198 |
|
199 |
+
# Ensure existence of the source node
|
200 |
+
if not if_node_exists(nodes, source_id):
|
201 |
+
nodes.append(Node(
|
202 |
+
id=source_id,
|
203 |
+
title=source.label,
|
204 |
+
label=source.name,
|
205 |
+
size=25,
|
206 |
+
shape="circle",
|
207 |
+
color=node_colors.get(source.label, "#CCCCCC")
|
208 |
+
))
|
209 |
+
|
210 |
+
# Ensure existence of the target node
|
211 |
+
if not if_node_exists(nodes, target_id):
|
212 |
+
nodes.append(Node(
|
213 |
+
id=target_id,
|
214 |
+
title=target.label,
|
215 |
+
label=target.name,
|
216 |
+
size=25,
|
217 |
+
shape="circle",
|
218 |
+
color=node_colors.get(target.label, "#CCCCCC")
|
219 |
+
))
|
220 |
+
|
221 |
+
edges.append(Edge(
|
222 |
+
source=source_id,
|
223 |
+
label=relationship.name,
|
224 |
+
target=target_id
|
225 |
+
))
|
226 |
+
|
227 |
+
config = Config(
|
228 |
+
width=1200,
|
229 |
+
height=800,
|
230 |
+
directed=True,
|
231 |
+
physics=True,
|
232 |
+
hierarchical=True,
|
233 |
+
from_json="config.json"
|
234 |
+
)
|
235 |
return edges, nodes, config
|
236 |
|
237 |
def display_graph(edges, nodes, config):
|
238 |
+
"""Render Agraph."""
|
239 |
return agraph(edges=edges, nodes=nodes, config=config)
|
240 |
|
241 |
|
242 |
+
def filter_nodes_by_types(nodes, node_types_filter):
|
243 |
+
"""
|
244 |
+
Filter out Agraph nodes by the node’s 'title' field (which is used as 'type' here).
|
245 |
+
"""
|
246 |
+
if not node_types_filter:
|
247 |
+
return nodes
|
248 |
+
return [node for node in nodes if node.title in node_types_filter]
|
249 |
|
|
|
|
|
|
|
|
|
|
|
|
|
250 |
|
251 |
+
################################################################################
|
252 |
+
# Dialog Components (same as your original code)
|
253 |
+
################################################################################
|
254 |
@st.dialog(title="Changer la vue")
|
255 |
def change_view_dialog():
|
256 |
+
"""
|
257 |
+
Dialog to rename or delete existing views from st.session_state.filter_views
|
258 |
+
and choose the active one (st.session_state.current_view).
|
259 |
+
"""
|
260 |
st.write("Changer la vue")
|
|
|
261 |
for index, item in enumerate(st.session_state.filter_views.keys()):
|
262 |
emp = st.empty()
|
263 |
col1, col2, col3 = emp.columns([8, 1, 1])
|
264 |
|
265 |
+
# Delete the view (except for the default if you want)
|
266 |
if index > 0 and col2.button("🗑️", key=f"del{index}"):
|
267 |
del st.session_state.filter_views[item]
|
268 |
st.session_state.current_view = "Vue par défaut"
|
269 |
st.rerun()
|
270 |
+
|
271 |
+
# Choose the view
|
272 |
but_content = "🔍" if st.session_state.current_view != item else "✅"
|
273 |
if col3.button(but_content, key=f"valid{index}"):
|
274 |
st.session_state.current_view = item
|
275 |
st.rerun()
|
276 |
+
|
277 |
+
# Show details / rename
|
278 |
if len(st.session_state.filter_views.keys()) > index:
|
279 |
with col1.expander(item):
|
280 |
+
# Don’t allow renaming the default view (index=0) if you want
|
281 |
if index > 0:
|
282 |
+
change_name = st.text_input(
|
283 |
+
"Nom de la vue",
|
284 |
+
label_visibility="collapsed",
|
285 |
+
placeholder="Changez le nom de la vue",
|
286 |
+
key=f"change_name{index}"
|
287 |
+
)
|
288 |
+
if st.button("Renommer", key=f"rename{index}"):
|
289 |
+
if change_name.strip():
|
290 |
st.session_state.filter_views[change_name] = st.session_state.filter_views.pop(item)
|
291 |
st.session_state.current_view = change_name
|
292 |
st.rerun()
|
293 |
+
st.markdown(
|
294 |
+
"\n".join(f"- {label.strip()}"
|
295 |
+
for label in st.session_state.filter_views[item])
|
296 |
+
)
|
297 |
else:
|
298 |
emp.empty()
|
299 |
|
300 |
+
|
301 |
@st.dialog(title="Ajouter une vue")
|
302 |
def add_view_dialog(filters):
|
303 |
+
"""
|
304 |
+
Dialog to add a new “view” to st.session_state.filter_views, specifying which types to filter by.
|
305 |
+
"""
|
306 |
st.write("Ajouter une vue")
|
307 |
view_name = st.text_input("Nom de la vue")
|
308 |
+
st.markdown("Les filtres actuels :")
|
309 |
st.write(filters)
|
310 |
if st.button("Ajouter la vue"):
|
311 |
+
if view_name.strip():
|
312 |
+
st.session_state.filter_views[view_name] = filters
|
313 |
+
st.session_state.current_view = view_name
|
314 |
st.rerun()
|
315 |
|
316 |
+
|
317 |
@st.dialog(title="Changer la couleur")
|
318 |
def change_color_dialog():
|
319 |
+
"""Dialog to interactively change colors of each node type via color pickers."""
|
320 |
st.write("Changer la couleur")
|
321 |
+
for node_type, color in st.session_state.node_types.items():
|
322 |
+
new_color = st.color_picker(
|
323 |
+
f"La couleur de l'entité **{node_type.strip()}**",
|
324 |
+
color
|
325 |
+
)
|
326 |
+
st.session_state.node_types[node_type] = new_color
|
327 |
|
328 |
if st.button("Valider"):
|
329 |
st.rerun()
|
330 |
|
331 |
|
332 |
+
################################################################################
|
333 |
+
# Main KG Function
|
334 |
+
################################################################################
|
335 |
|
336 |
def kg_main():
|
337 |
+
# 1. Load your pickles (if not already loaded in session state)
|
338 |
+
if "scenes" not in st.session_state:
|
339 |
+
with open("./utils/assets/scenes.pkl", "rb") as f:
|
340 |
+
st.session_state.scenes = pickle.load(f)
|
341 |
+
st.session_state.vectorstore = get_vectorstore(st.session_state.scenes)
|
|
|
|
|
|
|
|
|
|
|
|
|
342 |
|
343 |
if "graph" not in st.session_state:
|
344 |
+
with open("./utils/assets/kg_ia_signature.pkl", "rb") as f:
|
345 |
+
# Depending on how you stored it, it might be a tuple (graph, extra_info)
|
346 |
+
# or directly a single object. Adjust as needed.
|
347 |
+
st.session_state.graph = pickle.load(f)
|
348 |
+
print("Graph loaded.")
|
349 |
+
|
350 |
+
# 2. Initialize other session keys if they don’t exist
|
351 |
if "filter_views" not in st.session_state:
|
352 |
st.session_state.filter_views = {}
|
353 |
if "current_view" not in st.session_state:
|
354 |
st.session_state.current_view = None
|
|
|
|
|
|
|
355 |
if "node_types" not in st.session_state:
|
356 |
st.session_state.node_types = None
|
|
|
|
|
|
|
|
|
357 |
if "chat_graph_history" not in st.session_state:
|
358 |
st.session_state.chat_graph_history = []
|
|
|
|
|
|
|
359 |
|
360 |
+
st.title("Graphe de connaissance")
|
|
|
|
|
|
|
361 |
|
362 |
+
edges,nodes,config = None, None, None
|
363 |
+
|
364 |
+
# If we haven’t set up node types yet, do it now
|
365 |
+
if st.session_state.node_types is None:
|
366 |
+
# st.session_state.graph is presumably a list/tuple => st.session_state.graph[0]
|
367 |
+
# Or just st.session_state.graph if you stored it directly as a single obj
|
368 |
+
node_types, st.session_state.node_types = get_node_types_advanced(st.session_state.graph)
|
369 |
+
# st.write(f"Types d'entités trouvés : {node_types}")
|
370 |
+
print("Couleurs attribuées")
|
371 |
+
# Initialize a default filter view
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
372 |
st.session_state.filter_views["Vue par défaut"] = list(node_types)
|
373 |
st.session_state.current_view = "Vue par défaut"
|
374 |
|
375 |
+
# 3. Convert the graph to agraph format
|
376 |
+
edges, nodes, config = convert_advanced_neo4j_to_agraph(
|
377 |
+
st.session_state.graph, # or st.session_state.graph[0] if needed
|
378 |
+
st.session_state.node_types
|
379 |
+
)
|
380 |
+
print("Graph converti en Agraph")
|
381 |
+
|
382 |
+
# 4. UI layout: (left) the graph itself, (right) the chat
|
383 |
+
col1, col2 = st.columns([3, 1])
|
384 |
|
385 |
+
with col1.container(border=True,height=800):
|
386 |
+
st.write(f"#### Visualisation du graphe (**{st.session_state.current_view}**)")
|
387 |
+
|
388 |
+
filter_col, add_view_col, change_view_col, color_col = st.columns([9, 1, 1, 1])
|
389 |
+
|
390 |
+
if color_col.button("🎨", help="Changer la couleur"):
|
391 |
+
change_color_dialog()
|
392 |
+
|
393 |
+
if change_view_col.button("🔍", help="Changer de vue"):
|
394 |
+
change_view_dialog()
|
395 |
+
|
396 |
+
# Currently selected filter for the chosen view
|
397 |
+
current_filters = st.session_state.filter_views.get(st.session_state.current_view, [])
|
398 |
+
filter_selection = filter_col.multiselect(
|
399 |
+
"Filtrer selon l'étiquette",
|
400 |
+
st.session_state.node_types.keys(),
|
401 |
+
default=current_filters,
|
402 |
+
label_visibility="collapsed"
|
403 |
+
)
|
404 |
|
405 |
+
if add_view_col.button("➕", help="Ajouter une vue"):
|
406 |
+
add_view_dialog(filter_selection)
|
407 |
+
|
408 |
+
# Filter out nodes that don’t match the chosen types
|
409 |
+
filtered_nodes = filter_nodes_by_types(nodes, filter_selection)
|
410 |
|
411 |
+
# Render the graph
|
412 |
+
selected_node_id = display_graph(edges, filtered_nodes, config)
|
413 |
+
|
414 |
+
# 5. Chat UI
|
415 |
+
with col2.container(border=True,height=800):
|
416 |
+
st.markdown("#### Dialoguer avec le graphe")
|
417 |
+
user_query = st.chat_input("Votre question ...")
|
418 |
+
if user_query:
|
419 |
+
st.session_state.chat_graph_history.append(HumanMessage(content=user_query))
|
420 |
+
|
421 |
+
with st.container():
|
422 |
+
# Display the existing chat
|
423 |
+
for message in st.session_state.chat_graph_history:
|
424 |
+
if isinstance(message, AIMessage):
|
425 |
+
with st.chat_message("AI"):
|
426 |
+
st.markdown(message.content)
|
427 |
+
elif isinstance(message, HumanMessage):
|
428 |
+
with st.chat_message("Human"):
|
429 |
+
st.write(message.content)
|
430 |
+
|
431 |
+
# If the last message is from the user, we try to generate a response
|
432 |
+
if (len(st.session_state.chat_graph_history) > 0 and
|
433 |
+
isinstance(st.session_state.chat_graph_history[-1], HumanMessage)):
|
434 |
+
last_message = st.session_state.chat_graph_history[-1]
|
435 |
+
with st.chat_message("AI"):
|
436 |
+
# Example retrieval (if you have a vectorstore in session state)
|
437 |
+
# and want to incorporate scenes or graph data:
|
438 |
+
if "vectorstore" in st.session_state:
|
439 |
+
retriever = st.session_state.vectorstore.as_retriever()
|
440 |
+
context = retriever.invoke(last_message.content)
|
441 |
+
prompt = (
|
442 |
+
f"Contexte depuis les 'scenes': {st.session_state.scenes}\n"
|
443 |
+
f"Contexte vectorstore: {context}\n"
|
444 |
+
f"Question: {last_message.content}\n"
|
445 |
+
f"Graph: {st.session_state.graph}\n" # If you want to embed your entire graph
|
446 |
+
)
|
447 |
+
response = st.write_stream(
|
448 |
+
generate_response_via_langchain(prompt, stream=True)
|
449 |
+
)
|
450 |
+
st.session_state.chat_graph_history.append(AIMessage(content=response))
|
451 |
+
else:
|
452 |
+
# Fallback if no vectorstore
|
453 |
+
st.write("Aucune base de vecteurs disponible.")
|
454 |
+
st.session_state.chat_graph_history.append(AIMessage(content="(Pas de vectorstore)"))
|
455 |
+
|
456 |
+
# If the user clicked on a node in the graph, we can propose quick prompts
|
457 |
+
if selected_node_id:
|
458 |
+
with st.chat_message("AI"):
|
459 |
+
st.markdown(f"**Vous avez sélectionné**: `{selected_node_id}`")
|
460 |
+
quick_prompts = [
|
461 |
+
f"Donne-moi plus d'informations sur le noeud '{selected_node_id}'",
|
462 |
+
f"Montre-moi les relations de '{selected_node_id}' dans ce graphe"
|
463 |
+
]
|
464 |
+
for i, qprompt in enumerate(quick_prompts):
|
465 |
+
if st.button(qprompt, key=f"qp_{i}"):
|
466 |
+
st.session_state.chat_graph_history.append(HumanMessage(content=qprompt))
|
467 |
+
|
468 |
+
kg_main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
utils/assets/kg_ia_signature.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:55b49436038a45405798f6d05591464b1a35360409d83dbead163921707ac592
|
3 |
+
size 7354091
|
utils/assets/scenes.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:17fc4636b752c5b8f1434d0c97c95ea3b12605b083689e6d79daacd060f6c110
|
3 |
+
size 142917
|