Ilyas KHIAT commited on
Commit
5e3fa8e
·
1 Parent(s): 6402624
.streamlit/config.toml CHANGED
@@ -2,7 +2,7 @@
2
  maxUploadSize = 20
3
 
4
  [theme]
5
- base="light"
6
  primaryColor="#63abdf"
7
  secondaryBackgroundColor="#fbf7f1"
8
  textColor="#011166"
 
2
  maxUploadSize = 20
3
 
4
  [theme]
5
+ base="dark"
6
  primaryColor="#63abdf"
7
  secondaryBackgroundColor="#fbf7f1"
8
  textColor="#011166"
app.py CHANGED
@@ -8,21 +8,18 @@ def main():
8
 
9
  st.set_page_config(page_title="RAG Agent", page_icon="🤖", layout="wide")
10
 
11
- audit_page = st.Page("audit_page/audit.py", title="Audit", icon="📋", default=True)
12
  dialog_page = st.Page("audit_page/dialogue_doc.py", title="Dialoguer avec le document", icon="💬")
13
  kg_page = st.Page("audit_page/knowledge_graph.py", title="Graphe de connaissance", icon="🧠")
14
- agents_page = st.Page("agents_page/catalogue.py", title="Catalogue des agents", icon="📇")
15
- compte_rendu = st.Page("audit_page/compte_rendu.py", title="Compte rendu", icon="📝")
16
- recommended_agents = st.Page("agents_page/recommended_agent.py", title="Agents recommandés", icon="⭐")
17
- chatbot = st.Page("chatbot_page/chatbot.py", title="Chatbot", icon="💬")
18
- documentation = st.Page("doc_page/documentation.py", title="Documentation", icon="📚")
19
 
20
  pg = st.navigation(
21
  {
22
- "Audit de contenus": [audit_page,dialog_page],
23
- "Equipe d'agents IA": [recommended_agents],
24
- "Chatbot": [chatbot],
25
- "Documentation": [documentation]
26
  }
27
  )
28
 
 
8
 
9
  st.set_page_config(page_title="RAG Agent", page_icon="🤖", layout="wide")
10
 
11
+ # audit_page = st.Page("audit_page/audit.py", title="Audit", icon="📋", default=True)
12
  dialog_page = st.Page("audit_page/dialogue_doc.py", title="Dialoguer avec le document", icon="💬")
13
  kg_page = st.Page("audit_page/knowledge_graph.py", title="Graphe de connaissance", icon="🧠")
14
+ # agents_page = st.Page("agents_page/catalogue.py", title="Catalogue des agents", icon="📇")
15
+ # compte_rendu = st.Page("audit_page/compte_rendu.py", title="Compte rendu", icon="📝")
16
+ # recommended_agents = st.Page("agents_page/recommended_agent.py", title="Agents recommandés", icon="⭐")
17
+ # chatbot = st.Page("chatbot_page/chatbot.py", title="Chatbot", icon="💬")
18
+ # documentation = st.Page("doc_page/documentation.py", title="Documentation", icon="📚")
19
 
20
  pg = st.navigation(
21
  {
22
+ "Graphe de connaissance": [kg_page],
 
 
 
23
  }
24
  )
25
 
audit_page/dialogue_doc.py CHANGED
@@ -8,6 +8,7 @@ from utils.kg.construct_kg import get_graph,get_advanced_graph
8
  from audit_page.knowledge_graph import *
9
  import json
10
  from time import sleep
 
11
 
12
  def graph_doc_to_json(graph):
13
  nodes = []
@@ -75,13 +76,17 @@ def format_cr(cr:report):
75
  formatted_cr = f"### Résumé :\n{cr.summary}\n\n### Notes :\n{cr.Notes}\n\n### Actions :\n{cr.Actions}"
76
  return formatted_cr
77
 
 
 
 
 
 
 
 
 
78
 
79
  def doc_dialog_main():
80
  st.title("Dialogue avec le document")
81
-
82
- if "audit" not in st.session_state or st.session_state.audit == {}:
83
- st.error("Veuillez d'abord effectuer un audit pour générer le compte rendu ou le graphe de connaissance.")
84
- return
85
 
86
  #init cr and chat history cr
87
  if "cr" not in st.session_state:
@@ -96,7 +101,7 @@ def doc_dialog_main():
96
  st.session_state.current_chunk_index = 0
97
  st.session_state.number_of_entities = 0
98
  st.session_state.number_of_relationships = 0
99
-
100
  if "filter_views" not in st.session_state:
101
  st.session_state.filter_views = {}
102
  if "current_view" not in st.session_state:
@@ -108,6 +113,32 @@ def doc_dialog_main():
108
  if "chat_graph_history" not in st.session_state:
109
  st.session_state.chat_graph_history = []
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  #init a radio button for the choice
112
  if "radio_choice" not in st.session_state:
113
  st.session_state.radio_choice = None
@@ -122,15 +153,13 @@ def doc_dialog_main():
122
  st.session_state.radio_choice = options.index(choice)
123
 
124
 
125
- audit = st.session_state.audit_simplified
126
- content = st.session_state.audit["content"]
 
 
127
 
128
- if audit["type de fichier"] == "pdf":
129
- text = get_text_from_content_for_doc(content)
130
- elif audit["type de fichier"] == "audio":
131
- text = get_text_from_content_for_audio(content)
132
- elif audit["type de fichier"] == "text":
133
- text = content
134
 
135
  prompt_cr = dedent(f'''
136
 
@@ -255,24 +284,20 @@ def doc_dialog_main():
255
  st_copy_to_clipboard(chat_formatted,key="cp_but_cr_chat",show_text=False)
256
  # col_success_c.success("Historique copié !")
257
 
 
 
258
  elif choice == "graphe de connaissance":
259
  # st.write(st.session_state.graph)
260
  if "graph" not in st.session_state or st.session_state.graph == None:
261
  keywords_list = [keyword.strip() for keyword in audit["Mots clés"].strip().split(",")]
262
- allowed_nodes_types =keywords_list+ ["Person","Organization","Location","Event","Date","Time","Ressource","Concept"]
263
- number_tokens = audit["Nombre de tokens"]
264
- with st.spinner("Synthétisation des informations..."):
265
- if st.session_state.cr == "":
266
- st.session_state.cr = generate_structured_response(prompt_cr2)
267
- entete = f"**RESUME:** {st.session_state.cr.summary} \n **MOTS CLES:** {keywords_list}"
268
-
269
 
270
  with st.spinner("Construction du graphe de connaissance..."):
271
 
272
  #graph = get_graph(text,allowed_nodes=allowed_nodes_types)
273
  # chunk = st.session_state.chunks[st.session_state.current_chunk_index]
274
  # print(chunk)
275
- graph = get_advanced_graph(format_cr(st.session_state.cr),knowledge_graph=None)
276
  st.session_state.graph = graph
277
  st.session_state.current_chunk_index = 0
278
  st.session_state.filter_views = {}
 
8
  from audit_page.knowledge_graph import *
9
  import json
10
  from time import sleep
11
+ import pickle
12
 
13
  def graph_doc_to_json(graph):
14
  nodes = []
 
76
  formatted_cr = f"### Résumé :\n{cr.summary}\n\n### Notes :\n{cr.Notes}\n\n### Actions :\n{cr.Actions}"
77
  return formatted_cr
78
 
79
+ def load_text_from_pkl(file_path:str):
80
+ with open(file_path,"rb") as f:
81
+ return pickle.load(f)
82
+
83
+ def load_graph_from_pkl(file_path:str):
84
+ with open(file_path,"rb") as f:
85
+ return pickle.load(f)
86
+
87
 
88
  def doc_dialog_main():
89
  st.title("Dialogue avec le document")
 
 
 
 
90
 
91
  #init cr and chat history cr
92
  if "cr" not in st.session_state:
 
101
  st.session_state.current_chunk_index = 0
102
  st.session_state.number_of_entities = 0
103
  st.session_state.number_of_relationships = 0
104
+
105
  if "filter_views" not in st.session_state:
106
  st.session_state.filter_views = {}
107
  if "current_view" not in st.session_state:
 
113
  if "chat_graph_history" not in st.session_state:
114
  st.session_state.chat_graph_history = []
115
 
116
+ global_graph = load_graph_from_pkl("./utils/assets/kg_ia_signature.pkl")
117
+ st.write("graphe global chargé")
118
+ st.session_state.graph = global_graph
119
+ st.write("graphe global assigné")
120
+ # st.session_state.current_chunk_index = 0
121
+ # st.session_state.filter_views = {}
122
+ # st.session_state.current_view = None
123
+ # st.session_state.node_types = None
124
+ # st.session_state.chat_graph_history = []
125
+ st.write("searching for node types")
126
+ node_types = get_node_types_advanced(st.session_state.graph)
127
+ st.write("types de noeuds obtenus")
128
+ list_node_types = list(node_types)
129
+ sorted_node_types = sorted(list_node_types,key=lambda x: x.lower())
130
+ print(sorted_node_types)
131
+ st.write("tri des types de noeuds effectué")
132
+ nodes_type_dict = list_to_dict_colors(sorted_node_types)
133
+ st.write("dictionnaire de types de noeuds créé")
134
+ st.session_state.node_types = nodes_type_dict
135
+ st.session_state.filter_views["Vue par défaut"] = list(node_types)
136
+ st.session_state.current_view = "Vue par défaut"
137
+ st.write("finished init")
138
+ #######################################################################
139
+
140
+
141
+
142
  #init a radio button for the choice
143
  if "radio_choice" not in st.session_state:
144
  st.session_state.radio_choice = None
 
153
  st.session_state.radio_choice = options.index(choice)
154
 
155
 
156
+ audit = {"Mots clés": ""}
157
+ content = {}
158
+
159
+ text = load_text_from_pkl("./utils/assets/scenes.pkl")
160
 
161
+ st.write(text)
162
+
 
 
 
 
163
 
164
  prompt_cr = dedent(f'''
165
 
 
284
  st_copy_to_clipboard(chat_formatted,key="cp_but_cr_chat",show_text=False)
285
  # col_success_c.success("Historique copié !")
286
 
287
+
288
+
289
  elif choice == "graphe de connaissance":
290
  # st.write(st.session_state.graph)
291
  if "graph" not in st.session_state or st.session_state.graph == None:
292
  keywords_list = [keyword.strip() for keyword in audit["Mots clés"].strip().split(",")]
293
+
 
 
 
 
 
 
294
 
295
  with st.spinner("Construction du graphe de connaissance..."):
296
 
297
  #graph = get_graph(text,allowed_nodes=allowed_nodes_types)
298
  # chunk = st.session_state.chunks[st.session_state.current_chunk_index]
299
  # print(chunk)
300
+ graph = global_graph
301
  st.session_state.graph = graph
302
  st.session_state.current_chunk_index = 0
303
  st.session_state.filter_views = {}
audit_page/knowledge_graph.py CHANGED
@@ -1,410 +1,468 @@
1
  import streamlit as st
2
- from utils.kg.construct_kg import get_graph
3
- from utils.audit.rag import get_text_from_content_for_doc,get_text_from_content_for_audio
4
- from streamlit_agraph import agraph, Node, Edge, Config
5
  import random
6
  import math
 
 
 
 
7
  from utils.audit.response_llm import generate_response_via_langchain
 
8
  from langchain_core.messages import AIMessage, HumanMessage
9
  from langchain_core.prompts import PromptTemplate
 
10
  from itext2kg.models import KnowledgeGraph
11
 
12
- def if_node_exists(nodes, node_id):
13
- """
14
- Check if a node exists in the graph.
15
 
16
- Args:
17
- graph (dict): A dictionary representing the graph with keys 'nodes' and 'relationships'.
18
- node_id (str): The id of the node to check.
19
 
20
- Returns:
21
- return_value: True if the node exists, False otherwise.
22
- """
23
  for node in nodes:
24
  if node.id == node_id:
25
  return True
26
  return False
27
 
28
  def generate_random_color():
 
29
  r = random.randint(180, 255)
30
  g = random.randint(180, 255)
31
  b = random.randint(180, 255)
32
  return (r, g, b)
33
 
34
  def rgb_to_hex(rgb):
 
35
  return '#{:02x}{:02x}{:02x}'.format(rgb[0], rgb[1], rgb[2])
36
 
37
- def get_node_types(graph):
38
- node_types = set()
39
- for node in graph.nodes:
40
- node_types.add(node.type)
41
- for relationship in graph.relationships:
42
- source = relationship.source
43
- target = relationship.target
44
- node_types.add(source.type)
45
- node_types.add(target.type)
46
- return node_types
47
-
48
- def get_node_types_advanced(graph:KnowledgeGraph):
49
- node_types = set()
50
- for node in graph.entities:
51
- node_types.add(node.label)
52
- for relationship in graph.relationships:
53
- source = relationship.startEntity
54
- target = relationship.endEntity
55
- node_types.add(source.label)
56
- node_types.add(target.label)
57
- return node_types
58
-
59
-
60
  def color_distance(color1, color2):
61
- # Calculate Euclidean distance between two RGB colors
62
- return math.sqrt((color1[0] - color2[0]) ** 2 + (color1[1] - color2[1]) ** 2 + (color1[2] - color2[2]) ** 2)
 
 
 
 
63
 
64
  def generate_distinct_colors(num_colors, min_distance=30):
 
 
 
 
65
  colors = []
66
  while len(colors) < num_colors:
67
  new_color = generate_random_color()
68
- if all(color_distance(new_color, existing_color) >= min_distance for existing_color in colors):
 
69
  colors.append(new_color)
70
  return [rgb_to_hex(color) for color in colors]
71
 
72
- def list_to_dict_colors(node_types:set):
73
-
 
 
74
  number_of_colors = len(node_types)
75
- colors = generate_distinct_colors(number_of_colors)
 
76
 
77
- node_colors = {}
78
- for i, node_type in enumerate(node_types):
79
- node_colors[node_type] = colors[i]
80
-
81
- return node_colors
 
 
 
 
 
 
 
82
 
 
 
83
 
84
- def convert_neo4j_to_agraph(neo4j_graph, node_colors):
85
- """
86
- Converts a Neo4j graph into an Agraph format.
87
 
88
- Args:
89
- neo4j_graph (dict): A dictionary representing the Neo4j graph with keys 'nodes' and 'relationships'.
90
- 'nodes' is a list of dicts with each dict having 'id' and 'type' keys.
91
- 'relationships' is a list of dicts with 'source', 'target', and 'type' keys.
 
 
 
 
 
 
 
 
 
92
 
93
- Returns:
94
- return_value: The Agraph visualization object.
 
95
  """
96
  nodes = []
97
  edges = []
98
 
99
- # Creating Agraph nodes
100
  for node in neo4j_graph.nodes:
101
- # Use the node id as the Agraph node id
102
- node_id = node.id.replace(" ", "_") # Replace spaces with underscores for ids
103
  label = node.id
104
- type = node.type
105
- size = 25 # Default size, can be customized
106
- shape = "circle" # Default shape, can be customized
107
-
108
- # For example purposes, no images are added, but you can set 'image' if needed.
109
- new_node = Node(id=node_id,title=type, label=label, size=size, shape=shape,color=node_colors[type])
110
- if not if_node_exists(nodes, new_node.id):
 
 
 
 
111
  nodes.append(new_node)
112
 
113
- # Creating Agraph edges
114
- for relationship in neo4j_graph.relationships:
115
- size = 25 # Default size, can be customized
116
- shape = "circle" # Default shape, can be customized
117
-
118
- source = relationship.source
119
- source_type = source.type
120
- source_id = source.id.replace(" ", "_")
121
- label_source = source.id
122
-
123
- source_node = Node(id=source_id,title=source_type, label=label_source, size=size, shape=shape,color=node_colors[source_type])
124
- if not if_node_exists(nodes, source_node.id):
125
- nodes.append(source_node)
126
-
127
- target = relationship.target
128
- target_type = target.type
129
- target_id = target.id.replace(" ", "_")
130
- label_target = target.id
131
-
132
- target_node = Node(id=target_id,title=target_type, label=label_target, size=size, shape=shape,color=node_colors[target_type])
133
- if not if_node_exists(nodes, target_node.id):
134
- nodes.append(target_node)
135
-
136
- label = relationship.type
137
-
138
- edges.append(Edge(source=source_id, label=label, target=target_id))
139
-
140
- # Define the configuration for Agraph
141
- config = Config(width=1200, height=800, directed=True, physics=True, hierarchical=True,from_json="config.json")
142
- # Create the Agraph visualization
143
-
 
 
 
 
 
 
 
 
144
  return edges, nodes, config
145
 
146
- def convert_advanced_neo4j_to_agraph(neo4j_graph:KnowledgeGraph, node_colors):
147
  """
148
- Converts a Neo4j graph into an Agraph format.
149
-
150
- Args:
151
- neo4j_graph (dict): A dictionary representing the Neo4j graph with keys 'nodes' and 'relationships'.
152
- 'nodes' is a list of dicts with each dict having 'id' and 'type' keys.
153
- 'relationships' is a list of dicts with 'source', 'target', and 'type' keys.
154
-
155
- Returns:
156
- return_value: The Agraph visualization object.
157
  """
158
  nodes = []
159
  edges = []
160
 
161
- # Creating Agraph nodes
162
  for node in neo4j_graph.entities:
163
- # Use the node id as the Agraph node id
164
- node_id = node.name.replace(" ", "_") # Replace spaces with underscores for ids
165
  label = node.name
166
- type = node.label
167
- size = 25 # Default size, can be customized
168
- shape = "circle" # Default shape, can be customized
169
-
170
- # For example purposes, no images are added, but you can set 'image' if needed.
171
- new_node = Node(id=node_id,title=type, label=label, size=size, shape=shape,color=node_colors[type])
172
- # if not if_node_exists(nodes, new_node.id):
173
- # nodes.append(new_node)
174
- nodes.append(new_node)
 
 
175
 
176
- # Creating Agraph edges
177
  for relationship in neo4j_graph.relationships:
178
- size = 25 # Default size, can be customized
179
- shape = "circle" # Default shape, can be customized
180
-
181
  source = relationship.startEntity
182
- source_type = source.label
183
- source_id = source.name.replace(" ", "_")
184
- label_source = source.name
185
-
186
- source_node = Node(id=source_id,title=source_type, label=label_source, size=size, shape=shape,color=node_colors[source_type])
187
- # if not if_node_exists(nodes, source_node.id):
188
- # nodes.append(source_node)
189
-
190
  target = relationship.endEntity
191
- target_type = target.label
192
- target_id = target.name.replace(" ", "_")
193
- label_target = target.name
194
-
195
- target_node = Node(id=target_id,title=target_type, label=label_target, size=size, shape=shape,color=node_colors[target_type])
196
- # if not if_node_exists(nodes, target_node.id):
197
- # nodes.append(target_node)
198
-
199
- label = relationship.name
200
-
201
- edges.append(Edge(source=source_id, label=label, target=target_id))
202
 
 
 
203
 
204
- # Define the configuration
205
- config = Config(width=1200, height=800, directed=True, physics=True, hierarchical=True,from_json="config.json")
206
- # Create the Agraph visualization
207
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  return edges, nodes, config
209
 
210
  def display_graph(edges, nodes, config):
211
- # Display the Agraph visualization
212
  return agraph(edges=edges, nodes=nodes, config=config)
213
 
214
 
 
 
 
 
 
 
 
215
 
216
- def filter_nodes_by_types(nodes:list[Node], node_types_filter:list) -> list[Node]:
217
- filtered_nodes = []
218
- for node in nodes:
219
- if node.title in node_types_filter: #the title represents the type of the node
220
- filtered_nodes.append(node)
221
- return filtered_nodes
222
 
 
 
 
223
  @st.dialog(title="Changer la vue")
224
  def change_view_dialog():
 
 
 
 
225
  st.write("Changer la vue")
226
-
227
  for index, item in enumerate(st.session_state.filter_views.keys()):
228
  emp = st.empty()
229
  col1, col2, col3 = emp.columns([8, 1, 1])
230
 
 
231
  if index > 0 and col2.button("🗑️", key=f"del{index}"):
232
  del st.session_state.filter_views[item]
233
  st.session_state.current_view = "Vue par défaut"
234
  st.rerun()
 
 
235
  but_content = "🔍" if st.session_state.current_view != item else "✅"
236
  if col3.button(but_content, key=f"valid{index}"):
237
  st.session_state.current_view = item
238
  st.rerun()
 
 
239
  if len(st.session_state.filter_views.keys()) > index:
240
  with col1.expander(item):
 
241
  if index > 0:
242
- change_name = st.text_input("Nom de la vue", label_visibility="collapsed", placeholder="Changez le nom de la vue",key=f"change_name{index}")
243
- if st.button("Renommer",key=f"rename{index}"):
244
- if change_name != "":
 
 
 
 
 
245
  st.session_state.filter_views[change_name] = st.session_state.filter_views.pop(item)
246
  st.session_state.current_view = change_name
247
  st.rerun()
248
- st.markdown("\n".join(f"- {label.strip()}" for label in st.session_state.filter_views[item]))
 
 
 
249
  else:
250
  emp.empty()
251
 
 
252
  @st.dialog(title="Ajouter une vue")
253
  def add_view_dialog(filters):
 
 
 
254
  st.write("Ajouter une vue")
255
  view_name = st.text_input("Nom de la vue")
256
- st.markdown("les filtres actuels:")
257
  st.write(filters)
258
  if st.button("Ajouter la vue"):
259
- st.session_state.filter_views[view_name] = filters
260
- st.session_state.current_view = view_name
 
261
  st.rerun()
262
 
 
263
  @st.dialog(title="Changer la couleur")
264
  def change_color_dialog():
 
265
  st.write("Changer la couleur")
266
- for node_type,color in st.session_state.node_types.items():
267
- color = st.color_picker(f"La couleur de l'entité **{node_type.strip()}**",color)
268
- st.session_state.node_types[node_type] = color
 
 
 
269
 
270
  if st.button("Valider"):
271
  st.rerun()
272
 
273
 
 
 
 
274
 
275
  def kg_main():
276
- #st.set_page_config(page_title="Graphe de connaissance", page_icon="", layout="wide")
277
-
278
-
279
-
280
- if "audit" not in st.session_state or st.session_state.audit == {}:
281
- st.error("Veuillez d'abord effectuer un audit pour visualiser le graphe de connaissance.")
282
- return
283
-
284
- if "cr" not in st.session_state:
285
- st.error("Veuillez d'abord effectuer un compte rendu pour visualiser le graphe de connaissance.")
286
- return
287
 
288
  if "graph" not in st.session_state:
289
- st.session_state.graph = None
290
-
 
 
 
 
 
291
  if "filter_views" not in st.session_state:
292
  st.session_state.filter_views = {}
293
  if "current_view" not in st.session_state:
294
  st.session_state.current_view = None
295
-
296
- st.title("Graphe de connaissance")
297
-
298
  if "node_types" not in st.session_state:
299
  st.session_state.node_types = None
300
-
301
- if "summary" not in st.session_state:
302
- st.session_state.summary = None
303
-
304
  if "chat_graph_history" not in st.session_state:
305
  st.session_state.chat_graph_history = []
306
-
307
- audit = st.session_state.audit_simplified
308
- # content = st.session_state.audit["content"]
309
 
310
- # if audit["type de fichier"] == "pdf":
311
- # text = get_text_from_content_for_doc(content)
312
- # elif audit["type de fichier"] == "audio":
313
- # text = get_text_from_content_for_audio(content)
314
 
315
- text = st.session_state.cr + "mots clés" + audit["Mots clés"]
316
-
317
- #summary_prompt = f"Voici un ensemble de documents : {text}. À partir de ces documents, veuillez fournir des résumés concis en vous concentrant sur l'extraction des relations essentielles et des événements. Il est crucial d'inclure les dates des actions ou des événements, car elles seront utilisées pour l'analyse chronologique. Par exemple : 'Sam a été licencié par le conseil d'administration d'OpenAI le 17 novembre 2023 (17 novembre, vendredi)', ce qui illustre la relation entre Sam et OpenAI ainsi que la date de l'événement."
318
-
319
- if st.button("Générer le graphe"):
320
- # with st.spinner("Extractions des relations..."):
321
- # sum = generate_response_openai(summary_prompt,model="gpt-4o")
322
- # st.session_state.summary = sum
323
-
324
- with st.spinner("Génération du graphe..."):
325
- keywords_list = audit["Mots clés"].strip().split(",")
326
- allowed_nodes_types =keywords_list+ ["Person","Organization","Location","Event","Date","Time","Ressource","Concept"]
327
- graph = get_graph(text,allowed_nodes=allowed_nodes_types)
328
- st.session_state.graph = graph
329
-
330
- node_types = get_node_types(graph[0])
331
- nodes_type_dict = list_to_dict_colors(node_types)
332
- st.session_state.node_types = nodes_type_dict
333
  st.session_state.filter_views["Vue par défaut"] = list(node_types)
334
  st.session_state.current_view = "Vue par défaut"
335
 
336
- else:
337
- graph = st.session_state.graph
 
 
 
 
 
 
 
338
 
339
- if graph is not None:
340
- #st.write(graph)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
 
342
- edges,nodes,config = convert_neo4j_to_agraph(graph[0],st.session_state.node_types)
 
 
 
 
343
 
344
- col1, col2 = st.columns([2.5, 1.5])
345
-
346
- with col1.container(border=True,height=800):
347
- st.write("##### Visualisation du graphe (**"+st.session_state.current_view+"**)")
348
- filter_col,add_view_col,change_view_col,color_col = st.columns([9,1,1,1])
349
-
350
- if color_col.button("🎨",help="Changer la couleur"):
351
- change_color_dialog()
352
-
353
- if change_view_col.button("🔍",help="Changer de vue"):
354
- change_view_dialog()
355
-
356
-
357
- #add mots cles to evry label in audit["Mots clés"]
358
- #filter_labels = [ label + " (mot clé)" if label.strip().lower() in audit["Mots clés"].strip().lower().split(",") else label for label in st.session_state.filter_views[st.session_state.current_view] ]
359
- filter = filter_col.multiselect("Filtrer selon l'étiquette",st.session_state.node_types.keys(),placeholder="Sélectionner une ou plusieurs étiquettes",default=st.session_state.filter_views[st.session_state.current_view],label_visibility="collapsed")
360
-
361
- if add_view_col.button("",help="Ajouter une vue"):
362
- add_view_dialog(filter)
363
-
364
- if filter:
365
- nodes = filter_nodes_by_types(nodes,filter)
366
-
367
- selected = display_graph(edges,nodes,config)
368
-
369
- with col2.container(border=True,height=800):
370
- st.markdown("##### Dialoguer avec le graphe")
371
-
372
- user_query = st.chat_input("Par ici ...")
373
- if user_query is not None and user_query != "":
374
- st.session_state.chat_graph_history.append(HumanMessage(content=user_query))
375
-
376
- with st.container(height=650, border=False):
377
- for message in st.session_state.chat_graph_history:
378
- if isinstance(message, AIMessage):
379
- with st.chat_message("AI"):
380
- st.markdown(message.content)
381
- elif isinstance(message, HumanMessage):
382
- with st.chat_message("Moi"):
383
- st.write(message.content)
384
-
385
- #check if last message is human message
386
- if len(st.session_state.chat_graph_history) > 0:
387
- last_message = st.session_state.chat_graph_history[-1]
388
- if isinstance(last_message, HumanMessage):
389
- with st.chat_message("AI"):
390
- retreive = st.session_state.vectorstore.as_retriever()
391
- context = retreive.invoke(last_message.content)
392
- wrapped_prompt = f"Étant donné le contexte suivant {context}, et le graph de connaissance: {graph}, {last_message.content}"
393
- response = st.write_stream(generate_response_via_langchain(wrapped_prompt,stream=True))
394
- st.session_state.chat_graph_history.append(AIMessage(content=response))
395
-
396
- if selected is not None:
397
- with st.chat_message("AI"):
398
- st.markdown(f" EXPLORER LES DONNEES CONTENUES DANS **{selected}**")
399
-
400
- prompts = [f"Extrait moi toutes les informations du noeud ''{selected}'' ➡️",
401
- f"Montre moi les conversations autour du noeud ''{selected}'' ➡️"]
402
-
403
- for i,prompt in enumerate(prompts):
404
- button = st.button(prompt,key=f"p_{i}",on_click=lambda i=i: st.session_state.chat_graph_history.append(HumanMessage(content=prompts[i])))
405
-
406
-
407
-
408
-
409
- node_types = st.session_state.node_types
410
-
 
1
  import streamlit as st
2
+ import pickle
 
 
3
  import random
4
  import math
5
+ from streamlit_agraph import agraph, Node, Edge, Config
6
+
7
+ from utils.kg.construct_kg import get_graph # if still needed for something else
8
+ from utils.audit.rag import get_text_from_content_for_doc, get_text_from_content_for_audio
9
  from utils.audit.response_llm import generate_response_via_langchain
10
+ from utils.audit.rag import get_vectorstore
11
  from langchain_core.messages import AIMessage, HumanMessage
12
  from langchain_core.prompts import PromptTemplate
13
+
14
  from itext2kg.models import KnowledgeGraph
15
 
 
 
 
16
 
17
+ ################################################################################
18
+ # Utility Functions
19
+ ################################################################################
20
 
21
+ def if_node_exists(nodes, node_id):
22
+ """Check if a node with the given id already exists in a list of Node objects."""
 
23
  for node in nodes:
24
  if node.id == node_id:
25
  return True
26
  return False
27
 
28
  def generate_random_color():
29
+ """Generate a random pastel-ish RGB color."""
30
  r = random.randint(180, 255)
31
  g = random.randint(180, 255)
32
  b = random.randint(180, 255)
33
  return (r, g, b)
34
 
35
  def rgb_to_hex(rgb):
36
+ """Convert an (R, G, B) tuple to a hex string like '#aabbcc'."""
37
  return '#{:02x}{:02x}{:02x}'.format(rgb[0], rgb[1], rgb[2])
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def color_distance(color1, color2):
40
+ """Calculate Euclidean distance between two RGB colors."""
41
+ return math.sqrt(
42
+ (color1[0] - color2[0])**2 +
43
+ (color1[1] - color2[1])**2 +
44
+ (color1[2] - color2[2])**2
45
+ )
46
 
47
  def generate_distinct_colors(num_colors, min_distance=30):
48
+ """
49
+ Generate a list of distinct pastel-ish colors (in hex), ensuring each is
50
+ at least `min_distance` away from the others in RGB space.
51
+ """
52
  colors = []
53
  while len(colors) < num_colors:
54
  new_color = generate_random_color()
55
+ if all(color_distance(new_color, existing_color) >= min_distance
56
+ for existing_color in colors):
57
  colors.append(new_color)
58
  return [rgb_to_hex(color) for color in colors]
59
 
60
+ def list_to_dict_colors(node_types):
61
+ """
62
+ Create a dict mapping each node type to a random (distinct) hex color.
63
+ """
64
  number_of_colors = len(node_types)
65
+ color_hexes = generate_distinct_colors(number_of_colors)
66
+ return {typ: color_hexes[i] for i, typ in enumerate(node_types)}
67
 
68
+ def get_node_types_advanced(graph: KnowledgeGraph):
69
+ """
70
+ Extract the set of node labels from an itext2kg KnowledgeGraph.
71
+ (graph.entities have .label, relationships have .startEntity, .endEntity)
72
+ """
73
+ node_types = set()
74
+ dict_node_colors = {}
75
+ for node in graph.entities:
76
+ node_types.add(node.label)
77
+ for relationship in graph.relationships:
78
+ node_types.add(relationship.startEntity.label)
79
+ node_types.add(relationship.endEntity.label)
80
 
81
+ dict_node_colors = {node:rgb_to_hex(generate_random_color()) for node in node_types}
82
+ return node_types, dict_node_colors
83
 
84
+ ################################################################################
85
+ # Graph Conversion
86
+ ################################################################################
87
 
88
+ def get_node_types(graph):
89
+ """
90
+ Extract the set of node types from a graph that has:
91
+ graph.nodes -> [ Node(id, type) ... ]
92
+ graph.relationships -> [ Relationship(source, target, type) ... ]
93
+ """
94
+ node_types = set()
95
+ for node in graph.nodes:
96
+ node_types.add(node.type)
97
+ for rel in graph.relationships:
98
+ node_types.add(rel.source.type)
99
+ node_types.add(rel.target.type)
100
+ return node_types
101
 
102
+ def convert_neo4j_to_agraph(neo4j_graph, node_colors):
103
+ """
104
+ Convert a “Neo4j-like” object into Agraph Nodes & Edges.
105
  """
106
  nodes = []
107
  edges = []
108
 
109
+ # Create nodes
110
  for node in neo4j_graph.nodes:
111
+ node_id = node.id.replace(" ", "_")
 
112
  label = node.id
113
+ type_ = node.type
114
+
115
+ new_node = Node(
116
+ id=node_id,
117
+ title=type_, # 'title' effectively becomes "type"
118
+ label=label,
119
+ size=25,
120
+ shape="circle",
121
+ color=node_colors.get(type_, "#cccccc")
122
+ )
123
+ if not if_node_exists(nodes, node_id):
124
  nodes.append(new_node)
125
 
126
+ # Create edges
127
+ for rel in neo4j_graph.relationships:
128
+ source_id = rel.source.id.replace(" ", "_")
129
+ target_id = rel.target.id.replace(" ", "_")
130
+
131
+ # Ensure nodes exist (if not from the loop above):
132
+ if not if_node_exists(nodes, source_id):
133
+ nodes.append(Node(
134
+ id=source_id,
135
+ title=rel.source.type,
136
+ label=rel.source.id,
137
+ size=25,
138
+ shape="circle",
139
+ color=node_colors.get(rel.source.type, "#cccccc")
140
+ ))
141
+ if not if_node_exists(nodes, target_id):
142
+ nodes.append(Node(
143
+ id=target_id,
144
+ title=rel.target.type,
145
+ label=rel.target.id,
146
+ size=25,
147
+ shape="circle",
148
+ color=node_colors.get(rel.target.type, "#cccccc")
149
+ ))
150
+
151
+ edges.append(Edge(
152
+ source=source_id,
153
+ label=rel.type,
154
+ target=target_id
155
+ ))
156
+
157
+ config = Config(
158
+ width=1200,
159
+ height=800,
160
+ directed=True,
161
+ physics=True,
162
+ hierarchical=True,
163
+ from_json="config.json"
164
+ )
165
  return edges, nodes, config
166
 
167
+ def convert_advanced_neo4j_to_agraph(neo4j_graph: KnowledgeGraph, node_colors):
168
  """
169
+ Same logic as above, but adapted to an itext2kg.models.KnowledgeGraph object
170
+ (graph.entities, graph.relationships).
 
 
 
 
 
 
 
171
  """
172
  nodes = []
173
  edges = []
174
 
175
+ # Create nodes
176
  for node in neo4j_graph.entities:
177
+ node_id = node.name.replace(" ", "_")
 
178
  label = node.name
179
+ type_ = node.label
180
+ new_node = Node(
181
+ id=node_id,
182
+ title=type_,
183
+ label=label,
184
+ size=25,
185
+ shape="circle",
186
+ color=node_colors[type_]
187
+ )
188
+ if not if_node_exists(nodes, new_node.id):
189
+ nodes.append(new_node)
190
 
191
+ # Create edges
192
  for relationship in neo4j_graph.relationships:
 
 
 
193
  source = relationship.startEntity
 
 
 
 
 
 
 
 
194
  target = relationship.endEntity
 
 
 
 
 
 
 
 
 
 
 
195
 
196
+ source_id = source.name.replace(" ", "_")
197
+ target_id = target.name.replace(" ", "_")
198
 
199
+ # Ensure existence of the source node
200
+ if not if_node_exists(nodes, source_id):
201
+ nodes.append(Node(
202
+ id=source_id,
203
+ title=source.label,
204
+ label=source.name,
205
+ size=25,
206
+ shape="circle",
207
+ color=node_colors.get(source.label, "#CCCCCC")
208
+ ))
209
+
210
+ # Ensure existence of the target node
211
+ if not if_node_exists(nodes, target_id):
212
+ nodes.append(Node(
213
+ id=target_id,
214
+ title=target.label,
215
+ label=target.name,
216
+ size=25,
217
+ shape="circle",
218
+ color=node_colors.get(target.label, "#CCCCCC")
219
+ ))
220
+
221
+ edges.append(Edge(
222
+ source=source_id,
223
+ label=relationship.name,
224
+ target=target_id
225
+ ))
226
+
227
+ config = Config(
228
+ width=1200,
229
+ height=800,
230
+ directed=True,
231
+ physics=True,
232
+ hierarchical=True,
233
+ from_json="config.json"
234
+ )
235
  return edges, nodes, config
236
 
237
  def display_graph(edges, nodes, config):
238
+ """Render Agraph."""
239
  return agraph(edges=edges, nodes=nodes, config=config)
240
 
241
 
242
+ def filter_nodes_by_types(nodes, node_types_filter):
243
+ """
244
+ Filter out Agraph nodes by the node’s 'title' field (which is used as 'type' here).
245
+ """
246
+ if not node_types_filter:
247
+ return nodes
248
+ return [node for node in nodes if node.title in node_types_filter]
249
 
 
 
 
 
 
 
250
 
251
+ ################################################################################
252
+ # Dialog Components (same as your original code)
253
+ ################################################################################
254
  @st.dialog(title="Changer la vue")
255
  def change_view_dialog():
256
+ """
257
+ Dialog to rename or delete existing views from st.session_state.filter_views
258
+ and choose the active one (st.session_state.current_view).
259
+ """
260
  st.write("Changer la vue")
 
261
  for index, item in enumerate(st.session_state.filter_views.keys()):
262
  emp = st.empty()
263
  col1, col2, col3 = emp.columns([8, 1, 1])
264
 
265
+ # Delete the view (except for the default if you want)
266
  if index > 0 and col2.button("🗑️", key=f"del{index}"):
267
  del st.session_state.filter_views[item]
268
  st.session_state.current_view = "Vue par défaut"
269
  st.rerun()
270
+
271
+ # Choose the view
272
  but_content = "🔍" if st.session_state.current_view != item else "✅"
273
  if col3.button(but_content, key=f"valid{index}"):
274
  st.session_state.current_view = item
275
  st.rerun()
276
+
277
+ # Show details / rename
278
  if len(st.session_state.filter_views.keys()) > index:
279
  with col1.expander(item):
280
+ # Don’t allow renaming the default view (index=0) if you want
281
  if index > 0:
282
+ change_name = st.text_input(
283
+ "Nom de la vue",
284
+ label_visibility="collapsed",
285
+ placeholder="Changez le nom de la vue",
286
+ key=f"change_name{index}"
287
+ )
288
+ if st.button("Renommer", key=f"rename{index}"):
289
+ if change_name.strip():
290
  st.session_state.filter_views[change_name] = st.session_state.filter_views.pop(item)
291
  st.session_state.current_view = change_name
292
  st.rerun()
293
+ st.markdown(
294
+ "\n".join(f"- {label.strip()}"
295
+ for label in st.session_state.filter_views[item])
296
+ )
297
  else:
298
  emp.empty()
299
 
300
+
301
  @st.dialog(title="Ajouter une vue")
302
  def add_view_dialog(filters):
303
+ """
304
+ Dialog to add a new “view” to st.session_state.filter_views, specifying which types to filter by.
305
+ """
306
  st.write("Ajouter une vue")
307
  view_name = st.text_input("Nom de la vue")
308
+ st.markdown("Les filtres actuels :")
309
  st.write(filters)
310
  if st.button("Ajouter la vue"):
311
+ if view_name.strip():
312
+ st.session_state.filter_views[view_name] = filters
313
+ st.session_state.current_view = view_name
314
  st.rerun()
315
 
316
+
317
  @st.dialog(title="Changer la couleur")
318
  def change_color_dialog():
319
+ """Dialog to interactively change colors of each node type via color pickers."""
320
  st.write("Changer la couleur")
321
+ for node_type, color in st.session_state.node_types.items():
322
+ new_color = st.color_picker(
323
+ f"La couleur de l'entité **{node_type.strip()}**",
324
+ color
325
+ )
326
+ st.session_state.node_types[node_type] = new_color
327
 
328
  if st.button("Valider"):
329
  st.rerun()
330
 
331
 
332
+ ################################################################################
333
+ # Main KG Function
334
+ ################################################################################
335
 
336
  def kg_main():
337
+ # 1. Load your pickles (if not already loaded in session state)
338
+ if "scenes" not in st.session_state:
339
+ with open("./utils/assets/scenes.pkl", "rb") as f:
340
+ st.session_state.scenes = pickle.load(f)
341
+ st.session_state.vectorstore = get_vectorstore(st.session_state.scenes)
 
 
 
 
 
 
342
 
343
  if "graph" not in st.session_state:
344
+ with open("./utils/assets/kg_ia_signature.pkl", "rb") as f:
345
+ # Depending on how you stored it, it might be a tuple (graph, extra_info)
346
+ # or directly a single object. Adjust as needed.
347
+ st.session_state.graph = pickle.load(f)
348
+ print("Graph loaded.")
349
+
350
+ # 2. Initialize other session keys if they don’t exist
351
  if "filter_views" not in st.session_state:
352
  st.session_state.filter_views = {}
353
  if "current_view" not in st.session_state:
354
  st.session_state.current_view = None
 
 
 
355
  if "node_types" not in st.session_state:
356
  st.session_state.node_types = None
 
 
 
 
357
  if "chat_graph_history" not in st.session_state:
358
  st.session_state.chat_graph_history = []
 
 
 
359
 
360
+ st.title("Graphe de connaissance")
 
 
 
361
 
362
+ edges,nodes,config = None, None, None
363
+
364
+ # If we haven’t set up node types yet, do it now
365
+ if st.session_state.node_types is None:
366
+ # st.session_state.graph is presumably a list/tuple => st.session_state.graph[0]
367
+ # Or just st.session_state.graph if you stored it directly as a single obj
368
+ node_types, st.session_state.node_types = get_node_types_advanced(st.session_state.graph)
369
+ # st.write(f"Types d'entités trouvés : {node_types}")
370
+ print("Couleurs attribuées")
371
+ # Initialize a default filter view
 
 
 
 
 
 
 
 
372
  st.session_state.filter_views["Vue par défaut"] = list(node_types)
373
  st.session_state.current_view = "Vue par défaut"
374
 
375
+ # 3. Convert the graph to agraph format
376
+ edges, nodes, config = convert_advanced_neo4j_to_agraph(
377
+ st.session_state.graph, # or st.session_state.graph[0] if needed
378
+ st.session_state.node_types
379
+ )
380
+ print("Graph converti en Agraph")
381
+
382
+ # 4. UI layout: (left) the graph itself, (right) the chat
383
+ col1, col2 = st.columns([3, 1])
384
 
385
+ with col1.container(border=True,height=800):
386
+ st.write(f"#### Visualisation du graphe (**{st.session_state.current_view}**)")
387
+
388
+ filter_col, add_view_col, change_view_col, color_col = st.columns([9, 1, 1, 1])
389
+
390
+ if color_col.button("🎨", help="Changer la couleur"):
391
+ change_color_dialog()
392
+
393
+ if change_view_col.button("🔍", help="Changer de vue"):
394
+ change_view_dialog()
395
+
396
+ # Currently selected filter for the chosen view
397
+ current_filters = st.session_state.filter_views.get(st.session_state.current_view, [])
398
+ filter_selection = filter_col.multiselect(
399
+ "Filtrer selon l'étiquette",
400
+ st.session_state.node_types.keys(),
401
+ default=current_filters,
402
+ label_visibility="collapsed"
403
+ )
404
 
405
+ if add_view_col.button("➕", help="Ajouter une vue"):
406
+ add_view_dialog(filter_selection)
407
+
408
+ # Filter out nodes that don’t match the chosen types
409
+ filtered_nodes = filter_nodes_by_types(nodes, filter_selection)
410
 
411
+ # Render the graph
412
+ selected_node_id = display_graph(edges, filtered_nodes, config)
413
+
414
+ # 5. Chat UI
415
+ with col2.container(border=True,height=800):
416
+ st.markdown("#### Dialoguer avec le graphe")
417
+ user_query = st.chat_input("Votre question ...")
418
+ if user_query:
419
+ st.session_state.chat_graph_history.append(HumanMessage(content=user_query))
420
+
421
+ with st.container():
422
+ # Display the existing chat
423
+ for message in st.session_state.chat_graph_history:
424
+ if isinstance(message, AIMessage):
425
+ with st.chat_message("AI"):
426
+ st.markdown(message.content)
427
+ elif isinstance(message, HumanMessage):
428
+ with st.chat_message("Human"):
429
+ st.write(message.content)
430
+
431
+ # If the last message is from the user, we try to generate a response
432
+ if (len(st.session_state.chat_graph_history) > 0 and
433
+ isinstance(st.session_state.chat_graph_history[-1], HumanMessage)):
434
+ last_message = st.session_state.chat_graph_history[-1]
435
+ with st.chat_message("AI"):
436
+ # Example retrieval (if you have a vectorstore in session state)
437
+ # and want to incorporate scenes or graph data:
438
+ if "vectorstore" in st.session_state:
439
+ retriever = st.session_state.vectorstore.as_retriever()
440
+ context = retriever.invoke(last_message.content)
441
+ prompt = (
442
+ f"Contexte depuis les 'scenes': {st.session_state.scenes}\n"
443
+ f"Contexte vectorstore: {context}\n"
444
+ f"Question: {last_message.content}\n"
445
+ f"Graph: {st.session_state.graph}\n" # If you want to embed your entire graph
446
+ )
447
+ response = st.write_stream(
448
+ generate_response_via_langchain(prompt, stream=True)
449
+ )
450
+ st.session_state.chat_graph_history.append(AIMessage(content=response))
451
+ else:
452
+ # Fallback if no vectorstore
453
+ st.write("Aucune base de vecteurs disponible.")
454
+ st.session_state.chat_graph_history.append(AIMessage(content="(Pas de vectorstore)"))
455
+
456
+ # If the user clicked on a node in the graph, we can propose quick prompts
457
+ if selected_node_id:
458
+ with st.chat_message("AI"):
459
+ st.markdown(f"**Vous avez sélectionné**: `{selected_node_id}`")
460
+ quick_prompts = [
461
+ f"Donne-moi plus d'informations sur le noeud '{selected_node_id}'",
462
+ f"Montre-moi les relations de '{selected_node_id}' dans ce graphe"
463
+ ]
464
+ for i, qprompt in enumerate(quick_prompts):
465
+ if st.button(qprompt, key=f"qp_{i}"):
466
+ st.session_state.chat_graph_history.append(HumanMessage(content=qprompt))
467
+
468
+ kg_main()
 
 
 
 
 
 
 
 
 
test.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
utils/assets/kg_ia_signature.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:55b49436038a45405798f6d05591464b1a35360409d83dbead163921707ac592
3
+ size 7354091
utils/assets/scenes.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:17fc4636b752c5b8f1434d0c97c95ea3b12605b083689e6d79daacd060f6c110
3
+ size 142917