Spaces:
Running
Running
import plotly.graph_objects as go | |
import textwrap | |
import re | |
from collections import defaultdict | |
def generate_subplot1(paraphrased_sentence, scheme_sentences, highlight_info, common_grams): | |
# Combine nodes into one list with appropriate labels | |
nodes = [paraphrased_sentence] + scheme_sentences | |
nodes[0] += ' L0' # Paraphrased sentence is level 0 | |
for i in range(1, len(nodes)): | |
nodes[i] += ' L1' # Scheme sentences are level 1 | |
# Function to apply LCS numbering based on common_grams | |
def apply_lcs_numbering(sentence, common_grams): | |
for idx, lcs in common_grams: | |
# Only replace if the LCS is a whole word (not part of another word) | |
sentence = re.sub(rf"\b{lcs}\b", f"({idx}){lcs}", sentence) | |
return sentence | |
# Apply LCS numbering | |
nodes = [apply_lcs_numbering(node, common_grams) for node in nodes] | |
# Define the highlight_words function | |
def highlight_words(sentence, color_map): | |
for word, color in color_map.items(): | |
sentence = re.sub(f"\\b{word}\\b", f"{{{{{word}}}}}", sentence, flags=re.IGNORECASE) | |
return sentence | |
# Clean and wrap nodes, and highlight specified words globally | |
cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes] | |
global_color_map = dict(highlight_info) | |
highlighted_nodes = [highlight_words(node, global_color_map) for node in cleaned_nodes] | |
wrapped_nodes = ['<br>'.join(textwrap.wrap(node, width=55)) for node in highlighted_nodes] | |
# Function to determine tree levels and create edges dynamically | |
def get_levels_and_edges(nodes): | |
levels = {} | |
edges = [] | |
for i, node in enumerate(nodes): | |
level = int(node.split()[-1][1]) | |
levels[i] = level | |
# Add edges from L0 to all L1 nodes | |
root_node = next(i for i, level in levels.items() if level == 0) | |
for i, level in levels.items(): | |
if level == 1: | |
edges.append((root_node, i)) | |
return levels, edges | |
# Get levels and dynamic edges | |
levels, edges = get_levels_and_edges(nodes) | |
max_level = max(levels.values(), default=0) | |
# Calculate positions | |
positions = {} | |
level_heights = defaultdict(int) | |
for node, level in levels.items(): | |
level_heights[level] += 1 | |
y_offsets = {level: - (height - 1) / 2 for level, height in level_heights.items()} | |
x_gap = 2 | |
l1_y_gap = 10 | |
for node, level in levels.items(): | |
if level == 1: | |
positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap) | |
else: | |
positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap) | |
y_offsets[level] += 1 | |
# Function to highlight words in a wrapped node string | |
def color_highlighted_words(node, color_map): | |
parts = re.split(r'(\{\{.*?\}\})', node) | |
colored_parts = [] | |
for part in parts: | |
match = re.match(r'\{\{(.*?)\}\}', part) | |
if match: | |
word = match.group(1) | |
color = color_map.get(word, 'black') | |
colored_parts.append(f"<span style='color: {color};'>{word}</span>") | |
else: | |
colored_parts.append(part) | |
return ''.join(colored_parts) | |
# Define the text for each edge | |
edge_texts = [ | |
"Highest Entropy Masking", | |
"Pseudo-random Masking", | |
"Random Masking", | |
"Greedy Sampling", | |
"Temperature Sampling", | |
"Exponential Minimum Sampling", | |
"Inverse Transform Sampling", | |
"Greedy Sampling", | |
"Temperature Sampling", | |
"Exponential Minimum Sampling", | |
"Inverse Transform Sampling", | |
"Greedy Sampling", | |
"Temperature Sampling", | |
"Exponential Minimum Sampling", | |
"Inverse Transform Sampling" | |
] | |
# Create figure | |
fig1 = go.Figure() | |
# Add nodes to the figure | |
for i, node in enumerate(wrapped_nodes): | |
colored_node = color_highlighted_words(node, global_color_map) | |
x, y = positions[i] | |
fig1.add_trace(go.Scatter( | |
x=[-x], # Reflect the x coordinate | |
y=[y], | |
mode='markers', | |
marker=dict(size=10, color='blue'), | |
hoverinfo='none' | |
)) | |
fig1.add_annotation( | |
x=-x, # Reflect the x coordinate | |
y=y, | |
text=colored_node, | |
showarrow=False, | |
xshift=15, | |
align="center", | |
font=dict(size=12), | |
bordercolor='black', | |
borderwidth=1, | |
borderpad=2, | |
bgcolor='white', | |
width=300, | |
height=120 | |
) | |
# Add edges and text above each edge | |
for i, edge in enumerate(edges): | |
x0, y0 = positions[edge[0]] | |
x1, y1 = positions[edge[1]] | |
fig1.add_trace(go.Scatter( | |
x=[-x0, -x1], # Reflect the x coordinates | |
y=[y0, y1], | |
mode='lines', | |
line=dict(color='black', width=1) | |
)) | |
# Calculate the midpoint of the edge | |
mid_x = (-x0 + -x1) / 2 | |
mid_y = (y0 + y1) / 2 | |
# Adjust y position to shift text upwards | |
text_y_position = mid_y + 0.8 # Increase this value to shift the text further upwards | |
# Add text annotation above the edge | |
fig1.add_annotation( | |
x=mid_x, | |
y=text_y_position, | |
text=edge_texts[i], # Use the text specific to this edge | |
showarrow=False, | |
font=dict(size=12), | |
align="center" | |
) | |
fig1.update_layout( | |
showlegend=False, | |
margin=dict(t=20, b=20, l=20, r=20), | |
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
width=1435, # Adjusted width to accommodate more levels | |
height=1000 # Adjusted height to accommodate more levels | |
) | |
return fig1 | |
def generate_subplot2(scheme_sentences, sampled_sentence, highlight_info, common_grams): | |
# Combine nodes into one list with appropriate labels | |
nodes = scheme_sentences + sampled_sentence | |
para_len = len(scheme_sentences) | |
# Reassign levels: L1 -> L0, L2 -> L1 | |
for i in range(para_len): | |
nodes[i] += ' L0' # Scheme sentences are now level 0 | |
for i in range(para_len, len(nodes)): | |
nodes[i] += ' L1' # Sampled sentences are now level 1 | |
# Function to apply LCS numbering based on common_grams | |
def apply_lcs_numbering(sentence, common_grams): | |
for idx, lcs in common_grams: | |
# Only replace if the LCS is a whole word (not part of another word) | |
sentence = re.sub(rf"\b{lcs}\b", f"({idx}){lcs}", sentence) | |
return sentence | |
# Apply LCS numbering | |
nodes = [apply_lcs_numbering(node, common_grams) for node in nodes] | |
# Define the highlight_words function | |
def highlight_words(sentence, color_map): | |
for word, color in color_map.items(): | |
sentence = re.sub(f"\\b{word}\\b", f"{{{{{word}}}}}", sentence, flags=re.IGNORECASE) | |
return sentence | |
# Clean and wrap nodes, and highlight specified words globally | |
cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes] | |
global_color_map = dict(highlight_info) | |
highlighted_nodes = [highlight_words(node, global_color_map) for node in cleaned_nodes] | |
wrapped_nodes = ['<br>'.join(textwrap.wrap(node, width=80)) for node in highlighted_nodes] | |
# Function to determine tree levels and create edges dynamically | |
def get_levels_and_edges(nodes): | |
levels = {} | |
edges = [] | |
for i, node in enumerate(nodes): | |
level = int(node.split()[-1][1]) | |
levels[i] = level | |
# Add edges from L0 to all L1 nodes | |
l0_indices = [i for i, level in levels.items() if level == 0] | |
l1_indices = [i for i, level in levels.items() if level == 1] | |
# Ensure there are exactly 3 L0 nodes | |
if len(l0_indices) < 3: | |
raise ValueError("There should be exactly 3 L0 nodes to attach edges correctly.") | |
# Split L1 nodes into 3 groups of 4 for attaching to L0 nodes | |
for i, l1_node in enumerate(l1_indices): | |
if i < 4: | |
edges.append((l0_indices[0], l1_node)) # Connect to the first L0 node | |
elif i < 8: | |
edges.append((l0_indices[1], l1_node)) # Connect to the second L0 node | |
else: | |
edges.append((l0_indices[2], l1_node)) # Connect to the third L0 node | |
return levels, edges | |
# Get levels and dynamic edges | |
levels, edges = get_levels_and_edges(nodes) | |
max_level = max(levels.values(), default=0) | |
# Calculate positions | |
positions = {} | |
level_heights = defaultdict(int) | |
for node, level in levels.items(): | |
level_heights[level] += 1 | |
y_offsets = {level: - (height - 1) / 2 for level, height in level_heights.items()} | |
x_gap = 2 | |
l1_y_gap = 10 | |
for node, level in levels.items(): | |
if level == 1: | |
positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap) | |
else: | |
positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap) | |
y_offsets[level] += 1 | |
# Function to highlight words in a wrapped node string | |
def color_highlighted_words(node, color_map): | |
parts = re.split(r'(\{\{.*?\}\})', node) | |
colored_parts = [] | |
for part in parts: | |
match = re.match(r'\{\{(.*?)\}\}', part) | |
if match: | |
word = match.group(1) | |
color = color_map.get(word, 'black') | |
colored_parts.append(f"<span style='color: {color};'>{word}</span>") | |
else: | |
colored_parts.append(part) | |
return ''.join(colored_parts) | |
# Define the text for each edge | |
edge_texts = [ | |
"Highest Entropy Masking", | |
"Pseudo-random Masking", | |
"Random Masking", | |
"Greedy Sampling", | |
"Temperature Sampling", | |
"Exponential Minimum Sampling", | |
"Inverse Transform Sampling", | |
"Greedy Sampling", | |
"Temperature Sampling", | |
"Exponential Minimum Sampling", | |
"Inverse Transform Sampling", | |
"Greedy Sampling", | |
"Temperature Sampling", | |
"Exponential Minimum Sampling", | |
"Inverse Transform Sampling" | |
] | |
# Create figure | |
fig2 = go.Figure() | |
# Add nodes to the figure | |
for i, node in enumerate(wrapped_nodes): | |
colored_node = color_highlighted_words(node, global_color_map) | |
x, y = positions[i] | |
fig2.add_trace(go.Scatter( | |
x=[-x], # Reflect the x coordinate | |
y=[y], | |
mode='markers', | |
marker=dict(size=10, color='blue'), | |
hoverinfo='none' | |
)) | |
fig2.add_annotation( | |
x=-x, # Reflect the x coordinate | |
y=y, | |
text=colored_node, | |
showarrow=False, | |
xshift=15, | |
align="center", | |
font=dict(size=12), | |
bordercolor='black', | |
borderwidth=1, | |
borderpad=2, | |
bgcolor='white', | |
width=450, | |
height=65 | |
) | |
# Add edges and text above each edge | |
for i, edge in enumerate(edges): | |
x0, y0 = positions[edge[0]] | |
x1, y1 = positions[edge[1]] | |
fig2.add_trace(go.Scatter( | |
x=[-x0, -x1], # Reflect the x coordinates | |
y=[y0, y1], | |
mode='lines', | |
line=dict(color='black', width=1) | |
)) | |
# Calculate the midpoint of the edge | |
mid_x = (-x0 + -x1) / 2 | |
mid_y = (y0 + y1) / 2 | |
# Adjust y position to shift text upwards | |
text_y_position = mid_y + 0.8 # Increase this value to shift the text further upwards | |
# Add text annotation above the edge | |
# Use a fallback text if we exceed the length of edge_texts | |
text = edge_texts[i] if i < len(edge_texts) else f"Edge {i+1}" | |
fig2.add_annotation( | |
x=mid_x, | |
y=text_y_position, | |
text=text, # Use the text specific to this edge | |
showarrow=False, | |
font=dict(size=12), | |
align="center" | |
) | |
fig2.update_layout( | |
showlegend=False, | |
margin=dict(t=20, b=20, l=20, r=20), | |
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
width=1435, # Adjusted width to accommodate more levels | |
height=1000 # Adjusted height to accommodate more levels | |
) | |
return fig2 |