Pijush2023's picture
Update app.py
c020ce7 verified
import nltk
nltk.download('stopwords')
nltk.download('punkt_tab')
# from transformers import AutoTokenizer
# from transformers import AutoModelForSeq2SeqLM
import plotly.graph_objs as go
from transformers import pipeline
import random
import gradio as gr
from tree import generate_subplot1, generate_subplot2
from paraphraser import generate_paraphrase
from lcs import find_common_subsequences, find_common_gram_positions
from highlighter import highlight_common_words, highlight_common_words_dict, reparaphrased_sentences_html
from entailment import analyze_entailment
from masking_methods import mask_non_stopword, mask_non_stopword_pseudorandom, high_entropy_words
from sampling_methods import sample_word
from detectability import SentenceDetectabilityCalculator
from distortion import SentenceDistortionCalculator
from euclidean_distance import SentenceEuclideanDistanceCalculator
from threeD_plot import gen_three_D_plot
# Function for the Gradio interface
def model(prompt):
user_prompt = prompt
paraphrased_sentences = generate_paraphrase(user_prompt)
analyzed_paraphrased_sentences, selected_sentences, discarded_sentences = analyze_entailment(user_prompt, paraphrased_sentences, 0.7)
print(analyze_entailment(user_prompt, paraphrased_sentences, 0.7))
common_grams = find_common_subsequences(user_prompt, selected_sentences)
subsequences = [subseq for _, subseq in common_grams]
common_grams_position = find_common_gram_positions(selected_sentences, subsequences)
masked_sentences = []
masked_words = []
masked_logits = []
for sentence in paraphrased_sentences:
masked_sent, logits, words = mask_non_stopword(sentence)
masked_sentences.append(masked_sent)
masked_words.append(words)
masked_logits.append(logits)
masked_sent, logits, words = mask_non_stopword_pseudorandom(sentence)
masked_sentences.append(masked_sent)
masked_words.append(words)
masked_logits.append(logits)
masked_sent, logits, words = high_entropy_words(sentence, common_grams)
masked_sentences.append(masked_sent)
masked_words.append(words)
masked_logits.append(logits)
sampled_sentences = []
for masked_sent, words, logits in zip(masked_sentences, masked_words, masked_logits):
sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='inverse_transform', temperature=1.0))
sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='exponential_minimum', temperature=1.0))
sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='temperature', temperature=1.0))
sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='greedy', temperature=1.0))
colors = ["red", "blue", "brown", "green"]
def select_color():
return random.choice(colors)
highlight_info = [(word, select_color()) for _, word in common_grams]
highlighted_user_prompt = highlight_common_words(common_grams, [user_prompt], "Non-melting Points in the User Prompt")
highlighted_accepted_sentences = highlight_common_words_dict(common_grams, selected_sentences, "Paraphrased Sentences")
highlighted_discarded_sentences = highlight_common_words_dict(common_grams, discarded_sentences, "Discarded Sentences")
trees1 = []
trees2 = []
masked_index = 0
sampled_index = 0
for i, sentence in enumerate(paraphrased_sentences):
next_masked_sentences = masked_sentences[masked_index:masked_index + 3]
next_sampled_sentences = sampled_sentences[sampled_index:sampled_index + 12]
tree1 = generate_subplot1(sentence, next_masked_sentences, highlight_info, common_grams)
trees1.append(tree1)
tree2 = generate_subplot2(next_masked_sentences, next_sampled_sentences, highlight_info, common_grams)
trees2.append(tree2)
masked_index += 3
sampled_index += 12
reparaphrased_sentences = generate_paraphrase(sampled_sentences)
len_reparaphrased_sentences = len(reparaphrased_sentences)
reparaphrased_sentences_list = []
# Process the sentences in batches of 10
for i in range(0, len_reparaphrased_sentences, 10):
# Get the current batch of 10 sentences
batch = reparaphrased_sentences[i:i + 10]
# Check if the batch has exactly 10 sentences
if len(batch) == 10:
# Call the display_sentences function and store the result in the list
html_block = reparaphrased_sentences_html(batch)
reparaphrased_sentences_list.append(html_block)
distortion_list = []
detectability_list = []
euclidean_dist_list = []
distortion_calculator = SentenceDistortionCalculator(user_prompt, reparaphrased_sentences)
distortion_calculator.calculate_all_metrics()
distortion_calculator.normalize_metrics()
distortion_calculator.calculate_combined_distortion()
distortion = distortion_calculator.get_combined_distortions()
for each in distortion.items():
distortion_list.append(each[1])
detectability_calculator = SentenceDetectabilityCalculator(user_prompt, reparaphrased_sentences)
detectability_calculator.calculate_all_metrics()
detectability_calculator.normalize_metrics()
detectability_calculator.calculate_combined_detectability()
detectability = detectability_calculator.get_combined_detectabilities()
for each in detectability.items():
detectability_list.append(each[1])
euclidean_dist_calculator = SentenceEuclideanDistanceCalculator(user_prompt, reparaphrased_sentences)
euclidean_dist_calculator.calculate_all_metrics()
euclidean_dist_calculator.normalize_metrics()
euclidean_dist_calculator.get_normalized_metrics()
euclidean_dist = detectability_calculator.get_combined_detectabilities()
for each in euclidean_dist.items():
euclidean_dist_list.append(each[1])
three_D_plot = gen_three_D_plot(detectability_list, distortion_list, euclidean_dist_list)
return [highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + trees1 + trees2 + reparaphrased_sentences_list + [three_D_plot]
# Logic for the new "Paraphrase and Discarded Sentence Generator" button
def generate_paraphrase_and_discarded_sentences(prompt):
user_prompt = prompt
paraphrased_sentences = generate_paraphrase(user_prompt)
analyzed_paraphrased_sentences, selected_sentences, discarded_sentences = analyze_entailment(user_prompt, paraphrased_sentences, 0.7)
# Combine discarded sentences with their entailment scores
discarded_sentences_with_scores = [
f"{sentence} (Entailment Score: {score:.2f})"
for sentence, score in discarded_sentences.items()
]
# Prepare paraphrased sentences for display
paraphrased_sentences_html = highlight_common_words_dict([], selected_sentences, "Paraphrased Sentences")
discarded_sentences_html = "<br>".join(discarded_sentences_with_scores)
return paraphrased_sentences_html, discarded_sentences_html
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
gr.Markdown("# **AIISC Watermarking Model**")
with gr.Row():
user_input = gr.Textbox(label="User Prompt")
with gr.Row():
submit_button = gr.Button("Submit")
clear_button = gr.Button("Clear")
generate_non_melting_point_button = gr.Button("Generate Non-Melting Point") # New button
paraphrase_discard_button = gr.Button("Paraphrase and Discarded Sentence Generator")
with gr.Row():
highlighted_user_prompt = gr.HTML()
with gr.Row():
with gr.Tabs():
with gr.TabItem("Paraphrased Sentences"):
highlighted_accepted_sentences = gr.HTML()
with gr.TabItem("Discarded Sentences"):
highlighted_discarded_sentences = gr.HTML()
# Adding labels before the tree plots
with gr.Row():
gr.Markdown("### Where to Watermark?") # Label for masked sentences trees
with gr.Row():
with gr.Tabs():
tree1_tabs = []
for i in range(10): # Adjust this range according to the number of trees
with gr.TabItem(f"Sentence {i+1}"):
tree1 = gr.Plot()
tree1_tabs.append(tree1)
with gr.Row():
gr.Markdown("### How to Watermark?") # Label for sampled sentences trees
with gr.Row():
with gr.Tabs():
tree2_tabs = []
for i in range(10): # Adjust this range according to the number of trees
with gr.TabItem(f"Sentence {i+1}"):
tree2 = gr.Plot()
tree2_tabs.append(tree2)
# Adding the "Re-paraphrased Sentences" section
with gr.Row():
gr.Markdown("### Re-paraphrased Sentences") # Label for re-paraphrased sentences
# Adding tabs for the re-paraphrased sentences
with gr.Row():
with gr.Tabs():
reparaphrased_sentences_tabs = []
for i in range(120): # 120 tabs for 120 batches of sentences
with gr.TabItem(f"Sentence {i+1}"):
reparaphrased_sent_html = gr.HTML() # Placeholder for each batch
reparaphrased_sentences_tabs.append(reparaphrased_sent_html)
with gr.Row():
gr.Markdown("### 3D Plot for Sweet Spot")
with gr.Row():
three_D_plot = gr.Plot()
# Logic for the new button
def generate_non_melting_points_only(prompt):
user_prompt = prompt
paraphrased_sentences = generate_paraphrase(user_prompt)
analyzed_paraphrased_sentences, selected_sentences, discarded_sentences = analyze_entailment(user_prompt, paraphrased_sentences, 0.7)
common_grams = find_common_subsequences(user_prompt, selected_sentences)
highlighted_user_prompt = highlight_common_words(common_grams, [user_prompt], "Non-melting Points in the User Prompt")
return highlighted_user_prompt
# Connect buttons to functions
submit_button.click(
model,
inputs=user_input,
outputs=[highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + tree1_tabs + tree2_tabs + reparaphrased_sentences_tabs + [three_D_plot]
)
generate_non_melting_point_button.click(
generate_non_melting_points_only,
inputs=user_input,
outputs=highlighted_user_prompt
)
paraphrase_discard_button.click(
generate_paraphrase_and_discarded_sentences,
inputs=user_input,
outputs=[highlighted_accepted_sentences, highlighted_discarded_sentences]
)
clear_button.click(lambda: "", inputs=None, outputs=user_input)
clear_button.click(lambda: "", inputs=None, outputs=[highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + tree1_tabs + tree2_tabs + reparaphrased_sentences_tabs + [three_D_plot])
demo.launch(share=True)