jgyasu's picture
Update tree.py
d3347e0 verified
raw
history blame
3.51 kB
import plotly.graph_objs as go
import textwrap
import re
from collections import defaultdict
from paraphraser import generate_paraphrase
from masking_methods import mask, mask_non_stopword
def generate_plot(original_sentence, selected_sentences):
first_paraphrased_sentence = selected_sentences[0]
masked_sentence = mask_non_stopword(first_paraphrased_sentence)
masked_versions = mask(masked_sentence)
nodes = []
nodes.append(original_sentence)
nodes.extend(selected_sentences)
nodes.extend(masked_versions)
nodes[0] += ' L0'
para_len = len(selected_sentences)
for i in range(1, para_len+1):
nodes[i] += ' L1'
for i in range(para_len+1, len(nodes)):
nodes[i] += ' L2'
cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes]
wrapped_nodes = ['<br>'.join(textwrap.wrap(node, width=30)) for node in cleaned_nodes]
def get_levels_and_edges(nodes):
levels = {}
edges = []
for i, node in enumerate(nodes):
level = int(node.split()[-1][1])
levels[i] = level
# Add edges from L0 to all L1 nodes
root_node = next(i for i, level in levels.items() if level == 0)
for i, level in levels.items():
if level == 1:
edges.append((root_node, i))
# Identify the first L1 node
first_l1_node = next(i for i, level in levels.items() if level == 1)
# Add edges from the first L1 node to all L2 nodes
for i, level in levels.items():
if level == 2:
edges.append((first_l1_node, i))
return levels, edges
# Get levels and dynamic edges
levels, edges = get_levels_and_edges(nodes)
max_level = max(levels.values())
# Calculate positions
positions = {}
level_widths = defaultdict(int)
for node, level in levels.items():
level_widths[level] += 1
x_offsets = {level: - (width - 1) / 2 for level, width in level_widths.items()}
y_gap = 4
for node, level in levels.items():
positions[node] = (x_offsets[level], -level * y_gap)
x_offsets[level] += 1
# Create figure
fig = go.Figure()
# Add nodes to the figure
for i, node in enumerate(wrapped_nodes):
x, y = positions[i]
fig.add_trace(go.Scatter(
x=[x],
y=[y],
mode='markers',
marker=dict(size=10, color='blue'),
hoverinfo='none'
))
fig.add_annotation(
x=x,
y=y,
text=node,
showarrow=False,
yshift=20, # Adjust the y-shift value to avoid overlap
align="center",
font=dict(size=10),
bordercolor='black',
borderwidth=1,
borderpad=4,
bgcolor='white',
width=200
)
# Add edges to the figure
for edge in edges:
x0, y0 = positions[edge[0]]
x1, y1 = positions[edge[1]]
fig.add_trace(go.Scatter(
x=[x0, x1],
y=[y0, y1],
mode='lines',
line=dict(color='black', width=2)
))
fig.update_layout(
showlegend=False,
margin=dict(t=50, b=50, l=50, r=50),
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
width=1470,
height=800 # Increase height to provide more space
)
return fig