ginipick commited on
Commit
fcc0582
·
verified ·
1 Parent(s): 3072175

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +158 -85
app.py CHANGED
@@ -1,37 +1,46 @@
1
  import spaces
2
  import gradio as gr
3
  from phi3_instruct_graph import MODEL_LIST, Phi3InstructGraph
4
- from textwrap import dedent
5
  import rapidjson
6
- import spaces
7
  from pyvis.network import Network
8
  import networkx as nx
9
  import spacy
10
  from spacy import displacy
11
  from spacy.tokens import Span
12
  import random
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- @spaces.GPU
15
- def extract(text, model):
16
- model = Phi3InstructGraph(model=model)
17
- result = model.extract(text)
18
- return rapidjson.loads(result)
19
-
20
  def handle_text(text):
21
  return " ".join(text.split())
22
 
23
- def get_random_color():
24
- return f"#{random.randint(0, 0xFFFFFF):06x}"
25
-
26
- def get_random_light_color():
27
- # Generate higher RGB values to ensure a lighter color
28
- r = random.randint(128, 255)
29
- g = random.randint(128, 255)
30
- b = random.randint(128, 255)
31
- return f"#{r:02x}{g:02x}{b:02x}"
32
-
33
- def get_random_color():
34
- return f"#{random.randint(0, 0xFFFFFF):06x}"
35
 
36
  def find_token_indices(doc, substring, text):
37
  result = []
@@ -48,9 +57,7 @@ def find_token_indices(doc, substring, text):
48
  if token.idx + len(token) == end_index:
49
  end_token = token.i + 1
50
 
51
- if start_token is None or end_token is None:
52
- print(f"Token boundaries not found for '{substring}' at index {start_index}")
53
- else:
54
  result.append({
55
  "start": start_token,
56
  "end": end_token
@@ -59,12 +66,8 @@ def find_token_indices(doc, substring, text):
59
  # Search for next occurrence
60
  start_index = text.find(substring, end_index)
61
 
62
- if not result:
63
- print(f"Token boundaries not found for '{substring}'")
64
-
65
  return result
66
 
67
-
68
  def create_custom_entity_viz(data, full_text):
69
  nlp = spacy.blank("xx")
70
  doc = nlp(full_text)
@@ -82,8 +85,6 @@ def create_custom_entity_viz(data, full_text):
82
  overlapping = any(s.start < end and start < s.end for s in spans)
83
  if not overlapping:
84
  span = Span(doc, start, end, label=node["type"])
85
-
86
- # print(span)
87
  spans.append(span)
88
  if node["type"] not in colors:
89
  colors[node["type"]] = get_random_light_color()
@@ -101,26 +102,28 @@ def create_custom_entity_viz(data, full_text):
101
  html = displacy.render(doc, style="span", options=options)
102
  return html
103
 
104
-
105
  def create_graph(json_data):
106
  G = nx.Graph()
107
 
 
108
  for node in json_data['nodes']:
109
  G.add_node(node['id'], title=f"{node['type']}: {node['detailed_type']}")
110
 
 
111
  for edge in json_data['edges']:
112
  G.add_edge(edge['from'], edge['to'], title=edge['label'], label=edge['label'])
113
 
 
114
  nt = Network(
115
  width="720px",
116
  height="600px",
117
  directed=True,
118
  notebook=False,
119
- bgcolor="#111827",
120
- font_color="white"
121
- # bgcolor="#FFFFFF",
122
- # font_color="#111827"
123
  )
 
 
124
  nt.from_nx(G)
125
  nt.barnes_hut(
126
  gravity=-3000,
@@ -130,71 +133,141 @@ def create_graph(json_data):
130
  damping=0.09,
131
  overlap=0,
132
  )
133
-
134
  # Customize edge appearance
135
- # for edge in nt.edges:
136
- # edge['font'] = {'size': 12, 'color': '#FFD700', 'face': 'Arial'} # Removed strokeWidth
137
- # edge['color'] = {'color': '#FF4500', 'highlight': '#FF4500'}
138
- # edge['width'] = 1
139
- # edge['arrows'] = {'to': {'enabled': True, 'type': 'arrow'}}
140
- # edge['smooth'] = {'type': 'curvedCW', 'roundness': 0.2}
141
-
 
 
 
 
 
 
 
142
  html = nt.generate_html()
143
  html = html.replace("'", '"')
144
 
145
- return f"""<iframe style="width: 140%; height: 620px; margin: 0 auto;" name="result"
146
- allow="midi; geolocation; microphone; camera; display-capture; encrypted-media;"
147
  sandbox="allow-modals allow-forms allow-scripts allow-same-origin allow-popups
148
  allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
149
  allowpaymentrequest="" frameborder="0" srcdoc='{html}'></iframe>"""
150
-
151
 
152
- def process_and_visualize(text, model):
153
  if not text or not model:
154
- raise gr.Error("Text and model must be provided.")
 
 
155
  json_data = extract(text, model)
 
 
156
  entities_viz = create_custom_entity_viz(json_data, text)
157
 
 
158
  graph_html = create_graph(json_data)
159
- return graph_html, entities_viz, json_data
160
-
161
-
162
-
163
- with gr.Blocks(title="Phi-3 Instruct Graph (by Emergent Methods") as demo:
164
- gr.Markdown("# Phi-3 Instruct Graph (by Emergent Methods)")
165
- gr.Markdown("Extract a JSON graph from a text input and visualize it.")
166
- with gr.Row():
167
- with gr.Column(scale=1):
168
- input_model = gr.Dropdown(
169
- MODEL_LIST, label="Model",
170
- )
171
- input_text = gr.TextArea(label="Text", info="The text to be extracted")
172
-
173
- examples = gr.Examples(
174
- examples=[
175
- handle_text("""Legendary rock band Aerosmith has officially announced their retirement from touring after 54 years, citing
176
- lead singer Steven Tyler's unrecoverable vocal cord injury.
177
- The decision comes after months of unsuccessful treatment for Tyler's fractured larynx,
178
- which he suffered in September 2023."""),
179
- handle_text("""Pop star Justin Timberlake, 43, had his driver's license suspended by a New York judge during a virtual
180
- court hearing on August 2, 2024. The suspension follows Timberlake's arrest for driving while intoxicated (DWI)
181
- in Sag Harbor on June 18. Timberlake, who is currently on tour in Europe,
182
- pleaded not guilty to the charges."""),
183
- ],
184
- inputs=input_text
185
- )
186
-
187
- submit_button = gr.Button("Extract and Visualize")
188
 
189
- with gr.Column(scale=1):
190
- output_entity_viz = gr.HTML(label="Entities Visualization", show_label=True)
191
- output_graph = gr.HTML(label="Graph Visualization", show_label=True)
192
-
193
- submit_button.click(
194
- fn=process_and_visualize,
195
- inputs=[input_text, input_model],
196
- outputs=[output_graph, output_entity_viz]
197
- )
 
 
 
 
 
 
 
 
 
198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
 
200
  demo.launch(share=False)
 
1
  import spaces
2
  import gradio as gr
3
  from phi3_instruct_graph import MODEL_LIST, Phi3InstructGraph
 
4
  import rapidjson
 
5
  from pyvis.network import Network
6
  import networkx as nx
7
  import spacy
8
  from spacy import displacy
9
  from spacy.tokens import Span
10
  import random
11
+ from tqdm import tqdm
12
+
13
+ # Constants
14
+ TITLE = "🌐 Phi-3 Instruct Graph Explorer"
15
+ SUBTITLE = "✨ Extract and visualize knowledge graphs from any text in multiple languages"
16
+ THEME = gr.themes.Base().set(
17
+ primary_hue="indigo",
18
+ secondary_hue="purple",
19
+ neutral_hue="slate",
20
+ radius_size=gr.themes.sizes.radius_sm,
21
+ shadow_size=gr.themes.sizes.shadow_lg,
22
+ )
23
+
24
+ # Color utilities
25
+ def get_random_light_color():
26
+ r = random.randint(140, 255)
27
+ g = random.randint(140, 255)
28
+ b = random.randint(140, 255)
29
+ return f"#{r:02x}{g:02x}{b:02x}"
30
 
31
+ # Text preprocessing
 
 
 
 
 
32
  def handle_text(text):
33
  return " ".join(text.split())
34
 
35
+ # Main processing functions
36
+ @spaces.GPU
37
+ def extract(text, model):
38
+ try:
39
+ model = Phi3InstructGraph(model=model)
40
+ result = model.extract(text)
41
+ return rapidjson.loads(result)
42
+ except Exception as e:
43
+ raise gr.Error(f"Extraction error: {str(e)}")
 
 
 
44
 
45
  def find_token_indices(doc, substring, text):
46
  result = []
 
57
  if token.idx + len(token) == end_index:
58
  end_token = token.i + 1
59
 
60
+ if start_token is not None and end_token is not None:
 
 
61
  result.append({
62
  "start": start_token,
63
  "end": end_token
 
66
  # Search for next occurrence
67
  start_index = text.find(substring, end_index)
68
 
 
 
 
69
  return result
70
 
 
71
  def create_custom_entity_viz(data, full_text):
72
  nlp = spacy.blank("xx")
73
  doc = nlp(full_text)
 
85
  overlapping = any(s.start < end and start < s.end for s in spans)
86
  if not overlapping:
87
  span = Span(doc, start, end, label=node["type"])
 
 
88
  spans.append(span)
89
  if node["type"] not in colors:
90
  colors[node["type"]] = get_random_light_color()
 
102
  html = displacy.render(doc, style="span", options=options)
103
  return html
104
 
 
105
  def create_graph(json_data):
106
  G = nx.Graph()
107
 
108
+ # Add nodes with tooltips
109
  for node in json_data['nodes']:
110
  G.add_node(node['id'], title=f"{node['type']}: {node['detailed_type']}")
111
 
112
+ # Add edges with labels
113
  for edge in json_data['edges']:
114
  G.add_edge(edge['from'], edge['to'], title=edge['label'], label=edge['label'])
115
 
116
+ # Create network visualization
117
  nt = Network(
118
  width="720px",
119
  height="600px",
120
  directed=True,
121
  notebook=False,
122
+ bgcolor="#f8fafc",
123
+ font_color="#1e293b"
 
 
124
  )
125
+
126
+ # Configure network display
127
  nt.from_nx(G)
128
  nt.barnes_hut(
129
  gravity=-3000,
 
133
  damping=0.09,
134
  overlap=0,
135
  )
136
+
137
  # Customize edge appearance
138
+ for edge in nt.edges:
139
+ edge['width'] = 2
140
+ edge['arrows'] = {'to': {'enabled': True, 'type': 'arrow'}}
141
+ edge['color'] = {'color': '#6366f1', 'highlight': '#4f46e5'}
142
+ edge['font'] = {'size': 12, 'color': '#4b5563', 'face': 'Arial'}
143
+
144
+ # Customize node appearance
145
+ for node in nt.nodes:
146
+ node['color'] = {'background': '#e0e7ff', 'border': '#6366f1', 'highlight': {'background': '#c7d2fe', 'border': '#4f46e5'}}
147
+ node['font'] = {'size': 14, 'color': '#1e293b'}
148
+ node['shape'] = 'dot'
149
+ node['size'] = 25
150
+
151
+ # Generate HTML with iframe to isolate styles
152
  html = nt.generate_html()
153
  html = html.replace("'", '"')
154
 
155
+ return f"""<iframe style="width: 100%; height: 620px; margin: 0 auto; border-radius: 8px; box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06);"
156
+ name="result" allow="midi; geolocation; microphone; camera; display-capture; encrypted-media;"
157
  sandbox="allow-modals allow-forms allow-scripts allow-same-origin allow-popups
158
  allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
159
  allowpaymentrequest="" frameborder="0" srcdoc='{html}'></iframe>"""
 
160
 
161
+ def process_and_visualize(text, model, progress=gr.Progress()):
162
  if not text or not model:
163
+ raise gr.Error("⚠️ Both text and model must be provided.")
164
+
165
+ progress(0, desc="Starting extraction...")
166
  json_data = extract(text, model)
167
+
168
+ progress(0.5, desc="Creating entity visualization...")
169
  entities_viz = create_custom_entity_viz(json_data, text)
170
 
171
+ progress(0.8, desc="Building knowledge graph...")
172
  graph_html = create_graph(json_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
+ node_count = len(json_data["nodes"])
175
+ edge_count = len(json_data["edges"])
176
+ stats = f"📊 Extracted {node_count} entities and {edge_count} relationships"
177
+
178
+ progress(1.0, desc="Complete!")
179
+ return graph_html, entities_viz, json_data, stats
180
+
181
+ # Example texts in different languages
182
+ EXAMPLES = [
183
+ [handle_text("""Legendary rock band Aerosmith has officially announced their retirement from touring after 54 years, citing
184
+ lead singer Steven Tyler's unrecoverable vocal cord injury.
185
+ The decision comes after months of unsuccessful treatment for Tyler's fractured larynx,
186
+ which he suffered in September 2023.""")],
187
+
188
+ [handle_text("""Pop star Justin Timberlake, 43, had his driver's license suspended by a New York judge during a virtual
189
+ court hearing on August 2, 2024. The suspension follows Timberlake's arrest for driving while intoxicated (DWI)
190
+ in Sag Harbor on June 18. Timberlake, who is currently on tour in Europe,
191
+ pleaded not guilty to the charges.""")],
192
 
193
+ [handle_text("""세계적인 기술 기업 삼성전자는 새로운 인공지능 기반 스마트폰을 올해 하반기에 출시할 예정이라고 발표했다.
194
+ 이 스마트폰은 현재 개발 중인 갤럭시 시리즈의 최신작으로, 강력한 AI 기능과 혁신적인 카메라 시스템을 탑재할 것으로 알려졌다.
195
+ 삼성전자의 CEO는 이번 신제품이 스마트폰 시장에 새로운 혁신을 가져올 것이라고 전망했다.""")],
196
+
197
+ [handle_text("""한국 영화 '기생충'은 2020년 아카데미 시상식에서 작품상, 감독상, 각본상, 국제영화상 등 4개 부문을 수상하며 역사를 새로 썼다.
198
+ 봉준호 감독이 연출한 이 영화는 한국 영화 최초로 칸 영화제 황금종려상도 수상했으며, 전 세계적으로 엄청난 흥행과
199
+ 평단의 호평을 받았다.""")]
200
+ ]
201
+
202
+ def create_ui():
203
+ with gr.Blocks(theme=THEME, title=TITLE) as demo:
204
+ # Header
205
+ gr.Markdown(f"# {TITLE}")
206
+ gr.Markdown(f"{SUBTITLE}")
207
+
208
+ with gr.Row():
209
+ gr.Markdown("🌍 **Multilingual Support Available** 🔤")
210
+
211
+ # Main interface
212
+ with gr.Row():
213
+ # Input column
214
+ with gr.Column(scale=1):
215
+ input_model = gr.Dropdown(
216
+ MODEL_LIST,
217
+ label="🤖 Select Model",
218
+ info="Choose a model to process your text",
219
+ value=MODEL_LIST[0] if MODEL_LIST else None
220
+ )
221
+
222
+ input_text = gr.TextArea(
223
+ label="📝 Input Text",
224
+ info="Enter text in any language to extract a knowledge graph",
225
+ placeholder="Enter text here...",
226
+ lines=10
227
+ )
228
+
229
+ with gr.Row():
230
+ submit_button = gr.Button("🚀 Extract & Visualize", variant="primary", scale=2)
231
+ clear_button = gr.Button("🔄 Clear", variant="secondary", scale=1)
232
+
233
+ gr.Examples(
234
+ examples=EXAMPLES,
235
+ inputs=input_text,
236
+ label="📚 Example Texts (English & Korean)"
237
+ )
238
+
239
+ stats_output = gr.Markdown("", label="🔍 Analysis Results")
240
+
241
+ # Output column
242
+ with gr.Column(scale=1):
243
+ with gr.Tab("🧩 Knowledge Graph"):
244
+ output_graph = gr.HTML(label="")
245
+
246
+ with gr.Tab("🏷️ Entities"):
247
+ output_entity_viz = gr.HTML(label="")
248
+
249
+ with gr.Tab("📊 JSON Data"):
250
+ output_json = gr.JSON(label="")
251
+
252
+ # Functionality
253
+ submit_button.click(
254
+ fn=process_and_visualize,
255
+ inputs=[input_text, input_model],
256
+ outputs=[output_graph, output_entity_viz, output_json, stats_output]
257
+ )
258
+
259
+ clear_button.click(
260
+ fn=lambda: [None, None, None, ""],
261
+ inputs=[],
262
+ outputs=[output_graph, output_entity_viz, output_json, stats_output]
263
+ )
264
+
265
+ # Footer
266
+ gr.Markdown("---")
267
+ gr.Markdown("📋 **Instructions:** Enter text in any language, select a model, and click 'Extract & Visualize' to generate a knowledge graph.")
268
+ gr.Markdown("🛠️ Powered by Phi-3 Instruct Graph | Created by Emergent Methods")
269
+
270
+ return demo
271
 
272
+ demo = create_ui()
273
  demo.launch(share=False)