Spaces:
Runtime error
Runtime error
import plotly.graph_objects as go | |
import textwrap | |
import re | |
from collections import defaultdict | |
def apply_lcs_numbering(sentence, common_grams): | |
"""Apply LCS numbering based on common grams.""" | |
for idx, lcs in common_grams: | |
sentence = re.sub(rf"\b{lcs}\b", f"({idx}){lcs}", sentence) | |
return sentence | |
def highlight_words(sentence, color_map): | |
"""Highlight specified words in a sentence with corresponding colors.""" | |
for word, color in color_map.items(): | |
sentence = re.sub(f"\\b{word}\\b", f"{{{{{word}}}}}", sentence, flags=re.IGNORECASE) | |
return sentence | |
def clean_and_wrap_nodes(nodes, highlight_info): | |
"""Clean nodes by removing labels and wrap text for display.""" | |
global_color_map = dict(highlight_info) | |
cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes] | |
highlighted_nodes = [highlight_words(node, global_color_map) for node in cleaned_nodes] | |
return ['<br>'.join(textwrap.wrap(node, width=55)) for node in highlighted_nodes] | |
def get_levels_and_edges(nodes): | |
"""Determine levels and create edges dynamically.""" | |
levels = {} | |
edges = [] | |
for i, node in enumerate(nodes): | |
level = int(node.split()[-1][1]) | |
levels[i] = level | |
# Create edges from level 0 to level 1 nodes | |
root_node = next(i for i, level in levels.items() if level == 0) | |
edges.extend((root_node, i) for i, level in levels.items() if level == 1) | |
return levels, edges | |
def calculate_positions(levels): | |
"""Calculate x, y positions for each node based on levels.""" | |
positions = {} | |
level_heights = defaultdict(int) | |
y_offsets = {level: - (height - 1) / 2 for level, height in level_heights.items()} | |
for node, level in levels.items(): | |
level_heights[level] += 1 | |
x_gap = 2 | |
l1_y_gap = 10 | |
positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap) | |
y_offsets[level] += 1 | |
return positions | |
def color_highlighted_words(node, color_map): | |
"""Highlight words in a wrapped node string.""" | |
parts = re.split(r'(\{\{.*?\}\})', node) | |
colored_parts = [ | |
f"<span style='color: {color_map.get(match.group(1), 'black')};'>{match.group(1)}</span>" | |
if (match := re.match(r'\{\{(.*?)\}\}', part)) | |
else part | |
for part in parts | |
] | |
return ''.join(colored_parts) | |
def generate_subplot(paraphrased_sentence, scheme_sentences, highlight_info, common_grams, subplot_number): | |
"""Generate a subplot based on the input sentences and highlight info.""" | |
# Combine nodes into one list with appropriate labels | |
nodes = [paraphrased_sentence + ' L0'] + [s + ' L1' for s in scheme_sentences] | |
# Apply LCS numbering and clean/wrap nodes | |
nodes = [apply_lcs_numbering(node, common_grams) for node in nodes] | |
wrapped_nodes = clean_and_wrap_nodes(nodes, highlight_info) | |
# Get levels and edges | |
levels, edges = get_levels_and_edges(nodes) | |
positions = calculate_positions(levels) | |
# Create figure | |
fig = go.Figure() | |
# Add nodes and edges to the figure | |
for i, node in enumerate(wrapped_nodes): | |
colored_node = color_highlighted_words(node, dict(highlight_info)) | |
x, y = positions[i] | |
fig.add_trace(go.Scatter( | |
x=[-x], # Reflect the x coordinate | |
y=[y], | |
mode='markers', | |
marker=dict(size=10, color='blue'), | |
hoverinfo='none' | |
)) | |
fig.add_annotation( | |
x=-x, # Reflect the x coordinate | |
y=y, | |
text=colored_node, | |
showarrow=False, | |
xshift=15, | |
align="center", | |
font=dict(size=12), | |
bordercolor='black', | |
borderwidth=1, | |
borderpad=2, | |
bgcolor='white', | |
width=300, | |
height=120 | |
) | |
# Add edges and edge annotations | |
edge_texts = [ | |
"Highest Entropy Masking", "Pseudo-random Masking", "Random Masking", | |
"Greedy Sampling", "Temperature Sampling", "Exponential Minimum Sampling", | |
"Inverse Transform Sampling", "Greedy Sampling", "Temperature Sampling", | |
"Exponential Minimum Sampling", "Inverse Transform Sampling", | |
"Greedy Sampling", "Temperature Sampling", "Exponential Minimum Sampling", | |
"Inverse Transform Sampling" | |
] | |
for i, edge in enumerate(edges): | |
x0, y0 = positions[edge[0]] | |
x1, y1 = positions[edge[1]] | |
fig.add_trace(go.Scatter( | |
x=[-x0, -x1], # Reflect the x coordinates | |
y=[y0, y1], | |
mode='lines', | |
line=dict(color='black', width=1) | |
)) | |
# Add text annotation above the edge | |
mid_x = (-x0 + -x1) / 2 | |
mid_y = (y0 + y1) / 2 | |
fig.add_annotation( | |
x=mid_x, | |
y=mid_y + 0.8, # Adjust y position to shift text upwards | |
text=edge_texts[i], # Use the text specific to this edge | |
showarrow=False, | |
font=dict(size=12), | |
align="center" | |
) | |
fig.update_layout( | |
showlegend=False, | |
margin=dict(t=20, b=20, l=20, r=20), | |
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
width=1435, | |
height=1000 | |
) | |
return fig | |
def generate_subplot1(paraphrased_sentence, scheme_sentences, highlight_info, common_grams): | |
return generate_subplot(paraphrased_sentence, scheme_sentences, highlight_info, common_grams, subplot_number=1) | |
def generate_subplot2(scheme_sentences, sampled_sentence, highlight_info, common_grams): | |
nodes = scheme_sentences + [s + ' L1' for s in sampled_sentence] | |
for i in range(len(scheme_sentences)): | |
nodes[i] += ' L0' # Reassign levels | |
# Apply LCS numbering and clean/wrap nodes | |
nodes = [apply_lcs_numbering(node, common_grams) for node in nodes] | |
wrapped_nodes = clean_and_wrap_nodes(nodes, highlight_info) | |
# Get levels and edges | |
levels, edges = get_levels_and_edges(nodes) | |
positions = calculate_positions(levels) | |
# Create figure | |
fig2 = go.Figure() | |
# Add nodes and edges to the figure | |
for i, node in enumerate(wrapped_nodes): | |
colored_node = color_highlighted_words(node, dict(highlight_info)) | |
x, y = positions[i] | |
fig2.add_trace(go.Scatter( | |
x=[-x], # Reflect the x coordinate | |
y=[y], | |
mode='markers', | |
marker=dict(size=10, color='blue'), | |
hoverinfo='none' | |
)) | |
fig2.add_annotation( | |
x=-x, # Reflect the x coordinate | |
y=y, | |
text=colored_node, | |
showarrow=False, | |
xshift=15, | |
align="center", | |
font=dict(size=12), | |
bordercolor='black', | |
borderwidth=1, | |
borderpad=2, | |
bgcolor='white', | |
width=450, | |
height=65 | |
) | |
# Add edges and text above each edge | |
edge_texts = [ | |
"Highest Entropy Masking", "Pseudo-random Masking", "Random Masking", | |
"Greedy Sampling", "Temperature Sampling", "Exponential Minimum Sampling", | |
"Inverse Transform Sampling", "Greedy Sampling", "Temperature Sampling", | |
"Exponential Minimum Sampling", "Inverse Transform Sampling", | |
"Greedy Sampling", "Temperature Sampling", "Exponential Minimum Sampling", | |
"Inverse Transform Sampling" | |
] | |
for i, edge in enumerate(edges): | |
x0, y0 = positions[edge[0]] | |
x1, y1 = positions[edge[1]] | |
fig2.add_trace(go.Scatter( | |
x=[-x0, -x1], # Reflect the x coordinates | |
y=[y0, y1], | |
mode='lines', | |
line=dict(color='black', width=1) | |
)) | |
# Add text annotation above the edge | |
mid_x = (-x0 + -x1) / 2 | |
mid_y = (y0 + y1) / 2 | |
fig2.add_annotation( | |
x=mid_x, | |
y=mid_y + 0.8, # Adjust y position to shift text upwards | |
text=edge_texts[i], # Use the text specific to this edge | |
showarrow=False, | |
font=dict(size=12), | |
align="center" | |
) | |
fig2.update_layout( | |
showlegend=False, | |
margin=dict(t=20, b=20, l=20, r=20), | |
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
width=1435, | |
height=1000 | |
) | |
return fig2 | |