Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -272,7 +272,7 @@ class BeamNode:
|
|
272 |
is_selected_sequence: bool
|
273 |
|
274 |
|
275 |
-
def generate_beams(n_beams, start_sentence, scores, length_penalty, decoded_sequences
|
276 |
original_tree = BeamNode(
|
277 |
cumulative_score=0,
|
278 |
current_token_ix=None,
|
@@ -415,8 +415,6 @@ def generate_beams(n_beams, start_sentence, scores, length_penalty, decoded_sequ
|
|
415 |
current_token_choice_ix = top_df_selected_filtered.iloc[beam_ix]["token_index"]
|
416 |
beam_trees[beam_ix] = beam_trees[beam_ix].children[current_token_choice_ix]
|
417 |
|
418 |
-
print(f"Step {step}, beams kept: {beams_to_keep}")
|
419 |
-
|
420 |
return original_tree
|
421 |
|
422 |
@spaces.GPU
|
@@ -445,14 +443,23 @@ def get_beam_search_html(
|
|
445 |
for i, sequence in enumerate(decoded_sequences):
|
446 |
markdown += f"\n- Score `{outputs.sequences_scores[i]:.2f}`: `{clean(sequence.replace('<s> ', ''))}`"
|
447 |
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
456 |
html = generate_html(input_text, original_tree)
|
457 |
return html, markdown
|
458 |
|
|
|
272 |
is_selected_sequence: bool
|
273 |
|
274 |
|
275 |
+
def generate_beams(n_beams, start_sentence, scores, length_penalty, decoded_sequences):
|
276 |
original_tree = BeamNode(
|
277 |
cumulative_score=0,
|
278 |
current_token_ix=None,
|
|
|
415 |
current_token_choice_ix = top_df_selected_filtered.iloc[beam_ix]["token_index"]
|
416 |
beam_trees[beam_ix] = beam_trees[beam_ix].children[current_token_choice_ix]
|
417 |
|
|
|
|
|
418 |
return original_tree
|
419 |
|
420 |
@spaces.GPU
|
|
|
443 |
for i, sequence in enumerate(decoded_sequences):
|
444 |
markdown += f"\n- Score `{outputs.sequences_scores[i]:.2f}`: `{clean(sequence.replace('<s> ', ''))}`"
|
445 |
|
446 |
+
if number_beams > 1:
|
447 |
+
original_tree = generate_beams(
|
448 |
+
number_beams,
|
449 |
+
input_text,
|
450 |
+
outputs.scores[:],
|
451 |
+
length_penalty,
|
452 |
+
decoded_sequences,
|
453 |
+
)
|
454 |
+
else:
|
455 |
+
original_tree = generate_beams(
|
456 |
+
n_beams,
|
457 |
+
start_sentence,
|
458 |
+
outputs.logits,
|
459 |
+
0,
|
460 |
+
decoded_sequences,
|
461 |
+
)
|
462 |
+
|
463 |
html = generate_html(input_text, original_tree)
|
464 |
return html, markdown
|
465 |
|