jgyasu commited on
Commit
2493822
1 Parent(s): ee305a4

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. app.py +1 -0
  2. tree.py +198 -3
app.py CHANGED
@@ -90,6 +90,7 @@ def model(prompt):
90
  highlighted_accepted_sentences = highlight_common_words_dict(common_grams, selected_sentences, "Paraphrased Sentences")
91
  highlighted_discarded_sentences = highlight_common_words_dict(common_grams, discarded_sentences, "Discarded Sentences")
92
 
 
93
  # Initialize empty list to hold the trees
94
  trees = []
95
 
 
90
  highlighted_accepted_sentences = highlight_common_words_dict(common_grams, selected_sentences, "Paraphrased Sentences")
91
  highlighted_discarded_sentences = highlight_common_words_dict(common_grams, discarded_sentences, "Discarded Sentences")
92
 
93
+
94
  # Initialize empty list to hold the trees
95
  trees = []
96
 
tree.py CHANGED
@@ -1,3 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import plotly.graph_objects as go
2
  import textwrap
3
  import re
@@ -105,6 +263,25 @@ def generate_subplot(paraphrased_sentence, scheme_sentences, sampled_sentence, h
105
  colored_parts.append(part)
106
  return ''.join(colored_parts)
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  # Create figure
109
  fig = go.Figure()
110
 
@@ -134,8 +311,8 @@ def generate_subplot(paraphrased_sentence, scheme_sentences, sampled_sentence, h
134
  width=150
135
  )
136
 
137
- # Add edges to the figure
138
- for edge in edges:
139
  x0, y0 = positions[edge[0]]
140
  x1, y1 = positions[edge[1]]
141
  fig.add_trace(go.Scatter(
@@ -145,6 +322,23 @@ def generate_subplot(paraphrased_sentence, scheme_sentences, sampled_sentence, h
145
  line=dict(color='black', width=1)
146
  ))
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  fig.update_layout(
149
  showlegend=False,
150
  margin=dict(t=20, b=20, l=20, r=20),
@@ -154,4 +348,5 @@ def generate_subplot(paraphrased_sentence, scheme_sentences, sampled_sentence, h
154
  height=1000 # Adjusted height to accommodate more levels
155
  )
156
 
157
- return fig
 
 
1
+ # import plotly.graph_objects as go
2
+ # import textwrap
3
+ # import re
4
+ # from collections import defaultdict
5
+
6
+ # def generate_subplot(paraphrased_sentence, scheme_sentences, sampled_sentence, highlight_info):
7
+ # # Combine nodes into one list with appropriate labels
8
+ # nodes = [paraphrased_sentence] + scheme_sentences + sampled_sentence
9
+ # nodes[0] += ' L0' # Paraphrased sentence is level 0
10
+ # para_len = len(scheme_sentences)
11
+ # for i in range(1, para_len + 1):
12
+ # nodes[i] += ' L1' # Scheme sentences are level 1
13
+ # for i in range(para_len + 1, len(nodes)):
14
+ # nodes[i] += ' L2' # Sampled sentences are level 2
15
+
16
+ # # Define the highlight_words function
17
+ # def highlight_words(sentence, color_map):
18
+ # for word, color in color_map.items():
19
+ # sentence = re.sub(f"\\b{word}\\b", f"{{{{{word}}}}}", sentence, flags=re.IGNORECASE)
20
+ # return sentence
21
+
22
+ # # Clean and wrap nodes, and highlight specified words globally
23
+ # cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes]
24
+ # global_color_map = dict(highlight_info)
25
+ # highlighted_nodes = [highlight_words(node, global_color_map) for node in cleaned_nodes]
26
+ # wrapped_nodes = ['<br>'.join(textwrap.wrap(node, width=30)) for node in highlighted_nodes]
27
+
28
+ # # Function to determine tree levels and create edges dynamically
29
+ # def get_levels_and_edges(nodes):
30
+ # levels = {}
31
+ # edges = []
32
+ # for i, node in enumerate(nodes):
33
+ # level = int(node.split()[-1][1])
34
+ # levels[i] = level
35
+
36
+ # # Add edges from L0 to all L1 nodes
37
+ # root_node = next(i for i, level in levels.items() if level == 0)
38
+ # for i, level in levels.items():
39
+ # if level == 1:
40
+ # edges.append((root_node, i))
41
+
42
+ # # Add edges from each L1 node to their corresponding L2 nodes
43
+ # l1_indices = [i for i, level in levels.items() if level == 1]
44
+ # l2_indices = [i for i, level in levels.items() if level == 2]
45
+
46
+ # for i, l1_node in enumerate(l1_indices):
47
+ # l2_start = i * 4
48
+ # for j in range(4):
49
+ # l2_index = l2_start + j
50
+ # if l2_index < len(l2_indices):
51
+ # edges.append((l1_node, l2_indices[l2_index]))
52
+
53
+ # # Add edges from each L2 node to their corresponding L3 nodes
54
+ # l2_indices = [i for i, level in levels.items() if level == 2]
55
+ # l3_indices = [i for i, level in levels.items() if level == 3]
56
+
57
+ # l2_to_l3_map = {l2_node: [] for l2_node in l2_indices}
58
+
59
+ # # Map L3 nodes to L2 nodes
60
+ # for l3_node in l3_indices:
61
+ # l2_node = l3_node % len(l2_indices)
62
+ # l2_to_l3_map[l2_indices[l2_node]].append(l3_node)
63
+
64
+ # for l2_node, l3_nodes in l2_to_l3_map.items():
65
+ # for l3_node in l3_nodes:
66
+ # edges.append((l2_node, l3_node))
67
+
68
+ # return levels, edges
69
+
70
+ # # Get levels and dynamic edges
71
+ # levels, edges = get_levels_and_edges(nodes)
72
+ # max_level = max(levels.values(), default=0)
73
+
74
+ # # Calculate positions
75
+ # positions = {}
76
+ # level_heights = defaultdict(int)
77
+ # for node, level in levels.items():
78
+ # level_heights[level] += 1
79
+
80
+ # y_offsets = {level: - (height - 1) / 2 for level, height in level_heights.items()}
81
+ # x_gap = 2
82
+ # l1_y_gap = 10
83
+ # l2_y_gap = 6
84
+
85
+ # for node, level in levels.items():
86
+ # if level == 1:
87
+ # positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap)
88
+ # elif level == 2:
89
+ # positions[node] = (-level * x_gap, y_offsets[level] * l2_y_gap)
90
+ # else:
91
+ # positions[node] = (-level * x_gap, y_offsets[level] * l2_y_gap)
92
+ # y_offsets[level] += 1
93
+
94
+ # # Function to highlight words in a wrapped node string
95
+ # def color_highlighted_words(node, color_map):
96
+ # parts = re.split(r'(\{\{.*?\}\})', node)
97
+ # colored_parts = []
98
+ # for part in parts:
99
+ # match = re.match(r'\{\{(.*?)\}\}', part)
100
+ # if match:
101
+ # word = match.group(1)
102
+ # color = color_map.get(word, 'black')
103
+ # colored_parts.append(f"<span style='color: {color};'>{word}</span>")
104
+ # else:
105
+ # colored_parts.append(part)
106
+ # return ''.join(colored_parts)
107
+
108
+ # # Create figure
109
+ # fig = go.Figure()
110
+
111
+ # # Add nodes to the figure
112
+ # for i, node in enumerate(wrapped_nodes):
113
+ # colored_node = color_highlighted_words(node, global_color_map)
114
+ # x, y = positions[i]
115
+ # fig.add_trace(go.Scatter(
116
+ # x=[-x], # Reflect the x coordinate
117
+ # y=[y],
118
+ # mode='markers',
119
+ # marker=dict(size=10, color='blue'),
120
+ # hoverinfo='none'
121
+ # ))
122
+ # fig.add_annotation(
123
+ # x=-x, # Reflect the x coordinate
124
+ # y=y,
125
+ # text=colored_node,
126
+ # showarrow=False,
127
+ # xshift=15,
128
+ # align="center",
129
+ # font=dict(size=8),
130
+ # bordercolor='black',
131
+ # borderwidth=1,
132
+ # borderpad=2,
133
+ # bgcolor='white',
134
+ # width=150
135
+ # )
136
+
137
+ # # Add edges to the figure
138
+ # for edge in edges:
139
+ # x0, y0 = positions[edge[0]]
140
+ # x1, y1 = positions[edge[1]]
141
+ # fig.add_trace(go.Scatter(
142
+ # x=[-x0, -x1], # Reflect the x coordinates
143
+ # y=[y0, y1],
144
+ # mode='lines',
145
+ # line=dict(color='black', width=1)
146
+ # ))
147
+
148
+ # fig.update_layout(
149
+ # showlegend=False,
150
+ # margin=dict(t=20, b=20, l=20, r=20),
151
+ # xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
152
+ # yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
153
+ # width=1200, # Adjusted width to accommodate more levels
154
+ # height=1000 # Adjusted height to accommodate more levels
155
+ # )
156
+
157
+ # return fig
158
+
159
  import plotly.graph_objects as go
160
  import textwrap
161
  import re
 
263
  colored_parts.append(part)
264
  return ''.join(colored_parts)
265
 
266
+ # Define the text for each edge
267
+ edge_texts = [
268
+ "Highest Entropy Masking",
269
+ "Pseudo-random Masking",
270
+ "Random Masking",
271
+ "Greedy Sampling",
272
+ "Temperature Sampling",
273
+ "Exponential Minimum Sampling",
274
+ "Inverse Transform Sampling",
275
+ "Greedy Sampling",
276
+ "Temperature Sampling",
277
+ "Exponential Minimum Sampling",
278
+ "Inverse Transform Sampling",
279
+ "Greedy Sampling",
280
+ "Temperature Sampling",
281
+ "Exponential Minimum Sampling",
282
+ "Inverse Transform Sampling"
283
+ ]
284
+
285
  # Create figure
286
  fig = go.Figure()
287
 
 
311
  width=150
312
  )
313
 
314
+ # Add edges and text above each edge
315
+ for i, edge in enumerate(edges):
316
  x0, y0 = positions[edge[0]]
317
  x1, y1 = positions[edge[1]]
318
  fig.add_trace(go.Scatter(
 
322
  line=dict(color='black', width=1)
323
  ))
324
 
325
+ # Calculate the midpoint of the edge
326
+ mid_x = (-x0 + -x1) / 2
327
+ mid_y = (y0 + y1) / 2
328
+
329
+ # Adjust y position to shift text upwards
330
+ text_y_position = mid_y + 0.8 # Increase this value to shift the text further upwards
331
+
332
+ # Add text annotation above the edge
333
+ fig.add_annotation(
334
+ x=mid_x,
335
+ y=text_y_position,
336
+ text=edge_texts[i], # Use the text specific to this edge
337
+ showarrow=False,
338
+ font=dict(size=10),
339
+ align="center"
340
+ )
341
+
342
  fig.update_layout(
343
  showlegend=False,
344
  margin=dict(t=20, b=20, l=20, r=20),
 
348
  height=1000 # Adjusted height to accommodate more levels
349
  )
350
 
351
+ return fig
352
+