jgyasu commited on
Commit
38620f7
1 Parent(s): 00be330

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -27
app.py CHANGED
@@ -6,12 +6,11 @@ Automatically generated by Colab.
6
  Original file is located at
7
  https://colab.research.google.com/drive/1pFGR4uvXMMWVJFQeFmn--arumSxqa5Yy
8
  """
9
- import os
10
- os.system('apt install graphviz')
11
 
12
  from transformers import AutoTokenizer
13
  from transformers import AutoModelForSeq2SeqLM
14
- import plotly.graph_objects as go
 
15
  from transformers import pipeline
16
  import re
17
  import time
@@ -436,32 +435,119 @@ def mask(sentence):
436
  return masked_sentences
437
 
438
  # Function to generate the tree and return the Graphviz source
439
- def generate_tree(original_sentence: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
440
  paraphrased_sentences = generate_paraphrase(original_sentence)
441
  first_paraphrased_sentence = paraphrased_sentences[0]
442
  masked_sentence = mask_non_stopword(first_paraphrased_sentence)
443
  masked_versions = mask(masked_sentence)
444
- dot = graphviz.Digraph()
445
- dot.attr(rankdir='LR', size='10,10', dpi=' 2743')
446
-
447
- existing_nodes = set()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448
 
449
- def add_paraphrases(parent, paraphrases):
450
- if parent not in existing_nodes:
451
- dot.node(parent, parent, shape='box')
452
- existing_nodes.add(parent)
453
 
454
- for paraphrase in paraphrases:
455
- if paraphrase not in existing_nodes:
456
- dot.node(paraphrase, paraphrase, shape='box')
457
- existing_nodes.add(paraphrase)
458
- dot.edge(parent, paraphrase)
459
 
460
- add_paraphrases(original_sentence, paraphrased_sentences) #whenever a new branch is to be created call this function along with the original sentence and the list of its paraphrases
461
- add_paraphrases(paraphrased_sentences[0], masked_versions)
462
 
463
- graph_path = dot.render(filename='paraphrase_tree_dynamic', format='png')
464
- return masked_sentence, masked_versions, dot.source
465
 
466
  # Function for the Gradio interface
467
  def model(prompt):
@@ -472,11 +558,8 @@ def model(prompt):
472
  for i in range(len(common_subs)):
473
  common_subs[i]["Paraphrased Sentence"] = res[i]
474
  result = highlight_phrases_with_colors(res, common_grams)
475
- masked_sentence, masked_versions, tree_source = generate_tree(sentence)
476
- graph = graphviz.Source(tree_source)
477
- png_content = graph.render(filename='paraphrase_tree_dynamic', format='png')
478
- # tree = f'<div style="width: 100%; overflow-x: auto;">{svg_content}</div>'
479
- return generated, generated, result, masked_sentence, masked_versions, png_content
480
 
481
  with gr.Blocks(theme = gr.themes.Monochrome()) as demo:
482
  gr.Markdown("# Paraphrases the Text and Highlights the Non-melting Points")
@@ -504,7 +587,7 @@ with gr.Blocks(theme = gr.themes.Monochrome()) as demo:
504
  masked_versions = gr.Textbox(label="Sentence Generated by Masking Model")
505
 
506
  with gr.Row():
507
- tree = gr.Image(label="Paraphrase Tree")
508
 
509
  submit_button.click(model, inputs=user_input, outputs=[ai_output, selected_sentence, html_output, masked_sentence, masked_versions, tree])
510
  clear_button.click(lambda: "", inputs=None, outputs=user_input)
 
6
  Original file is located at
7
  https://colab.research.google.com/drive/1pFGR4uvXMMWVJFQeFmn--arumSxqa5Yy
8
  """
 
 
9
 
10
  from transformers import AutoTokenizer
11
  from transformers import AutoModelForSeq2SeqLM
12
+ import plotly.graph_objs as go
13
+ import textwrap
14
  from transformers import pipeline
15
  import re
16
  import time
 
435
  return masked_sentences
436
 
437
  # Function to generate the tree and return the Graphviz source
438
+ # def generate_tree(original_sentence: str) -> str:
439
+ # paraphrased_sentences = generate_paraphrase(original_sentence)
440
+ # first_paraphrased_sentence = paraphrased_sentences[0]
441
+ # masked_sentence = mask_non_stopword(first_paraphrased_sentence)
442
+ # masked_versions = mask(masked_sentence)
443
+ # dot = graphviz.Digraph()
444
+ # dot.attr(rankdir='LR', size='10,10', dpi=' 2743')
445
+
446
+ # existing_nodes = set()
447
+
448
+ # def add_paraphrases(parent, paraphrases):
449
+ # if parent not in existing_nodes:
450
+ # dot.node(parent, parent, shape='box')
451
+ # existing_nodes.add(parent)
452
+
453
+ # for paraphrase in paraphrases:
454
+ # if paraphrase not in existing_nodes:
455
+ # dot.node(paraphrase, paraphrase, shape='box')
456
+ # existing_nodes.add(paraphrase)
457
+ # dot.edge(parent, paraphrase)
458
+
459
+ # add_paraphrases(original_sentence, paraphrased_sentences) #whenever a new branch is to be created call this function along with the original sentence and the list of its paraphrases
460
+ # add_paraphrases(paraphrased_sentences[0], masked_versions)
461
+
462
+ # graph_path = dot.render(filename='paraphrase_tree_dynamic', format='png')
463
+ # return masked_sentence, masked_versions, dot.source
464
+
465
+
466
+
467
+ #plotly tree
468
+ def generate_plot(original_sentence):
469
  paraphrased_sentences = generate_paraphrase(original_sentence)
470
  first_paraphrased_sentence = paraphrased_sentences[0]
471
  masked_sentence = mask_non_stopword(first_paraphrased_sentence)
472
  masked_versions = mask(masked_sentence)
473
+ nodes = []
474
+ nodes.append(original_sentence)
475
+ nodes.extend(paraphrased_sentences)
476
+ nodes.extend(masked_versions)
477
+ edges = [
478
+ (0, 1),
479
+ (0, 2),
480
+ (0, 3),
481
+ (0, 4),
482
+ (0, 5),
483
+ (1, 6),
484
+ (1, 7),
485
+ (1, 8),
486
+ (1, 9)
487
+ ]
488
+ wrapped_nodes = ['<br>'.join(textwrap.wrap(node, width=30)) for node in nodes]
489
+ fig = go.Figure()
490
+ positions = {
491
+ 0: (0, 0),
492
+ 1: (-4, -4),
493
+ 2: (-2, -4),
494
+ 3: (0, -4),
495
+ 4: (2, -4),
496
+ 5: (4, -4),
497
+ 6: (-4.5, -8),
498
+ 7: (-3, -8),
499
+ 8: (-1.5, -8),
500
+ 9: (0, -8),
501
+ 10: (2, -8) # Example addition for index 10
502
+ }
503
+
504
+ for i, node in enumerate(wrapped_nodes):
505
+ x, y = positions[i]
506
+ fig.add_trace(go.Scatter(
507
+ x=[x],
508
+ y=[y],
509
+ mode='markers',
510
+ marker=dict(size=10, color='blue'),
511
+ hoverinfo='none'
512
+ ))
513
+ fig.add_annotation(
514
+ x=x,
515
+ y=y,
516
+ text=node,
517
+ showarrow=False,
518
+ yshift=20, # Adjust the y-shift value to avoid overlap
519
+ align="center",
520
+ font=dict(size=10),
521
+ bordercolor='black',
522
+ borderwidth=1,
523
+ borderpad=4,
524
+ bgcolor='white',
525
+ width=200
526
+ )
527
+
528
+ for edge in edges:
529
+ x0, y0 = positions[edge[0]]
530
+ x1, y1 = positions[edge[1]]
531
+ fig.add_trace(go.Scatter(
532
+ x=[x0, x1],
533
+ y=[y0, y1],
534
+ mode='lines',
535
+ line=dict(color='black', width=2)
536
+ ))
537
+
538
+ fig.update_layout(
539
+ showlegend=False,
540
+ margin=dict(t=50, b=50, l=50, r=50),
541
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
542
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
543
+ width=1470,
544
+ height=800 # Increase height to provide more space
545
+ )
546
 
547
+ return masked_sentence, masked_versions, fig
 
 
 
548
 
 
 
 
 
 
549
 
 
 
550
 
 
 
551
 
552
  # Function for the Gradio interface
553
  def model(prompt):
 
558
  for i in range(len(common_subs)):
559
  common_subs[i]["Paraphrased Sentence"] = res[i]
560
  result = highlight_phrases_with_colors(res, common_grams)
561
+ masked_sentence, masked_versions, tree = generate_plot(sentence)
562
+ return generated, generated, result, masked_sentence, masked_versions, tree
 
 
 
563
 
564
  with gr.Blocks(theme = gr.themes.Monochrome()) as demo:
565
  gr.Markdown("# Paraphrases the Text and Highlights the Non-melting Points")
 
587
  masked_versions = gr.Textbox(label="Sentence Generated by Masking Model")
588
 
589
  with gr.Row():
590
+ tree = gr.Plot()
591
 
592
  submit_button.click(model, inputs=user_input, outputs=[ai_output, selected_sentence, html_output, masked_sentence, masked_versions, tree])
593
  clear_button.click(lambda: "", inputs=None, outputs=user_input)