Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -437,45 +437,75 @@ def mask(sentence):
|
|
437 |
|
438 |
|
439 |
|
440 |
-
|
441 |
#plotly tree
|
|
|
|
|
|
|
|
|
|
|
442 |
def generate_plot(original_sentence):
|
443 |
paraphrased_sentences = generate_paraphrase(original_sentence)
|
444 |
first_paraphrased_sentence = paraphrased_sentences[0]
|
445 |
masked_sentence = mask_non_stopword(first_paraphrased_sentence)
|
446 |
masked_versions = mask(masked_sentence)
|
|
|
447 |
nodes = []
|
448 |
nodes.append(original_sentence)
|
449 |
nodes.extend(paraphrased_sentences)
|
450 |
nodes.extend(masked_versions)
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
464 |
fig = go.Figure()
|
465 |
-
positions = {
|
466 |
-
0: (0, 0),
|
467 |
-
1: (-4, -4),
|
468 |
-
2: (-2, -4),
|
469 |
-
3: (0, -4),
|
470 |
-
4: (2, -4),
|
471 |
-
5: (4, -4),
|
472 |
-
6: (-4.5, -8),
|
473 |
-
7: (-3, -8),
|
474 |
-
8: (-1.5, -8),
|
475 |
-
9: (0, -8),
|
476 |
-
10: (2, -8) # Example addition for index 10
|
477 |
-
}
|
478 |
|
|
|
479 |
for i, node in enumerate(wrapped_nodes):
|
480 |
x, y = positions[i]
|
481 |
fig.add_trace(go.Scatter(
|
@@ -500,6 +530,7 @@ def generate_plot(original_sentence):
|
|
500 |
width=200
|
501 |
)
|
502 |
|
|
|
503 |
for edge in edges:
|
504 |
x0, y0 = positions[edge[0]]
|
505 |
x1, y1 = positions[edge[1]]
|
|
|
437 |
|
438 |
|
439 |
|
|
|
440 |
#plotly tree
|
441 |
+
import plotly.graph_objs as go
|
442 |
+
import textwrap
|
443 |
+
import re
|
444 |
+
from collections import defaultdict
|
445 |
+
|
446 |
def generate_plot(original_sentence):
|
447 |
paraphrased_sentences = generate_paraphrase(original_sentence)
|
448 |
first_paraphrased_sentence = paraphrased_sentences[0]
|
449 |
masked_sentence = mask_non_stopword(first_paraphrased_sentence)
|
450 |
masked_versions = mask(masked_sentence)
|
451 |
+
|
452 |
nodes = []
|
453 |
nodes.append(original_sentence)
|
454 |
nodes.extend(paraphrased_sentences)
|
455 |
nodes.extend(masked_versions)
|
456 |
+
nodes[0] += ' L0'
|
457 |
+
para_len = len(paraphrased_sentences)
|
458 |
+
for i in range(1, para_len+1):
|
459 |
+
nodes[i] += ' L1'
|
460 |
+
for i in range(para_len+1, len(nodes)):
|
461 |
+
nodes[i] += ' L2'
|
462 |
+
|
463 |
+
cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes]
|
464 |
+
wrapped_nodes = ['<br>'.join(textwrap.wrap(node, width=30)) for node in cleaned_nodes]
|
465 |
+
|
466 |
+
def get_levels_and_edges(nodes):
|
467 |
+
levels = {}
|
468 |
+
edges = []
|
469 |
+
for i, node in enumerate(nodes):
|
470 |
+
level = int(node.split()[-1][1])
|
471 |
+
levels[i] = level
|
472 |
+
|
473 |
+
# Add edges from L0 to all L1 nodes
|
474 |
+
root_node = next(i for i, level in levels.items() if level == 0)
|
475 |
+
for i, level in levels.items():
|
476 |
+
if level == 1:
|
477 |
+
edges.append((root_node, i))
|
478 |
+
|
479 |
+
# Identify the first L1 node
|
480 |
+
first_l1_node = next(i for i, level in levels.items() if level == 1)
|
481 |
+
# Add edges from the first L1 node to all L2 nodes
|
482 |
+
for i, level in levels.items():
|
483 |
+
if level == 2:
|
484 |
+
edges.append((first_l1_node, i))
|
485 |
+
|
486 |
+
return levels, edges
|
487 |
+
|
488 |
+
# Get levels and dynamic edges
|
489 |
+
levels, edges = get_levels_and_edges(nodes)
|
490 |
+
max_level = max(levels.values())
|
491 |
+
|
492 |
+
# Calculate positions
|
493 |
+
positions = {}
|
494 |
+
level_widths = defaultdict(int)
|
495 |
+
for node, level in levels.items():
|
496 |
+
level_widths[level] += 1
|
497 |
+
|
498 |
+
x_offsets = {level: - (width - 1) / 2 for level, width in level_widths.items()}
|
499 |
+
y_gap = 4
|
500 |
+
|
501 |
+
for node, level in levels.items():
|
502 |
+
positions[node] = (x_offsets[level], -level * y_gap)
|
503 |
+
x_offsets[level] += 1
|
504 |
+
|
505 |
+
# Create figure
|
506 |
fig = go.Figure()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
507 |
|
508 |
+
# Add nodes to the figure
|
509 |
for i, node in enumerate(wrapped_nodes):
|
510 |
x, y = positions[i]
|
511 |
fig.add_trace(go.Scatter(
|
|
|
530 |
width=200
|
531 |
)
|
532 |
|
533 |
+
# Add edges to the figure
|
534 |
for edge in edges:
|
535 |
x0, y0 = positions[edge[0]]
|
536 |
x1, y1 = positions[edge[1]]
|