nicpopovic commited on
Commit
843e384
·
verified ·
1 Parent(s): 7271c04

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +679 -0
app.py ADDED
@@ -0,0 +1,679 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, STOKEStreamer
3
+ from threading import Thread
4
+ import json
5
+ import torch
6
+ import os
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+ from matplotlib.colors import to_hex
10
+ import itertools
11
+ import transformers
12
+ import time
13
+ transformers.logging.set_verbosity_error()
14
+
15
+
16
+ # Variable to define number of instances
17
+ n_instances = 1
18
+
19
+ gpu_name = "CPU"
20
+
21
+ for i in range(torch.cuda.device_count()):
22
+ gpu_name = torch.cuda.get_device_properties(i).name
23
+
24
+ # Reusing the original MLP class and other functions (unchanged) except those specific to Streamlit
25
+ class MLP(torch.nn.Module):
26
+ def __init__(self, input_dim, output_dim, hidden_dim=1024, layer_id=0, cuda=False):
27
+ super(MLP, self).__init__()
28
+ self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
29
+ self.fc3 = torch.nn.Linear(hidden_dim, output_dim)
30
+ self.layer_id = layer_id
31
+ if cuda:
32
+ self.device = "cuda"
33
+ else:
34
+ self.device = "cpu"
35
+ self.to(self.device)
36
+
37
+ def forward(self, x):
38
+ x = torch.flatten(x, start_dim=1)
39
+ x = torch.relu(self.fc1(x))
40
+ x = self.fc3(x)
41
+ return torch.argmax(x, dim=-1).cpu().detach(), torch.softmax(x, dim=-1).cpu().detach()
42
+
43
+ def map_value_to_color(value, colormap_name='tab20c'):
44
+ value = np.clip(value, 0.0, 1.0)
45
+ colormap = plt.get_cmap(colormap_name)
46
+ rgba_color = colormap(value)
47
+ css_color = to_hex(rgba_color)
48
+ return css_color
49
+
50
+
51
+ # Caching functions for model and classifier
52
+ model_cache = {}
53
+
54
+ def get_multiple_model_and_tokenizer(name, n_instances):
55
+ model_instances = []
56
+ for _ in range(n_instances):
57
+ tok = AutoTokenizer.from_pretrained(name, token=os.getenv('HF_TOKEN'), pad_token_id=128001)
58
+ model = AutoModelForCausalLM.from_pretrained(name, token=os.getenv('HF_TOKEN'), torch_dtype="bfloat16", pad_token_id=128001, device_map="auto")
59
+ if torch.cuda.is_available():
60
+ model.cuda()
61
+ model_instances.append((model, tok))
62
+ return model_instances
63
+
64
+ def get_classifiers_for_model(att_size, emb_size, device, config_paths):
65
+ config = {
66
+ "classifier_token": json.load(open(os.path.join(config_paths["classifier_token"], "config.json"), "r")),
67
+ "classifier_span": json.load(open(os.path.join(config_paths["classifier_span"], "config.json"), "r"))
68
+ }
69
+ layer_id = config["classifier_token"]["layer"]
70
+
71
+ classifier_span = MLP(att_size, 2, hidden_dim=config["classifier_span"]["classifier_dim"]).to(device)
72
+ classifier_span.load_state_dict(torch.load(os.path.join(config_paths["classifier_span"], "checkpoint.pt"), map_location=device, weights_only=True))
73
+
74
+ classifier_token = MLP(emb_size, len(config["classifier_token"]["label_map"]), layer_id=layer_id, hidden_dim=config["classifier_token"]["classifier_dim"]).to(device)
75
+ classifier_token.load_state_dict(torch.load(os.path.join(config_paths["classifier_token"], "checkpoint.pt"), map_location=device, weights_only=True))
76
+
77
+ return classifier_span, classifier_token, config["classifier_token"]["label_map"]
78
+
79
+ def find_datasets_and_model_ids(root_dir):
80
+ datasets = {}
81
+ for root, dirs, files in os.walk(root_dir):
82
+ if 'config.json' in files and 'stoke_config.json' in files:
83
+ config_path = os.path.join(root, 'config.json')
84
+ stoke_config_path = os.path.join(root, 'stoke_config.json')
85
+
86
+ with open(config_path, 'r') as f:
87
+ config_data = json.load(f)
88
+ model_id = config_data.get('model_id')
89
+ if model_id:
90
+ dataset_name = os.path.basename(os.path.dirname(config_path))
91
+
92
+ with open(stoke_config_path, 'r') as f:
93
+ stoke_config_data = json.load(f)
94
+ if model_id:
95
+ dataset_name = os.path.basename(os.path.dirname(stoke_config_path))
96
+ datasets.setdefault(model_id, {})[dataset_name] = stoke_config_data
97
+ return datasets
98
+
99
+ def filter_spans(spans_and_values):
100
+ if spans_and_values == []:
101
+ return [], []
102
+ # Create a dictionary to store spans based on their second index values
103
+ span_dict = {}
104
+
105
+ spans, values = [x[0] for x in spans_and_values], [x[1] for x in spans_and_values]
106
+
107
+ # Iterate through the spans and update the dictionary with the highest value
108
+ for span, value in zip(spans, values):
109
+ start, end = span
110
+ if start > end or end - start > 15 or start == 0:
111
+ continue
112
+ current_value = span_dict.get(end, None)
113
+
114
+ if current_value is None or current_value[1] < value:
115
+ span_dict[end] = (span, value)
116
+
117
+ if span_dict == {}:
118
+ return [], []
119
+ # Extract the filtered spans and values
120
+ filtered_spans, filtered_values = zip(*span_dict.values())
121
+
122
+ return list(filtered_spans), list(filtered_values)
123
+
124
+ def remove_overlapping_spans(spans):
125
+ # Sort the spans based on their end points
126
+ sorted_spans = sorted(spans, key=lambda x: x[0][1])
127
+
128
+ non_overlapping_spans = []
129
+ last_end = float('-inf')
130
+
131
+ # Iterate through the sorted spans
132
+ for span in sorted_spans:
133
+ start, end = span[0]
134
+ value = span[1]
135
+
136
+ # If the current span does not overlap with the previous one
137
+ if start >= last_end:
138
+ non_overlapping_spans.append(span)
139
+ last_end = end
140
+ else:
141
+ # If it overlaps, choose the one with the highest value
142
+ existing_span_index = -1
143
+ for i, existing_span in enumerate(non_overlapping_spans):
144
+ if existing_span[0][1] <= start:
145
+ existing_span_index = i
146
+ break
147
+ if existing_span_index != -1 and non_overlapping_spans[existing_span_index][1] < value:
148
+ non_overlapping_spans[existing_span_index] = span
149
+
150
+ return non_overlapping_spans
151
+
152
+ def generate_html_no_overlap(tokenized_text, spans):
153
+ current_index = 0
154
+ html_content = ""
155
+
156
+ for (span_start, span_end), value in spans:
157
+ # Add text before the span
158
+ html_content += "".join(tokenized_text[current_index:span_start])
159
+
160
+ # Add the span with underlining
161
+ html_content += "<b><u>"
162
+ html_content += "".join(tokenized_text[span_start:span_end])
163
+ html_content += "</u></b> "
164
+
165
+ current_index = span_end
166
+
167
+ # Add any remaining text after the last span
168
+ html_content += "".join(tokenized_text[current_index:])
169
+
170
+ return html_content
171
+
172
+
173
+ css = """
174
+ <style>
175
+ .prose {
176
+ line-height: 200%;
177
+ }
178
+ .highlight {
179
+ display: inline;
180
+ }
181
+ .highlight::after {
182
+ background-color: var(data-color);
183
+ }
184
+ .spanhighlight {
185
+ padding: 2px 5px;
186
+ border-radius: 5px;
187
+ }
188
+ .tooltip {
189
+ position: relative;
190
+ display: inline-block;
191
+ }
192
+ .generated-content {
193
+ overflow: scroll;
194
+ height: 100%;
195
+ }
196
+ .tooltip::after {
197
+ content: attr(data-tooltip-text); /* Set content from data-tooltip-text attribute */
198
+ display: none;
199
+ position: absolute;
200
+ background-color: #333;
201
+ color: #fff;
202
+ padding: 5px;
203
+ border-radius: 5px;
204
+ bottom: 100%; /* Position it above the element */
205
+ left: 50%;
206
+ transform: translateX(-50%);
207
+ width: auto;
208
+ min-width: 120px;
209
+ margin: 0 auto;
210
+ text-align: center;
211
+ }
212
+
213
+ .tooltip:hover::after {
214
+ display: block; /* Show the tooltip on hover */
215
+ }
216
+
217
+ .small-text {
218
+ padding: 2px 5px;
219
+ background-color: white;
220
+ border-radius: 5px;
221
+ font-size: xx-small;
222
+ margin-left: 0.5em;
223
+ vertical-align: 0.2em;
224
+ font-weight: bold;
225
+ color: grey!important;
226
+ }
227
+
228
+ .square {
229
+ width: 20px; /* Width of the square */
230
+ height: 20px; /* Height of the square */
231
+ border: 1px solid black; /* Black outline */
232
+ margin: auto;
233
+ background-color: white; /* Optional: set the background color */
234
+ position: relative;
235
+ z-index: 1; /* Higher stacking order for the square */
236
+ }
237
+
238
+ .circle {
239
+ width: 16px; /* Width of the square */
240
+ height: 16px; /* Height of the square */
241
+ border: 1px solid red; /* Black outline */
242
+ border-radius: 8px;
243
+ margin: auto;
244
+ background-color: white; /* Optional: set the background color */
245
+ position: relative;
246
+ z-index: 1; /* Higher stacking order for the square */
247
+ display: block!important;
248
+ }
249
+
250
+ table {
251
+ border: 0px!important; /* Black outline */
252
+ table-layout: fixed;
253
+ width:100%;
254
+ }
255
+
256
+ th, td {
257
+ font-weight: normal;
258
+ width: 7em!important;
259
+ text-align: center!important;
260
+ border: 0px!important;
261
+ }
262
+
263
+ tr {
264
+ border: 0px!important;
265
+ }
266
+
267
+ .dashed-cell {
268
+ position: relative;
269
+ width: 50px; /* Adjust width of the table cell */
270
+ }
271
+
272
+ .dashed-cell::before {
273
+ content: "";
274
+ position: absolute;
275
+ top: 0;
276
+ bottom: 0;
277
+ left: 50%; /* Center the dashed line horizontally */
278
+ width: 0; /* No width, just a vertical line */
279
+ border-left: 1px dashed black; /* Dashed vertical line */
280
+ transform: translateX(-50%); /* Center the line exactly in the middle */
281
+ }
282
+
283
+ .dashed-cell-horizontal::after {
284
+ content: "";
285
+ position: absolute;
286
+ left: 0;
287
+ right: 0;
288
+ top: 50%; /* Center the dashed horizontal line vertically */
289
+ height: 0; /* No height, just a horizontal line */
290
+ border-top: 1px dashed black; /* Dashed horizontal line */
291
+ transform: translateY(-50%); /* Center the line exactly in the middle */
292
+ }
293
+
294
+ .arrowtip {
295
+ width: 0;
296
+ height: 0;
297
+ border-left: 4px solid transparent;
298
+ border-right: 4px solid transparent;
299
+ border-bottom: 8px solid black; /* The triangle color */
300
+ bottom: 8px; /* The triangle color */
301
+ position: relative;
302
+ }
303
+
304
+ .span-cell::after {
305
+ content: '';
306
+ position: absolute;
307
+ top: 50%;
308
+ left: -1px;
309
+ width: 1px;
310
+ height: calc(100% * 6.5); /* Adjust the height as needed to reach the yellow circle */
311
+ background-color: red;
312
+ }
313
+
314
+ </style>"""
315
+
316
+
317
+ def generate_html_spanwise(token_strings, tokenwise_preds, spans, tokenizer, new_tags):
318
+
319
+ # spanwise annotated text
320
+ annotated = []
321
+ span_ends = -1
322
+ in_span = False
323
+
324
+ out_of_span_tokens = []
325
+ for i in reversed(range(len(tokenwise_preds))):
326
+
327
+ if in_span:
328
+ if i >= span_ends:
329
+ continue
330
+ else:
331
+ in_span = False
332
+
333
+ predicted_class = ""
334
+ style = ""
335
+
336
+ span = None
337
+ for s in spans:
338
+ if s[1] == i+1:
339
+ span = s
340
+
341
+ if tokenwise_preds[i] != 0 and span is not None:
342
+ predicted_class = f"highlight spanhighlight"
343
+ style = f"background-color: {map_value_to_color((tokenwise_preds[i]-1)/(len(new_tags)-1))}"
344
+ if tokenizer.convert_tokens_to_string([token_strings[i]]).startswith(" "):
345
+ annotated.append("Ġ")
346
+
347
+ span_opener = f"Ġ<span class='{predicted_class}' data-tooltip-text='{new_tags[tokenwise_preds[i]]}' style='{style}'>".replace(" ", "Ġ")
348
+ span_end = f"<span class='small-text'>{new_tags[tokenwise_preds[i]]}</span></span>"
349
+ annotated.extend(out_of_span_tokens)
350
+ out_of_span_tokens = []
351
+ span_ends = span[0]
352
+ in_span = True
353
+ annotated.append(span_end)
354
+ annotated.extend([token_strings[x] for x in reversed(range(span[0], span[1]))])
355
+ annotated.append(span_opener)
356
+ else:
357
+ out_of_span_tokens.append(token_strings[i])
358
+
359
+ annotated.extend(out_of_span_tokens)
360
+
361
+ return [x for x in reversed(annotated)]
362
+
363
+ def gen_json(input_text, max_new_tokens):
364
+ streamer = STOKEStreamer(tok, classifier_token, classifier_span)
365
+
366
+ new_tags = label_map
367
+
368
+ inputs = tok([f" {input_text}"], return_tensors="pt").to(model.device)
369
+ generation_kwargs = dict(
370
+ inputs, streamer=streamer, max_new_tokens=max_new_tokens,
371
+ repetition_penalty=1.2, do_sample=False
372
+ )
373
+
374
+ def generate_async():
375
+ model.generate(**generation_kwargs)
376
+
377
+ thread = Thread(target=generate_async)
378
+ thread.start()
379
+
380
+ # Display generated text as it becomes available
381
+ output_text = ""
382
+ text_tokenwise = ""
383
+ text_spans = ""
384
+ removed_spans = ""
385
+ tags = []
386
+ spans = []
387
+ for new_text in streamer:
388
+ if new_text[1] is not None and new_text[2] != ['']:
389
+ text_tokenwise = ""
390
+ output_text = ""
391
+ tags.extend(new_text[1])
392
+ spans.extend(new_text[-1])
393
+
394
+ # Tokenwise Classification
395
+ for tk, pred in zip(new_text[2],tags):
396
+ if pred != 0:
397
+ style = f"background-color: {map_value_to_color((pred-1)/(len(new_tags)-1))}"
398
+ if tk.startswith(" "):
399
+ text_tokenwise += " "
400
+ text_tokenwise += f"<span class='tooltip highlight' data-tooltip-text='{new_tags[pred]}' style='{style}'>{tk}</span>"
401
+ output_text += tk
402
+ else:
403
+ text_tokenwise += tk
404
+ output_text += tk
405
+
406
+ # Span Classification
407
+ text_spans = ""
408
+ if len(spans) > 0:
409
+ filtered_spans = remove_overlapping_spans(spans)
410
+ text_spans = generate_html_no_overlap(new_text[2], filtered_spans)
411
+ if len(spans) - len(filtered_spans) > 0:
412
+ removed_spans = f"{len(spans) - len(filtered_spans)} span(s) hidden due to overlap."
413
+ else:
414
+ for tk in new_text[2]:
415
+ text_spans += f"{tk}"
416
+
417
+ # Spanwise Classification
418
+ annotated_tokens = generate_html_spanwise(new_text[2], tags, [x for x in filter_spans(spans)[0]], tok, new_tags)
419
+ generated_text_spanwise = tok.convert_tokens_to_string(annotated_tokens).replace("<|endoftext|>", "").replace("<|begin_of_text|>", "")
420
+
421
+ output = f"{css}<br>"
422
+ output += generated_text_spanwise.replace("\n", " ").replace("$", "$") + "\n<br>"
423
+ #output += "<h5>Show tokenwise classification</h5>\n" + text_tokenwise.replace("\n", " ").replace("$", "\\$").replace("<|endoftext|>", "").replace("<|begin_of_text|>", "")
424
+ #output += "</details><details><summary>Show spans</summary>\n" + text_spans.replace("\n", " ").replace("$", "\\$")
425
+ #if removed_spans != "":
426
+ # output += f"<br><br><i>({removed_spans})</i>"
427
+ list_of_spans = [{"name": tok.convert_tokens_to_string(new_text[2][x[0]:x[1]]).strip(), "type": new_tags[tags[x[1]-1]]} for x in filter_spans(spans)[0] if new_tags[tags[x[1]-1]] != "O"]
428
+
429
+ out_dict = {"text": output_text.replace("<|endoftext|>", "").replace("<|begin_of_text|>", "".strip()), "entites": list_of_spans}
430
+
431
+ yield out_dict
432
+ return
433
+
434
+ # Gradio app function to generate text using the assigned model instance
435
+ def generate_text(input_text, max_new_tokens=2):
436
+ if input_text == "":
437
+ yield "Please enter some text first."
438
+ return
439
+
440
+ # Select the next model instance in a round-robin manner
441
+ model, tok = next(model_round_robin)
442
+
443
+ streamer = STOKEStreamer(tok, classifier_token, classifier_span)
444
+
445
+ new_tags = label_map
446
+
447
+ inputs = tok([f"{input_text[:200]}"], return_tensors="pt").to(model.device)
448
+ generation_kwargs = dict(
449
+ inputs, streamer=streamer, max_new_tokens=max_new_tokens,
450
+ repetition_penalty=1.2, do_sample=False, temperature=None, top_p=None
451
+ )
452
+
453
+ def generate_async():
454
+ model.generate(**generation_kwargs)
455
+
456
+ thread = Thread(target=generate_async)
457
+ thread.start()
458
+
459
+ # Display generated text as it becomes available
460
+ output_text = ""
461
+ text_tokenwise = ""
462
+ text_spans = ""
463
+ removed_spans = ""
464
+ tags = []
465
+ spans = []
466
+ for new_text in streamer:
467
+ if new_text[1] is not None and new_text[2] != ['']:
468
+ text_tokenwise = ""
469
+ output_text = ""
470
+ tags.extend(new_text[1])
471
+ spans.extend(new_text[-1])
472
+
473
+ # Tokenwise Classification
474
+ for tk, pred in zip(new_text[2],tags):
475
+ if pred != 0:
476
+ style = f"background-color: {map_value_to_color((pred-1)/(len(new_tags)-1))}"
477
+ if tk.startswith(" "):
478
+ text_tokenwise += " "
479
+ text_tokenwise += f"<span class='tooltip highlight' data-tooltip-text='{new_tags[pred]}' style='{style}'>{tk}</span>"
480
+ output_text += tk
481
+ else:
482
+ text_tokenwise += tk
483
+ output_text += tk
484
+
485
+ # Span Classification
486
+ text_spans = ""
487
+ if len(spans) > 0:
488
+ filtered_spans = remove_overlapping_spans(spans)
489
+ text_spans = generate_html_no_overlap(new_text[2], filtered_spans)
490
+ if len(spans) - len(filtered_spans) > 0:
491
+ removed_spans = f"{len(spans) - len(filtered_spans)} span(s) hidden due to overlap."
492
+ else:
493
+ for tk in new_text[2]:
494
+ text_spans += f"{tk}"
495
+
496
+ # Spanwise Classification
497
+ annotated_tokens = generate_html_spanwise(new_text[2], tags, [x for x in filter_spans(spans)[0]], tok, new_tags)
498
+ generated_text_spanwise = tok.convert_tokens_to_string(annotated_tokens).replace("<|endoftext|>", "").replace("<|begin_of_text|>", "")
499
+
500
+ output = f"{css}<div class=\"generated-content\"><br>"
501
+ output += generated_text_spanwise.replace("\n", " ").replace("$", "$") + "\n<br>"
502
+
503
+ list_of_spans = [{"name": tok.convert_tokens_to_string(new_text[2][x[0]:x[1]]).strip(), "type": new_tags[tags[x[1]-1]]} for x in filter_spans(spans)[0] if new_tags[tags[x[1]-1]] != "O"]
504
+
505
+ out_dict = {"text": output_text.replace("<|endoftext|>", "").replace("<|begin_of_text|>", "").strip(), "entites": list_of_spans}
506
+
507
+ output_tokenwise = f"""{css}<div class=\"generated-content\">
508
+ <table>"""
509
+
510
+ output_tokenwise += """<tr><th style="width: 10em!important;">Span detection + label propagation</th>"""
511
+ for i, (tk, pred) in enumerate(zip(new_text[2][1:],tags[1:])):
512
+ span = ""
513
+
514
+ if i in [x[0][1]-2 for x in spans] and pred != 0:
515
+ top_span = [x for x in spans if x[0][1]-2 == i][0]
516
+ spanstring = ''.join(new_text[2][top_span[0][0]:top_span[0][1]])
517
+ color = map_value_to_color((pred-1)/(len(new_tags)-1)) + "88"
518
+ span = f"<span class='highlight spanhighlight spantext' style='background-color: {color}; position: absolute; transform: translateX(-50%); white-space: nowrap; top: 0.6em;'>{spanstring}<span class='small-text'>{new_tags[pred]}</span></span>"
519
+ output_tokenwise += f"<td class='span-cell-2' style='position:relative;'>{span}</td>"
520
+ else:
521
+ output_tokenwise += f"<td style='position:relative;'></td>"
522
+ output_tokenwise += "</tr><tr><td></td>"
523
+
524
+ output_tokenwise += """<tr><td>Span detection</td>"""
525
+ for i, (tk, pred) in enumerate(zip(new_text[2][1:],tags[:])):
526
+ span = ""
527
+ if i in [x[0][1]-1 for x in spans]:
528
+ top_span = [x for x in spans if x[0][1]-1 == i][0]
529
+ spanstring = ''.join(new_text[2][top_span[0][0]:top_span[0][1]])
530
+ span = f"<span class='highlight spanhighlight spantext' style='border: 1px solid red; background-color: lightgrey; position: absolute; left: 0; transform: translateX(-100%); white-space: nowrap; top: 0.6em;'>{spanstring}</span>"
531
+ output_tokenwise += f"<td class='span-cell' style='position:relative;'>{span}</td>"
532
+ else:
533
+ output_tokenwise += f"<td style='position:relative;'></td>"
534
+ output_tokenwise += "</tr><tr><td></td>"
535
+
536
+ output_tokenwise += """<tr><td style='width: 10em;'>Tokenwise<br>entity typing</td>"""
537
+ for tk, pred in zip(new_text[2][1:],tags[1:]):
538
+ style = "background-color: lightgrey;"
539
+ if pred != 0:
540
+ style = f"background-color: {map_value_to_color((pred-1)/(len(new_tags)-1))};"
541
+ output_tokenwise += f"<td><span class='highlight spanhighlight' style='{style} font-weight:normal; font-size: xx-small; border: 1px solid red; color: white;'>{new_tags[pred]}</span></td>"
542
+ else:
543
+ output_tokenwise += "<td></td>"
544
+ #output_tokenwise += f"<th><span class='arrowtip'></span></th>"
545
+ output_tokenwise += "<td></td></tr><tr style='line-height: 0px!important;'><td></td>"
546
+
547
+ for tk, pred in zip(new_text[2][1:],tags[1:]):
548
+ output_tokenwise += f"<td><span class='arrowtip'></span></td>"
549
+ output_tokenwise += "</tr><tr><td></td>"
550
+
551
+ for i, (tk, pred) in enumerate(zip(new_text[2][1:],tags[1:])):
552
+ style = "border-color: lightgray;background-color: transparent;"
553
+ if i in [x[0][1]-1 for x in spans]:
554
+ style = "background-color: yellow;"
555
+ output_tokenwise += f"<td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='{style}margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td>"
556
+ output_tokenwise += "</tr><tr><td></td>"
557
+
558
+ for tk, pred in zip(new_text[2][1:],tags[1:]):
559
+ if pred != 0:
560
+ style = f"background-color: {map_value_to_color((pred-1)/(len(new_tags)-1))}"
561
+ output_tokenwise += f"<td class='dashed-cell'><div class='circle tooltip' data-tooltip-text='{new_tags[pred]}' style='{style}'></div></td>"
562
+ else:
563
+ output_tokenwise += f"<td class='dashed-cell'><div class='circle' style='border-color: lightgray;background-color: transparent;'></div></td>"
564
+ output_tokenwise += "</tr><tr><td></td>"
565
+
566
+ for i, (tk, pred) in enumerate(zip(new_text[2][1:],tags[1:])):
567
+ style = "border-color: lightgray;background-color: transparent;"
568
+ if i in [x[0][1]-1 for x in spans]:
569
+ style = "background-color: yellow;"
570
+ output_tokenwise += f"<td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='{style}margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td>"
571
+ output_tokenwise += "</tr><tr style='height: 36px;'><td></td>"
572
+
573
+ for tk, pred in zip(new_text[2][1:],tags[1:]):
574
+ output_tokenwise += f"<td class='dashed-cell'></td>"
575
+ output_tokenwise += "</tr><tr><td></td>"
576
+
577
+ for i, (tk, pred) in enumerate(zip(new_text[2][1:],tags[1:])):
578
+ style = "border-color: lightgray;background-color: transparent;"
579
+ if i in [x[0][1]-1 for x in spans]:
580
+ style = "background-color: yellow;"
581
+ output_tokenwise += f"<td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='{style}margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td>"
582
+ output_tokenwise += "</tr><tr><td></td>"
583
+
584
+ for tk, pred in zip(new_text[2][1:],tags[1:]):
585
+ output_tokenwise += f"<td><span class='highlight spanhighlight' style='background-color: lightgrey;'>{tk}</span></td>"
586
+ output_tokenwise += "</tr>"
587
+
588
+
589
+ #yield output + "</div>"
590
+ yield output_tokenwise + "</table></div>"
591
+ #time.sleep(0.5)
592
+
593
+ return
594
+
595
+
596
+ # Load datasets and models for the Gradio app
597
+ datasets = find_datasets_and_model_ids("data/")
598
+ available_models = list(datasets.keys())
599
+ available_datasets = {model: list(datasets[model].keys()) for model in available_models}
600
+ available_configs = {model: {dataset: list(datasets[model][dataset].keys()) for dataset in available_datasets[model]} for model in available_models}
601
+
602
+ def update_datasets(model_name):
603
+ return available_datasets[model_name]
604
+
605
+ def update_configs(model_name, dataset_name):
606
+ return available_configs[model_name][dataset_name]
607
+
608
+ # Load datasets and models for the Gradio app
609
+ datasets = find_datasets_and_model_ids("data/")
610
+ available_models = list(datasets.keys())
611
+ available_datasets = {model: list(datasets[model].keys()) for model in available_models}
612
+ available_configs = {model: {dataset: list(datasets[model][dataset].keys()) for dataset in available_datasets[model]} for model in available_models}
613
+
614
+ # Set the model ID and data configurations
615
+ model_id = "meta-llama/Llama-3.2-1B"
616
+ data_id = "STOKE_100"
617
+ config_id = "default"
618
+
619
+ # Load n_instances separate instances of the model and tokenizer
620
+ model_instances = get_multiple_model_and_tokenizer(model_id, n_instances)
621
+
622
+ # Set up the round-robin iterator to distribute the requests across model instances
623
+ model_round_robin = itertools.cycle(model_instances)
624
+
625
+
626
+ # Load model classifiers
627
+ try:
628
+ classifier_span, classifier_token, label_map = get_classifiers_for_model(
629
+ model_instances[0][0].config.n_head * model_instances[0][0].config.n_layer, model_instances[0][0].config.n_embd, model_instances[0][0].device,
630
+ datasets[model_id][data_id][config_id]
631
+ )
632
+ except:
633
+ classifier_span, classifier_token, label_map = get_classifiers_for_model(
634
+ model_instances[0][0].config.num_attention_heads * model_instances[0][0].config.num_hidden_layers, model_instances[0][0].config.hidden_size, model_instances[0][0].device,
635
+ datasets[model_id][data_id][config_id]
636
+ )
637
+
638
+ initial_output = (css+"""<div class="generated-content">
639
+ <table><tr><th style="width: 10em!important;">Span detection + label propagation</th><td style='position:relative;'></td><td style='position:relative;'></td><td style='position:relative;'></td><td style='position:relative;'></td><td class='span-cell-2' style='position:relative;'><span class='highlight spanhighlight spantext' style='background-color: #9ecae188; position: absolute; transform: translateX(-50%); white-space: nowrap; top: 0.6em;'>The New York Film Festival<span class='small-text'>EVENT</span></span></td><td style='position:relative;'></td><td style='position:relative;'></td><td style='position:relative;'></td></tr><tr><td></td><tr><td>Span detection</td><td style='position:relative;'></td><td style='position:relative;'></td><td style='position:relative;'></td><td style='position:relative;'></td><td style='position:relative;'></td><td class='span-cell' style='position:relative;'><span class='highlight spanhighlight spantext' style='border: 1px solid red; background-color: lightgrey; position: absolute; left: 0; transform: translateX(-100%); white-space: nowrap; top: 0.6em;'>The New York Film Festival</span></td><td style='position:relative;'></td><td style='position:relative;'></td><td style='position:relative;'></td></tr><tr><td></td><tr><td style='width: 10em;'>Tokenwise<br>entity typing</td><td></td><td><span class='highlight spanhighlight' style='background-color: #e6550d; font-weight:normal; font-size: xx-small; border: 1px solid red; color: white;'>GPE</span></td><td><span class='highlight spanhighlight' style='background-color: #756bb1; font-weight:normal; font-size: xx-small; border: 1px solid red; color: white;'>ORG</span></td><td><span class='highlight spanhighlight' style='background-color: #756bb1; font-weight:normal; font-size: xx-small; border: 1px solid red; color: white;'>ORG</span></td><td><span class='highlight spanhighlight' style='background-color: #9ecae1; font-weight:normal; font-size: xx-small; border: 1px solid red; color: white;'>EVENT</span></td><td></td><td></td><td></td><td></td></tr><tr style='line-height: 0px!important;'><td></td><td><span class='arrowtip'></span></td><td><span class='arrowtip'></span></td><td><span class='arrowtip'></span></td><td><span class='arrowtip'></span></td><td><span class='arrowtip'></span></td><td><span class='arrowtip'></span></td><td><span class='arrowtip'></span></td><td><span class='arrowtip'></span></td></tr><tr><td></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='background-color: yellow;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td></tr><tr><td></td><td class='dashed-cell'><div class='circle' style='border-color: lightgray;background-color: transparent;'></div></td><td class='dashed-cell'><div class='circle tooltip' data-tooltip-text='GPE' style='background-color: #e6550d'></div></td><td class='dashed-cell'><div class='circle tooltip' data-tooltip-text='ORG' style='background-color: #756bb1'></div></td><td class='dashed-cell'><div class='circle tooltip' data-tooltip-text='ORG' style='background-color: #756bb1'></div></td><td class='dashed-cell'><div class='circle tooltip' data-tooltip-text='EVENT' style='background-color: #9ecae1'></div></td><td class='dashed-cell'><div class='circle' style='border-color: lightgray;background-color: transparent;'></div></td><td class='dashed-cell'><div class='circle' style='border-color: lightgray;background-color: transparent;'></div></td><td class='dashed-cell'><div class='circle' style='border-color: lightgray;background-color: transparent;'></div></td></tr><tr><td></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='background-color: yellow;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td></tr><tr style='height: 36px;'><td></td><td class='dashed-cell'></td><td class='dashed-cell'></td><td class='dashed-cell'></td><td class='dashed-cell'></td><td class='dashed-cell'></td><td class='dashed-cell'></td><td class='dashed-cell'></td><td class='dashed-cell'></td></tr><tr><td></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='background-color: yellow;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td><td class='dashed-cell dashed-cell-horizontal'><div class='square'></div><div class='circle' style='border-color: lightgray;background-color: transparent;margin-top: -14px!important; margin-left: -19px; width: 8px; height: 8px; margin-bottom: 6px!important;'></div></td></tr><tr><td></td><td><span class='highlight spanhighlight' style='background-color: lightgrey;'>The</span></td><td><span class='highlight spanhighlight' style='background-color: lightgrey;'> New</span></td><td><span class='highlight spanhighlight' style='background-color: lightgrey;'> York</span></td><td><span class='highlight spanhighlight' style='background-color: lightgrey;'> Film</span></td><td><span class='highlight spanhighlight' style='background-color: lightgrey;'> Festival</span></td><td><span class='highlight spanhighlight' style='background-color: lightgrey;'> is</span></td><td><span class='highlight spanhighlight' style='background-color: lightgrey;'> an</span></td><td><span class='highlight spanhighlight' style='background-color: lightgrey;'> annual</span></td></tr></table></div>""", {'text': 'Miami is a city in the U.S. state of Florida, and it\'s also known as "The Magic City." It was founded by Henry Flagler on October 28th, 1896.', 'entites': [{'name': 'Miami', 'type': 'GPE'}, {'name': 'U.S.', 'type': 'GPE'}, {'name': 'Florida', 'type': 'GPE'}, {'name': 'The Magic City', 'type': 'WORK_OF_ART'}, {'name': 'Henry Flagler', 'type': 'PERSON'}, {'name': 'October 28th, 1896', 'type': 'DATE'}]})
640
+
641
+
642
+ with gr.Blocks(css="footer{display:none !important} .gradio-container {padding: 0!important; height:400px;}", fill_width=True, fill_height=True) as demo:
643
+ with gr.Tab("EMBER Demo"):
644
+ with gr.Row():
645
+ output_text = gr.HTML(label="Generated Text", value=initial_output[0])
646
+ with gr.Group():
647
+ with gr.Row():
648
+ input_text = gr.Textbox(label="Try with your own text!", value="The New York Film Festival is an", max_length=200, submit_btn=True)
649
+ # New HTML output for model info
650
+ model_info_html = gr.HTML(
651
+ label="Model Info",
652
+ value=f'<div style="font-weight: lighter; text-align: center; font-size: x-small;">{model_id} running on {gpu_name}</div>'
653
+ )
654
+
655
+
656
+ input_text.submit(
657
+ fn=generate_text,
658
+ inputs=[input_text],
659
+ outputs=[output_text],
660
+ concurrency_limit=n_instances,
661
+ concurrency_id="queue"
662
+ )
663
+
664
+ # Function to refresh the model info HTML
665
+ def refresh_model_info():
666
+ return f'<div style="overflow: visible; font-weight: lighter; text-align: center; font-size: x-small;">{model_id} running on {gpu_name}</div>'
667
+
668
+ # Update the model info HTML on button click
669
+ input_text.submit(
670
+ fn=refresh_model_info,
671
+ inputs=[],
672
+ outputs=[model_info_html],
673
+ queue=False
674
+ )
675
+
676
+
677
+ demo.queue()
678
+
679
+ demo.launch()