wagnercosta commited on
Commit
94bf1e0
1 Parent(s): b6cf9eb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +207 -4
app.py CHANGED
@@ -1,7 +1,210 @@
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
  import gradio as gr
3
+ from phi3_instruct_graph import MODEL_LIST, Phi3InstructGraph
4
+ from textwrap import dedent
5
+ import rapidjson
6
+ import spaces
7
+ from pyvis.network import Network
8
+ import networkx as nx
9
+ import spacy
10
+ from spacy import displacy
11
+ from spacy.tokens import Span
12
+ import random
13
 
14
+ json_example = {'nodes': [{'id': 'Aerosmith', 'type': 'organization', 'detailed_type': 'rock band'}, {'id': 'Steven Tyler', 'type': 'person', 'detailed_type': 'lead singer'}, {'id': 'vocal cord injury', 'type': 'medical condition', 'detailed_type': 'fractured larynx'}, {'id': 'retirement', 'type': 'event', 'detailed_type': 'announcement'}, {'id': 'touring', 'type': 'activity', 'detailed_type': 'musical performance'}, {'id': 'September 2023', 'type': 'date', 'detailed_type': 'specific time'}], 'edges': [{'from': 'Aerosmith', 'to': 'Steven Tyler', 'label': 'led by'}, {'from': 'Steven Tyler', 'to': 'vocal cord injury', 'label': 'suffered'}, {'from': 'vocal cord injury', 'to': 'retirement', 'label': 'caused'}, {'from': 'retirement', 'to': 'touring', 'label': 'ended'}, {'from': 'vocal cord injury', 'to': 'September 2023', 'label': 'occurred in'}]}
 
15
 
16
+ @spaces.GPU
17
+ def extract(text, model):
18
+ model = Phi3InstructGraph(model=model)
19
+ result = model.extract(text)
20
+ return rapidjson.loads(result)
21
+
22
+ def handle_text(text):
23
+ return " ".join(text.split())
24
+
25
+ def get_random_color():
26
+ return f"#{random.randint(0, 0xFFFFFF):06x}"
27
+
28
+ def get_random_light_color():
29
+ # Generate higher RGB values to ensure a lighter color
30
+ r = random.randint(128, 255)
31
+ g = random.randint(128, 255)
32
+ b = random.randint(128, 255)
33
+ return f"#{r:02x}{g:02x}{b:02x}"
34
+
35
+ def get_random_color():
36
+ return f"#{random.randint(0, 0xFFFFFF):06x}"
37
+
38
+ def find_token_indices(doc, substring, text):
39
+ result = []
40
+ start_index = text.find(substring)
41
+
42
+ while start_index != -1:
43
+ end_index = start_index + len(substring)
44
+ start_token = None
45
+ end_token = None
46
+
47
+ for token in doc:
48
+ if token.idx == start_index:
49
+ start_token = token.i
50
+ if token.idx + len(token) == end_index:
51
+ end_token = token.i + 1
52
+
53
+ if start_token is None or end_token is None:
54
+ print(f"Token boundaries not found for '{substring}' at index {start_index}")
55
+ else:
56
+ result.append({
57
+ "start": start_token,
58
+ "end": end_token
59
+ })
60
+
61
+ # Search for next occurrence
62
+ start_index = text.find(substring, end_index)
63
+
64
+ if not result:
65
+ print(f"Token boundaries not found for '{substring}'")
66
+
67
+ return result
68
+
69
+
70
+ def create_custom_entity_viz(data, full_text):
71
+ nlp = spacy.blank("xx")
72
+ doc = nlp(full_text)
73
+
74
+ spans = []
75
+ colors = {}
76
+ for node in data["nodes"]:
77
+ # entity_spans = [m.span() for m in re.finditer(re.escape(node["id"]), full_text)]
78
+ entity_spans = find_token_indices(doc, node["id"], full_text)
79
+ for dataentity in entity_spans:
80
+ start = dataentity["start"]
81
+ end = dataentity["end"]
82
+
83
+ print("entity spans:", entity_spans)
84
+ if start < len(doc) and end <= len(doc):
85
+ span = Span(doc, start, end, label=node["type"])
86
+
87
+ # print(span)
88
+ spans.append(span)
89
+ if node["type"] not in colors:
90
+ colors[node["type"]] = get_random_light_color()
91
+
92
+ for span in spans:
93
+ print(f"Span: {span.text}, Label: {span.label_}")
94
+
95
+ doc.set_ents(spans, default="unmodified")
96
+ doc.spans["sc"] = spans
97
+
98
+ options = {
99
+ "colors": colors,
100
+ "ents": list(colors.keys()),
101
+ "style": "ent",
102
+ "manual": True
103
+ }
104
+
105
+ html = displacy.render(doc, style="span", options=options)
106
+ return html
107
+
108
+
109
+ def create_graph(json_data):
110
+ G = nx.Graph()
111
+
112
+ for node in json_data['nodes']:
113
+ G.add_node(node['id'], title=f"{node['type']}: {node['detailed_type']}")
114
+
115
+ for edge in json_data['edges']:
116
+ G.add_edge(edge['from'], edge['to'], title=edge['label'], label=edge['label'])
117
+
118
+ nt = Network(
119
+ width="720px",
120
+ height="600px",
121
+ directed=True,
122
+ notebook=False,
123
+ # bgcolor="#111827",
124
+ # font_color="white"
125
+ bgcolor="#FFFFFF",
126
+ font_color="#111827"
127
+ )
128
+ nt.from_nx(G)
129
+ nt.barnes_hut(
130
+ gravity=-3000,
131
+ central_gravity=0.3,
132
+ spring_length=50,
133
+ spring_strength=0.001,
134
+ damping=0.09,
135
+ overlap=0,
136
+ )
137
+
138
+ # Customize edge appearance
139
+ # for edge in nt.edges:
140
+ # edge['font'] = {'size': 12, 'color': '#FFD700', 'face': 'Arial'} # Removed strokeWidth
141
+ # edge['color'] = {'color': '#FF4500', 'highlight': '#FF4500'}
142
+ # edge['width'] = 1
143
+ # edge['arrows'] = {'to': {'enabled': True, 'type': 'arrow'}}
144
+ # edge['smooth'] = {'type': 'curvedCW', 'roundness': 0.2}
145
+
146
+ html = nt.generate_html()
147
+ # need to remove ' from HTML
148
+ html = html.replace("'", '"')
149
+ # return html
150
+
151
+ return f"""<iframe style="width: 140%; height: 620px; margin: 0 auto;" name="result"
152
+ allow="midi; geolocation; microphone; camera; display-capture; encrypted-media;"
153
+ sandbox="allow-modals allow-forms allow-scripts allow-same-origin allow-popups
154
+ allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
155
+ allowpaymentrequest="" frameborder="0" srcdoc='{html}'></iframe>"""
156
+
157
+
158
+ def process_and_visualize(text, model):
159
+ if not text or not model:
160
+ raise gr.Error("Text and model must be provided.")
161
+ json_data = extract(text, model)
162
+ # json_data = json_example
163
+ print(json_data)
164
+ entities_viz = create_custom_entity_viz(json_data, text)
165
+
166
+ graph_html = create_graph(json_data)
167
+ return graph_html, entities_viz, json_data
168
+
169
+
170
+
171
+ with gr.Blocks(title="Phi-3 Mini 4k Instruct Graph (by Emergent Methods") as demo:
172
+ gr.Markdown("# Phi-3 Mini 4k Instruct Graph (by Emergent Methods)")
173
+ gr.Markdown("Extract a JSON graph from a text input and visualize it.")
174
+
175
+ with gr.Row():
176
+ with gr.Column(scale=1):
177
+ input_model = gr.Dropdown(
178
+ MODEL_LIST, label="Model",
179
+ # value=MODEL_LIST[0]
180
+ )
181
+ input_text = gr.TextArea(label="Text", info="The text to be extracted")
182
+
183
+ examples = gr.Examples(
184
+ examples=[
185
+ handle_text("""Legendary rock band Aerosmith has officially announced their retirement from touring after 54 years, citing
186
+ lead singer Steven Tyler's unrecoverable vocal cord injury.
187
+ The decision comes after months of unsuccessful treatment for Tyler's fractured larynx,
188
+ which he suffered in September 2023."""),
189
+ handle_text("""Pop star Justin Timberlake, 43, had his driver's license suspended by a New York judge during a virtual
190
+ court hearing on August 2, 2024. The suspension follows Timberlake's arrest for driving while intoxicated (DWI)
191
+ in Sag Harbor on June 18. Timberlake, who is currently on tour in Europe,
192
+ pleaded not guilty to the charges."""),
193
+ ],
194
+ inputs=input_text
195
+ )
196
+
197
+ submit_button = gr.Button("Extract and Visualize")
198
+
199
+ with gr.Column(scale=1):
200
+ output_entity_viz = gr.HTML(label="Entities Visualization", show_label=True)
201
+ output_graph = gr.HTML(label="Graph Visualization", show_label=True)
202
+ # output_json = gr.JSON(label="JSON Graph")
203
+
204
+ submit_button.click(
205
+ fn=process_and_visualize,
206
+ inputs=[input_text, input_model],
207
+ outputs=[output_graph, output_entity_viz]
208
+ )
209
+
210
+ demo.launch(share=False)