|
import plotly.graph_objs as go |
|
import textwrap |
|
import re |
|
from collections import defaultdict |
|
from paraphraser import generate_paraphrase |
|
from masking_methods import mask, mask_non_stopword |
|
|
|
def generate_plot(original_sentence, selected_sentences): |
|
first_paraphrased_sentence = selected_sentences[0] |
|
masked_sentence = mask_non_stopword(first_paraphrased_sentence) |
|
masked_versions = mask(masked_sentence) |
|
|
|
nodes = [] |
|
nodes.append(original_sentence) |
|
nodes.extend(selected_sentences) |
|
nodes.extend(masked_versions) |
|
nodes[0] += ' L0' |
|
para_len = len(selected_sentences) |
|
for i in range(1, para_len+1): |
|
nodes[i] += ' L1' |
|
for i in range(para_len+1, len(nodes)): |
|
nodes[i] += ' L2' |
|
|
|
cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes] |
|
wrapped_nodes = ['<br>'.join(textwrap.wrap(node, width=30)) for node in cleaned_nodes] |
|
|
|
def get_levels_and_edges(nodes): |
|
levels = {} |
|
edges = [] |
|
for i, node in enumerate(nodes): |
|
level = int(node.split()[-1][1]) |
|
levels[i] = level |
|
|
|
|
|
root_node = next(i for i, level in levels.items() if level == 0) |
|
for i, level in levels.items(): |
|
if level == 1: |
|
edges.append((root_node, i)) |
|
|
|
|
|
first_l1_node = next(i for i, level in levels.items() if level == 1) |
|
|
|
for i, level in levels.items(): |
|
if level == 2: |
|
edges.append((first_l1_node, i)) |
|
|
|
return levels, edges |
|
|
|
|
|
levels, edges = get_levels_and_edges(nodes) |
|
max_level = max(levels.values()) |
|
|
|
|
|
positions = {} |
|
level_widths = defaultdict(int) |
|
for node, level in levels.items(): |
|
level_widths[level] += 1 |
|
|
|
x_offsets = {level: - (width - 1) / 2 for level, width in level_widths.items()} |
|
y_gap = 4 |
|
|
|
for node, level in levels.items(): |
|
positions[node] = (x_offsets[level], -level * y_gap) |
|
x_offsets[level] += 1 |
|
|
|
|
|
fig = go.Figure() |
|
|
|
|
|
for i, node in enumerate(wrapped_nodes): |
|
x, y = positions[i] |
|
fig.add_trace(go.Scatter( |
|
x=[x], |
|
y=[y], |
|
mode='markers', |
|
marker=dict(size=10, color='blue'), |
|
hoverinfo='none' |
|
)) |
|
fig.add_annotation( |
|
x=x, |
|
y=y, |
|
text=node, |
|
showarrow=False, |
|
yshift=20, |
|
align="center", |
|
font=dict(size=10), |
|
bordercolor='black', |
|
borderwidth=1, |
|
borderpad=4, |
|
bgcolor='white', |
|
width=200 |
|
) |
|
|
|
|
|
for edge in edges: |
|
x0, y0 = positions[edge[0]] |
|
x1, y1 = positions[edge[1]] |
|
fig.add_trace(go.Scatter( |
|
x=[x0, x1], |
|
y=[y0, y1], |
|
mode='lines', |
|
line=dict(color='black', width=2) |
|
)) |
|
|
|
fig.update_layout( |
|
showlegend=False, |
|
margin=dict(t=50, b=50, l=50, r=50), |
|
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), |
|
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), |
|
width=1470, |
|
height=800 |
|
) |
|
|
|
return fig |