Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -30,14 +30,12 @@ STYLE = """
|
|
30 |
.prose table {
|
31 |
margin-bottom: 0!important;
|
32 |
}
|
33 |
-
|
34 |
.prose td, th {
|
35 |
padding-left: 2px;
|
36 |
padding-right: 2px;
|
37 |
padding-top: 0;
|
38 |
padding-bottom: 0;
|
39 |
}
|
40 |
-
|
41 |
.tree {
|
42 |
padding: 0px;
|
43 |
margin: 0!important;
|
@@ -48,13 +46,11 @@ STYLE = """
|
|
48 |
text-align: center;
|
49 |
display:inline-block;
|
50 |
}
|
51 |
-
|
52 |
#root {
|
53 |
display: inline-grid!important;
|
54 |
width:auto!important;
|
55 |
min-width: 220px;
|
56 |
}
|
57 |
-
|
58 |
.tree ul {
|
59 |
padding-left: 20px;
|
60 |
position: relative;
|
@@ -75,7 +71,6 @@ STYLE = """
|
|
75 |
justify-content: start;
|
76 |
align-items: center;
|
77 |
}
|
78 |
-
|
79 |
.tree li::before, .tree li::after {
|
80 |
content: "";
|
81 |
position: absolute;
|
@@ -96,7 +91,6 @@ STYLE = """
|
|
96 |
.tree li:only-child::after, li:only-child::before {
|
97 |
display: none;
|
98 |
}
|
99 |
-
|
100 |
.tree li:first-child::before, .tree li:last-child::after {
|
101 |
border: 0 none;
|
102 |
}
|
@@ -111,7 +105,6 @@ STYLE = """
|
|
111 |
-webkit-border-radius: 5px 0 0 0;
|
112 |
-moz-border-radius: 5px 0 0 0;
|
113 |
}
|
114 |
-
|
115 |
.tree ul ul::before {
|
116 |
content: "";
|
117 |
position: absolute;
|
@@ -124,7 +117,6 @@ STYLE = """
|
|
124 |
.tree ul:has(> li:only-child)::before {
|
125 |
width:40px;
|
126 |
}
|
127 |
-
|
128 |
a:before {
|
129 |
border-right: 1px solid var(--body-text-color);
|
130 |
border-bottom: 1px solid var(--body-text-color);
|
@@ -138,8 +130,6 @@ a:before {
|
|
138 |
margin-left: 6px;
|
139 |
transform: rotate(315deg);
|
140 |
}
|
141 |
-
|
142 |
-
|
143 |
.tree li a {
|
144 |
border: 1px solid var(--body-text-color);
|
145 |
padding: 5px;
|
@@ -155,7 +145,6 @@ a:before {
|
|
155 |
.tree li a span {
|
156 |
padding: 5px;
|
157 |
font-size: 12px;
|
158 |
-
text-transform: uppercase;
|
159 |
letter-spacing: 1px;
|
160 |
font-weight: 500;
|
161 |
}
|
@@ -166,7 +155,7 @@ a:before {
|
|
166 |
.tree li a:hover+ul li::after, .tree li a:hover+ul li::before, .tree li a:hover+ul::before, .tree li a:hover+ul ul::before {
|
167 |
border-color: #7c2d12;
|
168 |
}
|
169 |
-
.chosen {
|
170 |
background-color: #ea580c;
|
171 |
width:auto!important;
|
172 |
}
|
@@ -206,33 +195,37 @@ def generate_markdown_table(
|
|
206 |
def generate_nodes(token_ix, node, step):
|
207 |
"""Recursively generate HTML for the tree nodes."""
|
208 |
token = tokenizer.decode([token_ix])
|
209 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
if node.table is not None:
|
211 |
html_content += node.table
|
212 |
html_content += "</a>"
|
|
|
213 |
if len(node.children.keys()) > 0:
|
214 |
html_content += "<ul> "
|
215 |
for token_ix, subnode in node.children.items():
|
216 |
html_content += generate_nodes(token_ix, subnode, step=step + 1)
|
217 |
html_content += "</ul>"
|
218 |
html_content += "</li>"
|
|
|
219 |
return html_content
|
220 |
|
221 |
|
222 |
def generate_html(start_sentence, original_tree):
|
223 |
-
|
224 |
html_output = f"""<div class="custom-container">
|
225 |
<div class="tree">
|
226 |
-
<ul>
|
227 |
-
|
228 |
-
|
229 |
-
html_output +=
|
230 |
-
|
231 |
-
html_output += generate_nodes(token_ix, subnode, step=1)
|
232 |
-
html_output += "</ul>"
|
233 |
-
|
234 |
html_output += """
|
235 |
-
</ul>
|
236 |
</div>
|
237 |
</body>
|
238 |
"""
|
@@ -246,11 +239,14 @@ from dataclasses import dataclass
|
|
246 |
|
247 |
@dataclass
|
248 |
class BeamNode:
|
|
|
249 |
cumulative_score: float
|
250 |
children_score_divider: float
|
251 |
table: str
|
252 |
current_sentence: str
|
253 |
children: Dict[int, "BeamNode"]
|
|
|
|
|
254 |
|
255 |
|
256 |
def generate_beams(start_sentence, scores, sequences, length_penalty):
|
@@ -258,13 +254,19 @@ def generate_beams(start_sentence, scores, sequences, length_penalty):
|
|
258 |
input_length = len(tokenizer([start_sentence], return_tensors="pt"))
|
259 |
original_tree = BeamNode(
|
260 |
cumulative_score=0,
|
|
|
261 |
table=None,
|
262 |
current_sentence=start_sentence,
|
263 |
children={},
|
264 |
children_score_divider=((input_length + 1) ** length_penalty),
|
|
|
|
|
265 |
)
|
266 |
n_beams = len(scores[0])
|
267 |
beam_trees = [original_tree] * n_beams
|
|
|
|
|
|
|
268 |
for step, step_scores in enumerate(scores):
|
269 |
(
|
270 |
top_token_indexes,
|
@@ -273,8 +275,13 @@ def generate_beams(start_sentence, scores, sequences, length_penalty):
|
|
273 |
current_completions,
|
274 |
top_tokens,
|
275 |
) = ([], [], [], [], [])
|
276 |
-
for beam_ix in range(n_beams):
|
277 |
current_beam = beam_trees[beam_ix]
|
|
|
|
|
|
|
|
|
|
|
278 |
# Get top cumulative scores for the current beam
|
279 |
current_top_token_indexes = list(
|
280 |
np.array(scores[step][beam_ix].argsort()[-n_beams:])[::-1]
|
@@ -337,14 +344,31 @@ def generate_beams(start_sentence, scores, sequences, length_penalty):
|
|
337 |
+ scores[step][source_beam_ix][current_token_choice_ix].numpy()
|
338 |
)
|
339 |
beam_trees[source_beam_ix].children[current_token_choice_ix] = BeamNode(
|
|
|
340 |
table=None,
|
341 |
children={},
|
342 |
current_sentence=beam_trees[source_beam_ix].current_sentence
|
343 |
+ current_token_choice,
|
344 |
cumulative_score=cumulative_score,
|
|
|
|
|
345 |
children_score_divider=((input_length + step + 1) ** length_penalty),
|
|
|
|
|
|
|
|
|
346 |
)
|
347 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
348 |
# Reassign all beams at once
|
349 |
beam_trees = [
|
350 |
beam_trees[int(top_df_selected.iloc[beam_ix]["beam_index"])]
|
@@ -355,6 +379,7 @@ def generate_beams(start_sentence, scores, sequences, length_penalty):
|
|
355 |
for beam_ix in range(n_beams):
|
356 |
current_token_choice_ix = top_df_selected.iloc[beam_ix]["token_index"]
|
357 |
beam_trees[beam_ix] = beam_trees[beam_ix].children[current_token_choice_ix]
|
|
|
358 |
|
359 |
return original_tree
|
360 |
|
@@ -373,9 +398,10 @@ def get_beam_search_html(input_text, number_steps, number_beams, length_penalty)
|
|
373 |
do_sample=False,
|
374 |
)
|
375 |
markdown = "Output sequences:"
|
|
|
376 |
decoded_sequences = tokenizer.batch_decode(outputs.sequences)
|
377 |
for i, sequence in enumerate(decoded_sequences):
|
378 |
-
markdown += f"\n- {clean(sequence.replace('<s> ', ''))} (score {outputs.sequences_scores[i]:.2f})"
|
379 |
|
380 |
original_tree = generate_beams(
|
381 |
input_text,
|
@@ -393,7 +419,8 @@ with gr.Blocks(
|
|
393 |
),
|
394 |
css=STYLE,
|
395 |
) as demo:
|
396 |
-
gr.Markdown(
|
|
|
397 |
|
398 |
Play with the parameters below to understand how beam search decoding works!
|
399 |
|
@@ -402,15 +429,29 @@ Play with the parameters below to understand how beam search decoding works!
|
|
402 |
- **Number of steps**: the number of tokens to generate
|
403 |
- **Number of beams**: the number of beams to use
|
404 |
- **Length penalty**: the length penalty to apply to outputs. `length_penalty` > 0.0 promotes longer sequences, while `length_penalty` < 0.0 encourages shorter sequences.
|
405 |
-
"""
|
406 |
-
|
|
|
|
|
|
|
|
|
407 |
with gr.Row():
|
408 |
-
steps = gr.Slider(
|
409 |
-
|
410 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
411 |
button = gr.Button()
|
412 |
out_html = gr.Markdown()
|
413 |
out_markdown = gr.Markdown()
|
414 |
-
button.click(
|
|
|
|
|
|
|
|
|
415 |
|
416 |
demo.launch()
|
|
|
30 |
.prose table {
|
31 |
margin-bottom: 0!important;
|
32 |
}
|
|
|
33 |
.prose td, th {
|
34 |
padding-left: 2px;
|
35 |
padding-right: 2px;
|
36 |
padding-top: 0;
|
37 |
padding-bottom: 0;
|
38 |
}
|
|
|
39 |
.tree {
|
40 |
padding: 0px;
|
41 |
margin: 0!important;
|
|
|
46 |
text-align: center;
|
47 |
display:inline-block;
|
48 |
}
|
|
|
49 |
#root {
|
50 |
display: inline-grid!important;
|
51 |
width:auto!important;
|
52 |
min-width: 220px;
|
53 |
}
|
|
|
54 |
.tree ul {
|
55 |
padding-left: 20px;
|
56 |
position: relative;
|
|
|
71 |
justify-content: start;
|
72 |
align-items: center;
|
73 |
}
|
|
|
74 |
.tree li::before, .tree li::after {
|
75 |
content: "";
|
76 |
position: absolute;
|
|
|
91 |
.tree li:only-child::after, li:only-child::before {
|
92 |
display: none;
|
93 |
}
|
|
|
94 |
.tree li:first-child::before, .tree li:last-child::after {
|
95 |
border: 0 none;
|
96 |
}
|
|
|
105 |
-webkit-border-radius: 5px 0 0 0;
|
106 |
-moz-border-radius: 5px 0 0 0;
|
107 |
}
|
|
|
108 |
.tree ul ul::before {
|
109 |
content: "";
|
110 |
position: absolute;
|
|
|
117 |
.tree ul:has(> li:only-child)::before {
|
118 |
width:40px;
|
119 |
}
|
|
|
120 |
a:before {
|
121 |
border-right: 1px solid var(--body-text-color);
|
122 |
border-bottom: 1px solid var(--body-text-color);
|
|
|
130 |
margin-left: 6px;
|
131 |
transform: rotate(315deg);
|
132 |
}
|
|
|
|
|
133 |
.tree li a {
|
134 |
border: 1px solid var(--body-text-color);
|
135 |
padding: 5px;
|
|
|
145 |
.tree li a span {
|
146 |
padding: 5px;
|
147 |
font-size: 12px;
|
|
|
148 |
letter-spacing: 1px;
|
149 |
font-weight: 500;
|
150 |
}
|
|
|
155 |
.tree li a:hover+ul li::after, .tree li a:hover+ul li::before, .tree li a:hover+ul::before, .tree li a:hover+ul ul::before {
|
156 |
border-color: #7c2d12;
|
157 |
}
|
158 |
+
.end-of-text, .chosen {
|
159 |
background-color: #ea580c;
|
160 |
width:auto!important;
|
161 |
}
|
|
|
195 |
def generate_nodes(token_ix, node, step):
|
196 |
"""Recursively generate HTML for the tree nodes."""
|
197 |
token = tokenizer.decode([token_ix])
|
198 |
+
|
199 |
+
if node.is_final:
|
200 |
+
return f"<li> <a href='#' class='end-of-text'> <span> <b>{token_ix}:<br>{clean(token)}</b> <br> Total score: {node.total_score:.2f} </span> </a> </li>"
|
201 |
+
|
202 |
+
html_content = (
|
203 |
+
f"<li> <a href='#'> <span> <b>{token_ix}:<br>{clean(token)}</b> </span>"
|
204 |
+
)
|
205 |
if node.table is not None:
|
206 |
html_content += node.table
|
207 |
html_content += "</a>"
|
208 |
+
|
209 |
if len(node.children.keys()) > 0:
|
210 |
html_content += "<ul> "
|
211 |
for token_ix, subnode in node.children.items():
|
212 |
html_content += generate_nodes(token_ix, subnode, step=step + 1)
|
213 |
html_content += "</ul>"
|
214 |
html_content += "</li>"
|
215 |
+
|
216 |
return html_content
|
217 |
|
218 |
|
219 |
def generate_html(start_sentence, original_tree):
|
|
|
220 |
html_output = f"""<div class="custom-container">
|
221 |
<div class="tree">
|
222 |
+
<ul> <li> <a href='#' id='root'> <span> <b>{start_sentence}</b> </span> {original_tree.table} </a>"""
|
223 |
+
html_output += "<ul> "
|
224 |
+
for token_ix, subnode in original_tree.children.items():
|
225 |
+
html_output += generate_nodes(token_ix, subnode, step=1)
|
226 |
+
html_output += "</ul>"
|
|
|
|
|
|
|
227 |
html_output += """
|
228 |
+
</li> </ul>
|
229 |
</div>
|
230 |
</body>
|
231 |
"""
|
|
|
239 |
|
240 |
@dataclass
|
241 |
class BeamNode:
|
242 |
+
current_token_ix: int
|
243 |
cumulative_score: float
|
244 |
children_score_divider: float
|
245 |
table: str
|
246 |
current_sentence: str
|
247 |
children: Dict[int, "BeamNode"]
|
248 |
+
total_score: float
|
249 |
+
is_final: bool
|
250 |
|
251 |
|
252 |
def generate_beams(start_sentence, scores, sequences, length_penalty):
|
|
|
254 |
input_length = len(tokenizer([start_sentence], return_tensors="pt"))
|
255 |
original_tree = BeamNode(
|
256 |
cumulative_score=0,
|
257 |
+
current_token_ix=None,
|
258 |
table=None,
|
259 |
current_sentence=start_sentence,
|
260 |
children={},
|
261 |
children_score_divider=((input_length + 1) ** length_penalty),
|
262 |
+
total_score=None,
|
263 |
+
is_final=False,
|
264 |
)
|
265 |
n_beams = len(scores[0])
|
266 |
beam_trees = [original_tree] * n_beams
|
267 |
+
|
268 |
+
candidate_nodes = []
|
269 |
+
|
270 |
for step, step_scores in enumerate(scores):
|
271 |
(
|
272 |
top_token_indexes,
|
|
|
275 |
current_completions,
|
276 |
top_tokens,
|
277 |
) = ([], [], [], [], [])
|
278 |
+
for beam_ix in range(n_beams): # Get possible descendants for each beam
|
279 |
current_beam = beam_trees[beam_ix]
|
280 |
+
|
281 |
+
# skip if the beam is already final
|
282 |
+
if current_beam.is_final:
|
283 |
+
continue
|
284 |
+
|
285 |
# Get top cumulative scores for the current beam
|
286 |
current_top_token_indexes = list(
|
287 |
np.array(scores[step][beam_ix].argsort()[-n_beams:])[::-1]
|
|
|
344 |
+ scores[step][source_beam_ix][current_token_choice_ix].numpy()
|
345 |
)
|
346 |
beam_trees[source_beam_ix].children[current_token_choice_ix] = BeamNode(
|
347 |
+
current_token_ix=current_token_choice_ix,
|
348 |
table=None,
|
349 |
children={},
|
350 |
current_sentence=beam_trees[source_beam_ix].current_sentence
|
351 |
+ current_token_choice,
|
352 |
cumulative_score=cumulative_score,
|
353 |
+
total_score=cumulative_score
|
354 |
+
/ ((input_length + step - 1) ** length_penalty),
|
355 |
children_score_divider=((input_length + step + 1) ** length_penalty),
|
356 |
+
is_final=(
|
357 |
+
step == len(scores) - 1
|
358 |
+
or current_token_choice_ix == tokenizer.eos_token_id
|
359 |
+
),
|
360 |
)
|
361 |
|
362 |
+
# Check this child should be selected as a top beam.
|
363 |
+
# Is it a final step or an EOS token?
|
364 |
+
if (
|
365 |
+
step == len(scores) - 1
|
366 |
+
or current_token_choice_ix == tokenizer.eos_token_id
|
367 |
+
):
|
368 |
+
candidate_nodes.append(
|
369 |
+
beam_trees[source_beam_ix].children[current_token_choice_ix]
|
370 |
+
)
|
371 |
+
|
372 |
# Reassign all beams at once
|
373 |
beam_trees = [
|
374 |
beam_trees[int(top_df_selected.iloc[beam_ix]["beam_index"])]
|
|
|
379 |
for beam_ix in range(n_beams):
|
380 |
current_token_choice_ix = top_df_selected.iloc[beam_ix]["token_index"]
|
381 |
beam_trees[beam_ix] = beam_trees[beam_ix].children[current_token_choice_ix]
|
382 |
+
print("Final nodes", candidate_nodes)
|
383 |
|
384 |
return original_tree
|
385 |
|
|
|
398 |
do_sample=False,
|
399 |
)
|
400 |
markdown = "Output sequences:"
|
401 |
+
# Sequences are padded anyway so you can batch decode them
|
402 |
decoded_sequences = tokenizer.batch_decode(outputs.sequences)
|
403 |
for i, sequence in enumerate(decoded_sequences):
|
404 |
+
markdown += f"\n- '{clean(sequence.replace('<s> ', ''))}' (score {outputs.sequences_scores[i]:.2f})"
|
405 |
|
406 |
original_tree = generate_beams(
|
407 |
input_text,
|
|
|
419 |
),
|
420 |
css=STYLE,
|
421 |
) as demo:
|
422 |
+
gr.Markdown(
|
423 |
+
"""# Beam search visualizer
|
424 |
|
425 |
Play with the parameters below to understand how beam search decoding works!
|
426 |
|
|
|
429 |
- **Number of steps**: the number of tokens to generate
|
430 |
- **Number of beams**: the number of beams to use
|
431 |
- **Length penalty**: the length penalty to apply to outputs. `length_penalty` > 0.0 promotes longer sequences, while `length_penalty` < 0.0 encourages shorter sequences.
|
432 |
+
"""
|
433 |
+
)
|
434 |
+
text = gr.Textbox(
|
435 |
+
label="Sentence to decode from",
|
436 |
+
value="Conclusion: thanks a lot. This article was originally published on",
|
437 |
+
)
|
438 |
with gr.Row():
|
439 |
+
steps = gr.Slider(
|
440 |
+
label="Number of steps", minimum=1, maximum=8, step=1, value=4
|
441 |
+
)
|
442 |
+
beams = gr.Slider(
|
443 |
+
label="Number of beams", minimum=2, maximum=4, step=1, value=3
|
444 |
+
)
|
445 |
+
length_penalty = gr.Slider(
|
446 |
+
label="Length penalty", minimum=-4, maximum=4, step=0.5, value=1
|
447 |
+
)
|
448 |
button = gr.Button()
|
449 |
out_html = gr.Markdown()
|
450 |
out_markdown = gr.Markdown()
|
451 |
+
button.click(
|
452 |
+
get_beam_search_html,
|
453 |
+
inputs=[text, steps, beams, length_penalty],
|
454 |
+
outputs=[out_html, out_markdown],
|
455 |
+
)
|
456 |
|
457 |
demo.launch()
|