npops commited on
Commit
0cd6d05
·
1 Parent(s): 765e08e
Files changed (1) hide show
  1. app.py +442 -0
app.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, STOKEStreamer
3
+ from threading import Thread
4
+ import json
5
+ import torch
6
+ import matplotlib.pyplot as plt
7
+ from matplotlib.colors import to_hex
8
+ import numpy as np
9
+ import os
10
+ import urllib.request
11
+ import zipfile
12
+
13
+
14
+ class MLP(torch.nn.Module):
15
+ def __init__(self, input_dim, output_dim, hidden_dim=1024, layer_id=0, cuda=False):
16
+ super(MLP, self).__init__()
17
+ self.fc1 = torch.nn.Linear(input_dim, hidden_dim) # Input layer to hidden layer
18
+ self.fc3 = torch.nn.Linear(hidden_dim, output_dim) # Hidden layer to output layer
19
+ self.layer_id = layer_id
20
+ if cuda:
21
+ self.device = "cuda"
22
+ else:
23
+ self.device = "cpu"
24
+ self.to(self.device)
25
+
26
+ def forward(self, x):
27
+ x = torch.flatten(x, start_dim=1)
28
+ x = torch.relu(self.fc1(x))
29
+ x = self.fc3(x)
30
+
31
+ return torch.argmax(x, dim=-1).cpu().detach(), torch.softmax(x, dim=-1).cpu().detach()
32
+
33
+ def map_value_to_color(value, colormap_name='tab20c'):
34
+ """
35
+ Map a value between 0 and 1 to a CSS color using a Python colormap.
36
+
37
+ Args:
38
+ value (float): A value between 0 and 1.
39
+ colormap_name (str): The name of the colormap to use (e.g., 'viridis').
40
+
41
+ Returns:
42
+ str: A CSS color string in the form 'rgb(r, g, b)'.
43
+ """
44
+ # Ensure the value is within the range [0, 1]
45
+ value = np.clip(value, 0.0, 1.0)
46
+
47
+ # Get the colormap
48
+ colormap = plt.get_cmap(colormap_name)
49
+
50
+ # Map the value to a color
51
+ rgba_color = colormap(value)
52
+
53
+ # Convert the RGBA color to CSS format
54
+ css_color = to_hex(rgba_color)
55
+
56
+ return css_color + "88"
57
+
58
+ @st.cache_resource
59
+ def get_model_and_tokenizer(name):
60
+ # Load pre-trained model and tokenizer
61
+ tok = AutoTokenizer.from_pretrained(name)
62
+ model = AutoModelForCausalLM.from_pretrained(name)
63
+ return model, tok
64
+
65
+ @st.cache_resource
66
+ def get_classifiers_for_model(att_size, emb_size, device, config_paths):
67
+ classifier_token = None
68
+ #print(config)
69
+ config = {
70
+ "classifier_token": json.load(open(os.path.join(config_paths["classifier_token"], "config.json"), "r")),
71
+ "classifier_span": json.load(open(os.path.join(config_paths["classifier_span"], "config.json"), "r"))
72
+ }
73
+
74
+ layer_id = config["classifier_token"]["layer"]
75
+
76
+ classifier_span = MLP(att_size, 2, hidden_dim=config["classifier_span"]["classifier_dim"]).to(device)
77
+ classifier_span.load_state_dict(torch.load(os.path.join(config_paths["classifier_span"], "checkpoint.pt"), map_location=device))
78
+
79
+ classifier_token = MLP(emb_size, len(config["classifier_token"]["label_map"]), layer_id=layer_id, hidden_dim=config["classifier_token"]["classifier_dim"]).to(device)
80
+ classifier_token.load_state_dict(torch.load(os.path.join(config_paths["classifier_token"], "checkpoint.pt"), map_location=device))
81
+
82
+ print(sum(p.numel() for p in classifier_span.parameters()), sum(p.numel() for p in classifier_token.parameters()))
83
+
84
+ return classifier_span, classifier_token, config["classifier_token"]["label_map"]
85
+
86
+ def get_available_models():
87
+ available_models = []
88
+ for model_name in ["gpt2", "gpt2-xl"]:
89
+ if os.path.isfile(f"checkpoints/{model_name}/config.json"):
90
+ available_models.append(model_name)
91
+ return available_models
92
+
93
+ def get_available_datasets(model_name):
94
+ available_datasets = []
95
+ config_path = f"checkpoints/{model_name}/config.json"
96
+ if os.path.isfile(config_path):
97
+ with open(config_path, "r") as f:
98
+ config = json.load(f)
99
+ # Assuming datasets are keys in config.json
100
+ available_datasets = list(config.keys())
101
+ return available_datasets
102
+
103
+ def download_and_extract_zip(url, extract_dir):
104
+ # Determine the parent directory
105
+ parent_dir = os.path.split(os.path.dirname(extract_dir))[-2]
106
+ print(parent_dir)
107
+
108
+ # Download the zip file to the parent directory
109
+ zip_file_path = os.path.join(parent_dir, "data.zip")
110
+ urllib.request.urlretrieve(url, zip_file_path)
111
+
112
+ # Extract the zip file
113
+ with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
114
+ zip_ref.extractall(parent_dir)
115
+
116
+ # Remove the zip file
117
+ os.remove(zip_file_path)
118
+
119
+ def find_datasets_and_model_ids(root_dir):
120
+ datasets = {}
121
+
122
+ # Check if the root directory exists
123
+ if not os.path.exists(root_dir):
124
+ # If root directory doesn't exist, download a zip file and unpack it
125
+ print("Root directory doesn't exist. Downloading zip file...")
126
+ url = "https://drive.usercontent.google.com/download?id=1dHjH_J0zuPS-SDVrh49tMpIx5ramu_hc&export=download&authuser=0&confirm=t&uuid=4efcec77-571c-44c7-82f1-f39ddae50eb5&at=APZUnTW8g-Ab4PUT0-B9mh4jQSc-%3A1711040271924" # Replace with your actual download URL
127
+ download_and_extract_zip(url, root_dir)
128
+ print("Zip file downloaded and unpacked successfully.")
129
+
130
+
131
+ for root, dirs, files in os.walk(root_dir):
132
+ if 'config.json' in files and 'stoke_config.json' in files:
133
+ config_path = os.path.join(root, 'config.json')
134
+ stoke_config_path = os.path.join(root, 'stoke_config.json')
135
+
136
+ with open(config_path, 'r') as f:
137
+ config_data = json.load(f)
138
+ model_id = config_data.get('model_id')
139
+ if model_id:
140
+ dataset_name = os.path.basename(os.path.dirname(config_path))
141
+
142
+ with open(stoke_config_path, 'r') as f:
143
+ stoke_config_data = json.load(f)
144
+ if model_id:
145
+ dataset_name = os.path.basename(os.path.dirname(stoke_config_path))
146
+ datasets.setdefault(model_id, {})[dataset_name] = stoke_config_data
147
+
148
+ return datasets
149
+
150
+
151
+ # Main content
152
+ st.title("Playground")
153
+
154
+ # Sidebar for model and dataset selection
155
+ with st.sidebar:
156
+ st.subheader("Model and Dataset Selection")
157
+ datasets = find_datasets_and_model_ids("data/")
158
+ available_models = datasets.keys()
159
+ print(datasets)
160
+ if available_models:
161
+ model_selection = st.selectbox("Select Model", available_models)
162
+ else:
163
+ st.error("No models available. Please check the file paths.")
164
+
165
+ # Select dataset based on selected model
166
+ available_datasets = datasets[model_selection]
167
+ if available_datasets:
168
+ dataset_selection = st.selectbox("Select Dataset", available_datasets)
169
+ else:
170
+ st.error("No datasets available for the selected model.")
171
+
172
+ # Select dataset based on selected model
173
+ available_configs = datasets[model_selection][dataset_selection]
174
+ if available_configs:
175
+ config_selection = st.selectbox("Select Config", available_configs.keys())
176
+ else:
177
+ st.error("No configs available for the selected dataset.")
178
+
179
+ # Load model and streamer based on selections
180
+ model, tok = get_model_and_tokenizer(model_selection)
181
+ if torch.cuda.is_available():
182
+ model.cuda()
183
+ classifier_span, classifier_token, label_map = get_classifiers_for_model(model.config.n_head*model.config.n_layer, model.config.n_embd, model.device, datasets[model_selection][dataset_selection][config_selection])
184
+ streamer = STOKEStreamer(tok, classifier_token, classifier_span)
185
+
186
+ new_tags = label_map
187
+
188
+
189
+ def filter_spans(spans_and_values):
190
+ if spans_and_values == []:
191
+ return [], []
192
+ # Create a dictionary to store spans based on their second index values
193
+ span_dict = {}
194
+
195
+ spans, values = [x[0] for x in spans_and_values], [x[1] for x in spans_and_values]
196
+
197
+ # Iterate through the spans and update the dictionary with the highest value
198
+ for span, value in zip(spans, values):
199
+ start, end = span
200
+ if start > end or end - start > 15 or start == 0:
201
+ continue
202
+ current_value = span_dict.get(end, None)
203
+
204
+ if current_value is None or current_value[1] < value:
205
+ span_dict[end] = (span, value)
206
+
207
+ if span_dict == {}:
208
+ return [], []
209
+ # Extract the filtered spans and values
210
+ filtered_spans, filtered_values = zip(*span_dict.values())
211
+
212
+ return list(filtered_spans), list(filtered_values)
213
+
214
+ def remove_overlapping_spans(spans):
215
+ # Sort the spans based on their end points
216
+ sorted_spans = sorted(spans, key=lambda x: x[0][1])
217
+
218
+ non_overlapping_spans = []
219
+ last_end = float('-inf')
220
+
221
+ # Iterate through the sorted spans
222
+ for span in sorted_spans:
223
+ start, end = span[0]
224
+ value = span[1]
225
+
226
+ # If the current span does not overlap with the previous one
227
+ if start >= last_end:
228
+ non_overlapping_spans.append(span)
229
+ last_end = end
230
+ else:
231
+ # If it overlaps, choose the one with the highest value
232
+ existing_span_index = -1
233
+ for i, existing_span in enumerate(non_overlapping_spans):
234
+ if existing_span[0][1] <= start:
235
+ existing_span_index = i
236
+ break
237
+ if existing_span_index != -1 and non_overlapping_spans[existing_span_index][1] < value:
238
+ non_overlapping_spans[existing_span_index] = span
239
+
240
+ return non_overlapping_spans
241
+
242
+ def generate_html_no_overlap(tokenized_text, spans):
243
+ current_index = 0
244
+ html_content = ""
245
+
246
+ for (span_start, span_end), value in spans:
247
+ # Add text before the span
248
+ html_content += "".join(tokenized_text[current_index:span_start])
249
+
250
+ # Add the span with underlining
251
+ html_content += "<b><u>"
252
+ html_content += "".join(tokenized_text[span_start:span_end])
253
+ html_content += "</u></b> "
254
+
255
+ current_index = span_end
256
+
257
+ # Add any remaining text after the last span
258
+ html_content += "".join(tokenized_text[current_index:])
259
+
260
+ return html_content
261
+
262
+
263
+ css = """
264
+ <style>
265
+ .highlight {
266
+ display: inline;
267
+ }
268
+ .highlight::after {
269
+ background-color: var(data-color);
270
+ }
271
+ .spanhighlight {
272
+ padding: 2px 5px;
273
+ border-radius: 5px;
274
+ }
275
+ .tooltip {
276
+ position: relative;
277
+ display: inline-block;
278
+ }
279
+
280
+ .tooltip::after {
281
+ content: attr(data-tooltip-text); /* Set content from data-tooltip-text attribute */
282
+ display: none;
283
+ position: absolute;
284
+ background-color: #333;
285
+ color: #fff;
286
+ padding: 5px;
287
+ border-radius: 5px;
288
+ bottom: 100%; /* Position it above the element */
289
+ left: 50%;
290
+ transform: translateX(-50%);
291
+ width: auto;
292
+ min-width: 120px;
293
+ margin: 0 auto;
294
+ text-align: center;
295
+ }
296
+
297
+ .tooltip:hover::after {
298
+ display: block; /* Show the tooltip on hover */
299
+ }
300
+
301
+ .small-text {
302
+ padding: 2px 5px;
303
+ background-color: white;
304
+ border-radius: 5px;
305
+ font-size: xx-small;
306
+ margin-left: 0.5em;
307
+ vertical-align: 0.2em;
308
+ font-weight: bold;
309
+ color: grey;
310
+ }
311
+ </style>"""
312
+
313
+
314
+ def generate_html_spanwise(token_strings, tokenwise_preds, spans, tokenizer):
315
+
316
+ # spanwise annotated text
317
+ annotated = []
318
+ span_ends = -1
319
+ in_span = False
320
+
321
+ out_of_span_tokens = []
322
+ for i in reversed(range(len(tokenwise_preds))):
323
+
324
+ if in_span:
325
+ if i >= span_ends:
326
+ continue
327
+ else:
328
+ in_span = False
329
+
330
+ predicted_class = ""
331
+ style = ""
332
+
333
+ span = None
334
+ for s in spans:
335
+ if s[1] == i+1:
336
+ span = s
337
+
338
+ if tokenwise_preds[i] != 0 and span is not None:
339
+ predicted_class = f"highlight spanhighlight"
340
+ style = f"background-color: {map_value_to_color((tokenwise_preds[i]-1)/(len(new_tags)-1))}"
341
+ if tokenizer.convert_tokens_to_string([token_strings[i]]).startswith(" "):
342
+ annotated.append("Ġ")
343
+
344
+ span_opener = f"Ġ<span class='{predicted_class}' data-tooltip-text='{new_tags[tokenwise_preds[i]]}' style='{style}'>".replace(" ", "Ġ")
345
+ span_end = f"<span class='small-text'>{new_tags[tokenwise_preds[i]]}</span></span>"
346
+ annotated.extend(out_of_span_tokens)
347
+ out_of_span_tokens = []
348
+ span_ends = span[0]
349
+ in_span = True
350
+ annotated.append(span_end)
351
+ annotated.extend([token_strings[x] for x in reversed(range(span[0], span[1]))])
352
+ annotated.append(span_opener)
353
+ else:
354
+ out_of_span_tokens.append(token_strings[i])
355
+
356
+ annotated.extend(out_of_span_tokens)
357
+
358
+ return [x for x in reversed(annotated)]
359
+
360
+ # Define function to generate text based on input
361
+ def generate_text(generation_kwargs, output_field):
362
+
363
+ # Function to generate text in a separate thread
364
+ def generate_async():
365
+ model.generate(**generation_kwargs)
366
+
367
+ # Start text generation in a separate thread
368
+ thread = Thread(target=generate_async)
369
+ thread.start()
370
+
371
+ # Display generated text as it becomes available
372
+ text_tokenwise = ""
373
+ text_spans = ""
374
+ removed_spans = ""
375
+ tags = []
376
+ spans = []
377
+ for new_text in streamer:
378
+ if new_text[1] is not None and new_text[2] != ['']:
379
+ text_tokenwise = ""
380
+ tags.extend(new_text[1])
381
+ spans.extend(new_text[-1])
382
+
383
+ # Tokenwise Classification
384
+ for tk, pred in zip(new_text[2],tags):
385
+ if pred != 0:
386
+ style = f"background-color: {map_value_to_color((pred-1)/(len(new_tags)-1))}"
387
+ if tk.startswith(" "):
388
+ text_tokenwise += " "
389
+ text_tokenwise += f"<span class='tooltip highlight' data-tooltip-text='{new_tags[pred]}' style='{style}'>{tk}</span>"
390
+ else:
391
+ text_tokenwise += tk
392
+
393
+ # Span Classification
394
+ text_spans = ""
395
+ if len(spans) > 0:
396
+ filtered_spans = remove_overlapping_spans(spans)
397
+ text_spans = generate_html_no_overlap(new_text[2], filtered_spans)
398
+ if len(spans) - len(filtered_spans) > 0:
399
+ removed_spans = f"{len(spans) - len(filtered_spans)} span(s) hidden due to overlap."
400
+ else:
401
+ for tk in new_text[2]:
402
+ text_spans += f"{tk}"
403
+
404
+ # Spanwise Classification
405
+ annotated_tokens = generate_html_spanwise(new_text[2], tags, [x for x in filter_spans(spans)[0]], tok)
406
+ generated_text_spanwise = tok.convert_tokens_to_string(annotated_tokens).replace("<|endoftext|>", "")
407
+
408
+ output_field.empty()
409
+ output = f"{css}"
410
+ output += generated_text_spanwise.replace("\n", " ").replace("$", "$") + "\n<br>"
411
+ output += "<details><summary>Show tokenwise classification</summary>\n" + text_tokenwise.replace("\n", " ").replace("$", "\\$")
412
+ #output += "</details><details><summary>Show spans</summary>\n" + text_spans.replace("\n", " ").replace("$", "\\$")
413
+ if removed_spans != "":
414
+ output += f"<br><br><i>({removed_spans})</i>"
415
+ output += "</details>"
416
+ output_field.write(output, unsafe_allow_html=True)
417
+
418
+ # Input field
419
+ input_text = st.text_area("Enter prompt for completion", "")
420
+
421
+ # Sidebar for customizing generation parameters
422
+ with st.sidebar:
423
+ st.subheader("Generation Parameters")
424
+ max_new_tokens = st.slider("Max New Tokens", min_value=1, max_value=100, value=30)
425
+ repetition_penalty = st.slider("Repetition Penalty", min_value=1.0, max_value=2.0, value=1.2)
426
+ do_sample = st.checkbox("Do Sample", value=True)
427
+ temperature = st.slider("Temperature", min_value=0.1, max_value=2.0, value=1.0)
428
+ top_p = st.slider("Top-p", min_value=0.1, max_value=1.0, value=0.3)
429
+ top_k = st.slider("Top-k", min_value=10, max_value=100, value=50)
430
+ typical_p = st.slider("Typical P", min_value=0.1, max_value=1.0, value=1.0)
431
+
432
+ # Button to generate text
433
+ if st.button("Generate"):
434
+ if input_text:
435
+ output_field = st.empty()
436
+ inputs = tok([" " + input_text], return_tensors="pt").to(model.device)
437
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens,
438
+ repetition_penalty=repetition_penalty, temperature=temperature,
439
+ top_p=top_p, top_k=top_k, do_sample=do_sample, typical_p=typical_p)
440
+ generate_text(generation_kwargs, output_field)
441
+ else:
442
+ st.warning("Please enter some text first.")