|
import nltk |
|
nltk.download('stopwords') |
|
nltk.download('punkt_tab') |
|
|
|
|
|
import plotly.graph_objs as go |
|
from transformers import pipeline |
|
import random |
|
import gradio as gr |
|
from tree import generate_subplot1, generate_subplot2 |
|
from paraphraser import generate_paraphrase |
|
from lcs import find_common_subsequences, find_common_gram_positions |
|
from highlighter import highlight_common_words, highlight_common_words_dict, reparaphrased_sentences_html |
|
from entailment import analyze_entailment |
|
from masking_methods import mask_non_stopword, mask_non_stopword_pseudorandom, high_entropy_words |
|
from sampling_methods import sample_word |
|
from detectability import SentenceDetectabilityCalculator |
|
from distortion import SentenceDistortionCalculator |
|
from euclidean_distance import SentenceEuclideanDistanceCalculator |
|
from threeD_plot import gen_three_D_plot |
|
|
|
|
|
|
|
def model(prompt): |
|
user_prompt = prompt |
|
paraphrased_sentences = generate_paraphrase(user_prompt) |
|
analyzed_paraphrased_sentences, selected_sentences, discarded_sentences = analyze_entailment(user_prompt, paraphrased_sentences, 0.7) |
|
print(analyze_entailment(user_prompt, paraphrased_sentences, 0.7)) |
|
common_grams = find_common_subsequences(user_prompt, selected_sentences) |
|
subsequences = [subseq for _, subseq in common_grams] |
|
common_grams_position = find_common_gram_positions(selected_sentences, subsequences) |
|
|
|
masked_sentences = [] |
|
masked_words = [] |
|
masked_logits = [] |
|
|
|
for sentence in paraphrased_sentences: |
|
masked_sent, logits, words = mask_non_stopword(sentence) |
|
masked_sentences.append(masked_sent) |
|
masked_words.append(words) |
|
masked_logits.append(logits) |
|
|
|
masked_sent, logits, words = mask_non_stopword_pseudorandom(sentence) |
|
masked_sentences.append(masked_sent) |
|
masked_words.append(words) |
|
masked_logits.append(logits) |
|
|
|
masked_sent, logits, words = high_entropy_words(sentence, common_grams) |
|
masked_sentences.append(masked_sent) |
|
masked_words.append(words) |
|
masked_logits.append(logits) |
|
|
|
sampled_sentences = [] |
|
for masked_sent, words, logits in zip(masked_sentences, masked_words, masked_logits): |
|
sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='inverse_transform', temperature=1.0)) |
|
sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='exponential_minimum', temperature=1.0)) |
|
sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='temperature', temperature=1.0)) |
|
sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='greedy', temperature=1.0)) |
|
|
|
|
|
|
|
|
|
colors = ["red", "blue", "brown", "green"] |
|
|
|
def select_color(): |
|
return random.choice(colors) |
|
|
|
highlight_info = [(word, select_color()) for _, word in common_grams] |
|
|
|
highlighted_user_prompt = highlight_common_words(common_grams, [user_prompt], "Non-melting Points in the User Prompt") |
|
highlighted_accepted_sentences = highlight_common_words_dict(common_grams, selected_sentences, "Paraphrased Sentences") |
|
highlighted_discarded_sentences = highlight_common_words_dict(common_grams, discarded_sentences, "Discarded Sentences") |
|
|
|
trees1 = [] |
|
trees2 = [] |
|
|
|
masked_index = 0 |
|
sampled_index = 0 |
|
|
|
for i, sentence in enumerate(paraphrased_sentences): |
|
next_masked_sentences = masked_sentences[masked_index:masked_index + 3] |
|
next_sampled_sentences = sampled_sentences[sampled_index:sampled_index + 12] |
|
|
|
tree1 = generate_subplot1(sentence, next_masked_sentences, highlight_info, common_grams) |
|
trees1.append(tree1) |
|
|
|
tree2 = generate_subplot2(next_masked_sentences, next_sampled_sentences, highlight_info, common_grams) |
|
trees2.append(tree2) |
|
|
|
masked_index += 3 |
|
sampled_index += 12 |
|
|
|
reparaphrased_sentences = generate_paraphrase(sampled_sentences) |
|
|
|
len_reparaphrased_sentences = len(reparaphrased_sentences) |
|
|
|
reparaphrased_sentences_list = [] |
|
|
|
|
|
for i in range(0, len_reparaphrased_sentences, 10): |
|
|
|
batch = reparaphrased_sentences[i:i + 10] |
|
|
|
|
|
if len(batch) == 10: |
|
|
|
html_block = reparaphrased_sentences_html(batch) |
|
reparaphrased_sentences_list.append(html_block) |
|
|
|
distortion_list = [] |
|
detectability_list = [] |
|
euclidean_dist_list = [] |
|
|
|
distortion_calculator = SentenceDistortionCalculator(user_prompt, reparaphrased_sentences) |
|
distortion_calculator.calculate_all_metrics() |
|
distortion_calculator.normalize_metrics() |
|
distortion_calculator.calculate_combined_distortion() |
|
|
|
distortion = distortion_calculator.get_combined_distortions() |
|
|
|
for each in distortion.items(): |
|
distortion_list.append(each[1]) |
|
|
|
detectability_calculator = SentenceDetectabilityCalculator(user_prompt, reparaphrased_sentences) |
|
detectability_calculator.calculate_all_metrics() |
|
detectability_calculator.normalize_metrics() |
|
detectability_calculator.calculate_combined_detectability() |
|
|
|
detectability = detectability_calculator.get_combined_detectabilities() |
|
|
|
for each in detectability.items(): |
|
detectability_list.append(each[1]) |
|
|
|
euclidean_dist_calculator = SentenceEuclideanDistanceCalculator(user_prompt, reparaphrased_sentences) |
|
euclidean_dist_calculator.calculate_all_metrics() |
|
euclidean_dist_calculator.normalize_metrics() |
|
euclidean_dist_calculator.get_normalized_metrics() |
|
|
|
euclidean_dist = detectability_calculator.get_combined_detectabilities() |
|
|
|
for each in euclidean_dist.items(): |
|
euclidean_dist_list.append(each[1]) |
|
|
|
three_D_plot = gen_three_D_plot(detectability_list, distortion_list, euclidean_dist_list) |
|
|
|
return [highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + trees1 + trees2 + reparaphrased_sentences_list + [three_D_plot] |
|
|
|
|
|
def generate_paraphrase_and_discarded_sentences(prompt): |
|
user_prompt = prompt |
|
paraphrased_sentences = generate_paraphrase(user_prompt) |
|
analyzed_paraphrased_sentences, selected_sentences, discarded_sentences = analyze_entailment(user_prompt, paraphrased_sentences, 0.7) |
|
|
|
|
|
discarded_sentences_with_scores = [ |
|
f"{sentence} (Entailment Score: {score:.2f})" |
|
for sentence, score in discarded_sentences.items() |
|
] |
|
|
|
|
|
paraphrased_sentences_html = highlight_common_words_dict([], selected_sentences, "Paraphrased Sentences") |
|
discarded_sentences_html = "<br>".join(discarded_sentences_with_scores) |
|
|
|
return paraphrased_sentences_html, discarded_sentences_html |
|
|
|
with gr.Blocks(theme=gr.themes.Monochrome()) as demo: |
|
gr.Markdown("# **AIISC Watermarking Model**") |
|
|
|
with gr.Row(): |
|
user_input = gr.Textbox(label="User Prompt") |
|
|
|
with gr.Row(): |
|
submit_button = gr.Button("Submit") |
|
clear_button = gr.Button("Clear") |
|
generate_non_melting_point_button = gr.Button("Generate Non-Melting Point") |
|
paraphrase_discard_button = gr.Button("Paraphrase and Discarded Sentence Generator") |
|
|
|
with gr.Row(): |
|
highlighted_user_prompt = gr.HTML() |
|
|
|
with gr.Row(): |
|
with gr.Tabs(): |
|
with gr.TabItem("Paraphrased Sentences"): |
|
highlighted_accepted_sentences = gr.HTML() |
|
with gr.TabItem("Discarded Sentences"): |
|
highlighted_discarded_sentences = gr.HTML() |
|
|
|
|
|
with gr.Row(): |
|
gr.Markdown("### Where to Watermark?") |
|
with gr.Row(): |
|
with gr.Tabs(): |
|
tree1_tabs = [] |
|
for i in range(10): |
|
with gr.TabItem(f"Sentence {i+1}"): |
|
tree1 = gr.Plot() |
|
tree1_tabs.append(tree1) |
|
|
|
with gr.Row(): |
|
gr.Markdown("### How to Watermark?") |
|
with gr.Row(): |
|
with gr.Tabs(): |
|
tree2_tabs = [] |
|
for i in range(10): |
|
with gr.TabItem(f"Sentence {i+1}"): |
|
tree2 = gr.Plot() |
|
tree2_tabs.append(tree2) |
|
|
|
|
|
with gr.Row(): |
|
gr.Markdown("### Re-paraphrased Sentences") |
|
|
|
|
|
with gr.Row(): |
|
with gr.Tabs(): |
|
reparaphrased_sentences_tabs = [] |
|
for i in range(120): |
|
with gr.TabItem(f"Sentence {i+1}"): |
|
reparaphrased_sent_html = gr.HTML() |
|
reparaphrased_sentences_tabs.append(reparaphrased_sent_html) |
|
|
|
with gr.Row(): |
|
gr.Markdown("### 3D Plot for Sweet Spot") |
|
with gr.Row(): |
|
three_D_plot = gr.Plot() |
|
|
|
|
|
def generate_non_melting_points_only(prompt): |
|
user_prompt = prompt |
|
paraphrased_sentences = generate_paraphrase(user_prompt) |
|
analyzed_paraphrased_sentences, selected_sentences, discarded_sentences = analyze_entailment(user_prompt, paraphrased_sentences, 0.7) |
|
common_grams = find_common_subsequences(user_prompt, selected_sentences) |
|
highlighted_user_prompt = highlight_common_words(common_grams, [user_prompt], "Non-melting Points in the User Prompt") |
|
return highlighted_user_prompt |
|
|
|
|
|
submit_button.click( |
|
model, |
|
inputs=user_input, |
|
outputs=[highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + tree1_tabs + tree2_tabs + reparaphrased_sentences_tabs + [three_D_plot] |
|
) |
|
generate_non_melting_point_button.click( |
|
generate_non_melting_points_only, |
|
inputs=user_input, |
|
outputs=highlighted_user_prompt |
|
) |
|
|
|
paraphrase_discard_button.click( |
|
generate_paraphrase_and_discarded_sentences, |
|
inputs=user_input, |
|
outputs=[highlighted_accepted_sentences, highlighted_discarded_sentences] |
|
) |
|
clear_button.click(lambda: "", inputs=None, outputs=user_input) |
|
clear_button.click(lambda: "", inputs=None, outputs=[highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + tree1_tabs + tree2_tabs + reparaphrased_sentences_tabs + [three_D_plot]) |
|
|
|
demo.launch(share=True) |
|
|
|
|