Spaces:
Running
Running
import nltk | |
nltk.download('stopwords') | |
import random | |
import gradio as gr | |
import time | |
from tree import generate_subplot1, generate_subplot2 | |
from paraphraser import generate_paraphrase | |
# from lcs import find_common_subsequences, find_common_gram_positions | |
from highlighter import highlight_common_words, highlight_common_words_dict, reparaphrased_sentences_html | |
from entailment import analyze_entailment | |
from masking_methods import mask_non_stopword, mask_non_stopword_pseudorandom, high_entropy_words | |
from sampling_methods import sample_word | |
from detectability import SentenceDetectabilityCalculator | |
from distortion import SentenceDistortionCalculator | |
from euclidean_distance import SentenceEuclideanDistanceCalculator | |
from threeD_plot import gen_three_D_plot | |
from twokenize import tokenize_sentences, tokenize_sentence | |
from non_melting_points import find_non_melting_points | |
class WatermarkingPipeline: | |
def __init__(self): | |
# Existing initialization code... | |
self.user_prompt = None | |
self.paraphrased_sentences = None | |
self.analyzed_paraphrased_sentences = None | |
self.selected_sentences = None | |
self.discarded_sentences = None | |
self.common_grams = None | |
self.subsequences = None | |
self.common_grams_position = None | |
self.masked_sentences = None | |
self.masked_words = None | |
self.masked_logits = None | |
self.sampled_sentences = None | |
self.reparaphrased_sentences = None | |
self.distortion_list = None | |
self.detectability_list = None | |
self.euclidean_dist_list = None | |
def step1_paraphrasing(self, prompt, threshold=0.7): | |
start_time = time.time() | |
self.user_prompt = prompt | |
self.paraphrased_sentences = generate_paraphrase(prompt) | |
if self.paraphrased_sentences is None: | |
return "Error in generating paraphrases", "Error: Could not complete step" | |
self.analyzed_paraphrased_sentences, self.selected_sentences, self.discarded_sentences = \ | |
analyze_entailment(self.user_prompt, self.paraphrased_sentences, threshold) | |
self.user_prompt_tokenized = tokenize_sentence(self.user_prompt) | |
self.selected_sentences_tokenized = tokenize_sentences(self.selected_sentences) | |
self.discarded_sentences_tokenized = tokenize_sentences(self.discarded_sentences) | |
all_tokenized_sentences = [] | |
all_tokenized_sentences.append(self.user_prompt_tokenized) | |
all_tokenized_sentences.extend(self.selected_sentences_tokenized) | |
self.common_grams = find_non_melting_points(all_tokenized_sentences) | |
highlighted_user_prompt = highlight_common_words( | |
self.common_grams, [self.user_prompt], "Highlighted LCS in the User Prompt" | |
) | |
highlighted_accepted_sentences = highlight_common_words_dict( | |
self.common_grams, self.selected_sentences, "Paraphrased Sentences" | |
) | |
highlighted_discarded_sentences = highlight_common_words_dict( | |
self.common_grams, self.discarded_sentences, "Discarded Sentences" | |
) | |
execution_time = time.time() - start_time | |
time_info = f"Step 1 completed in {execution_time:.2f} seconds" | |
return [ | |
highlighted_user_prompt, | |
highlighted_accepted_sentences, | |
highlighted_discarded_sentences, | |
time_info | |
] | |
def step2_masking(self): | |
start_time = time.time() | |
if self.paraphrased_sentences is None: | |
return [None] * 10 + ["Error: Please complete step 1 first"] | |
# Existing step2 code... | |
self.masked_sentences = [] | |
self.masked_words = [] | |
self.masked_logits = [] | |
for sentence in self.paraphrased_sentences: | |
for mask_func in [mask_non_stopword, mask_non_stopword_pseudorandom, | |
lambda s: high_entropy_words(s, self.common_grams)]: | |
masked_sent, logits, words = mask_func(sentence) | |
self.masked_sentences.append(masked_sent) | |
self.masked_words.append(words) | |
self.masked_logits.append(logits) | |
trees = [] | |
masked_index = 0 | |
colors = ["red", "blue", "brown", "green"] | |
highlight_info = [(word, random.choice(colors)) for _, word in self.common_grams] | |
for i, sentence in enumerate(self.paraphrased_sentences): | |
next_masked = self.masked_sentences[masked_index:masked_index + 3] | |
tree = generate_subplot1(sentence, next_masked, highlight_info, self.common_grams) | |
trees.append(tree) | |
masked_index += 3 | |
execution_time = time.time() - start_time | |
time_info = f"Step 2 completed in {execution_time:.2f} seconds" | |
return trees + [time_info] | |
def step3_sampling(self): | |
start_time = time.time() | |
if self.masked_sentences is None: | |
return [None] * 10 + ["Error: Please complete step 2 first"] | |
# Existing step3 code... | |
self.sampled_sentences = [] | |
trees = [] | |
colors = ["red", "blue", "brown", "green"] | |
highlight_info = [(word, random.choice(colors)) for _, word in self.common_grams] | |
sampling_techniques = [ | |
('inverse_transform', 1.0), | |
('exponential_minimum', 1.0), | |
('temperature', 1.0), | |
('greedy', 1.0) | |
] | |
masked_index = 0 | |
while masked_index < len(self.masked_sentences): | |
current_masked = self.masked_sentences[masked_index:masked_index + 3] | |
current_words = self.masked_words[masked_index:masked_index + 3] | |
current_logits = self.masked_logits[masked_index:masked_index + 3] | |
batch_samples = [] | |
for masked_sent, words, logits in zip(current_masked, current_words, current_logits): | |
for technique, temp in sampling_techniques: | |
sampled = sample_word(masked_sent, words, logits, | |
sampling_technique=technique, | |
temperature=temp) | |
batch_samples.append(sampled) | |
self.sampled_sentences.extend(batch_samples) | |
if current_masked: | |
tree = generate_subplot2( | |
current_masked, | |
batch_samples, | |
highlight_info, | |
self.common_grams | |
) | |
trees.append(tree) | |
masked_index += 3 | |
if len(trees) < 10: | |
trees.extend([None] * (10 - len(trees))) | |
execution_time = time.time() - start_time | |
time_info = f"Step 3 completed in {execution_time:.2f} seconds" | |
return trees[:10] + [time_info] | |
def step4_reparaphrase(self): | |
start_time = time.time() | |
if self.sampled_sentences is None: | |
return ["Error: Please complete step 3 first"] * 120 + ["Error: Please complete step 3 first"] | |
# Existing step4 code... | |
self.reparaphrased_sentences = [] | |
for i in range(13): | |
self.reparaphrased_sentences.append(generate_paraphrase(self.sampled_sentences[i])) | |
reparaphrased_sentences_list = [] | |
for i in range(0, len(self.reparaphrased_sentences), 10): | |
batch = self.reparaphrased_sentences[i:i + 10] | |
if len(batch) == 10: | |
html_block = reparaphrased_sentences_html(batch) | |
reparaphrased_sentences_list.append(html_block) | |
execution_time = time.time() - start_time | |
time_info = f"Step 4 completed in {execution_time:.2f} seconds" | |
return reparaphrased_sentences_list + [time_info] | |
def step5_metrics(self): | |
start_time = time.time() | |
if self.reparaphrased_sentences is None: | |
return "Please complete step 4 first", "Error: Please complete step 4 first" | |
# Existing step5 code... | |
distortion_calculator = SentenceDistortionCalculator(self.user_prompt, self.reparaphrased_sentences) | |
distortion_calculator.calculate_all_metrics() | |
distortion_calculator.normalize_metrics() | |
distortion_calculator.calculate_combined_distortion() | |
distortion = distortion_calculator.get_combined_distortions() | |
self.distortion_list = [each[1] for each in distortion.items()] | |
detectability_calculator = SentenceDetectabilityCalculator(self.user_prompt, self.reparaphrased_sentences) | |
detectability_calculator.calculate_all_metrics() | |
detectability_calculator.normalize_metrics() | |
detectability_calculator.calculate_combined_detectability() | |
detectability = detectability_calculator.get_combined_detectabilities() | |
self.detectability_list = [each[1] for each in detectability.items()] | |
euclidean_dist_calculator = SentenceEuclideanDistanceCalculator(self.user_prompt, self.reparaphrased_sentences) | |
euclidean_dist_calculator.calculate_all_metrics() | |
euclidean_dist_calculator.normalize_metrics() | |
euclidean_dist = detectability_calculator.get_combined_detectabilities() | |
self.euclidean_dist_list = [each[1] for each in euclidean_dist.items()] | |
three_D_plot = gen_three_D_plot( | |
self.detectability_list, | |
self.distortion_list, | |
self.euclidean_dist_list | |
) | |
execution_time = time.time() - start_time | |
time_info = f"Step 5 completed in {execution_time:.2f} seconds" | |
return three_D_plot, time_info | |
def create_gradio_interface(): | |
pipeline = WatermarkingPipeline() | |
with gr.Blocks(theme=gr.themes.Monochrome()) as demo: | |
gr.Markdown("# **AIISC Watermarking Model**") | |
with gr.Column(): | |
gr.Markdown("## Input Prompt") | |
user_input = gr.Textbox(label="Enter Your Prompt") | |
gr.Markdown("## Step 1: Paraphrasing, LCS and Entailment Analysis") | |
paraphrase_button = gr.Button("Generate Paraphrases") | |
highlighted_user_prompt = gr.HTML(label="Highlighted User Prompt") | |
with gr.Tabs(): | |
with gr.TabItem("Accepted Paraphrased Sentences"): | |
highlighted_accepted_sentences = gr.HTML() | |
with gr.TabItem("Discarded Paraphrased Sentences"): | |
highlighted_discarded_sentences = gr.HTML() | |
step1_time = gr.Textbox(label="Execution Time", interactive=False) | |
gr.Markdown("## Step 2: Where to Mask?") | |
masking_button = gr.Button("Apply Masking") | |
gr.Markdown("### Masked Sentence Trees") | |
with gr.Tabs(): | |
tree1_tabs = [] | |
for i in range(10): | |
with gr.TabItem(f"Masked Sentence {i+1}"): | |
tree1 = gr.Plot() | |
tree1_tabs.append(tree1) | |
step2_time = gr.Textbox(label="Execution Time", interactive=False) | |
gr.Markdown("## Step 3: How to Mask?") | |
sampling_button = gr.Button("Sample Words") | |
gr.Markdown("### Sampled Sentence Trees") | |
with gr.Tabs(): | |
tree2_tabs = [] | |
for i in range(10): | |
with gr.TabItem(f"Sampled Sentence {i+1}"): | |
tree2 = gr.Plot() | |
tree2_tabs.append(tree2) | |
step3_time = gr.Textbox(label="Execution Time", interactive=False) | |
gr.Markdown("## Step 4: Re-paraphrasing") | |
reparaphrase_button = gr.Button("Re-paraphrase") | |
gr.Markdown("### Reparaphrased Sentences") | |
with gr.Tabs(): | |
reparaphrased_sentences_tabs = [] | |
for i in range(120): | |
with gr.TabItem(f"Reparaphrased Batch {i+1}"): | |
reparaphrased_sent_html = gr.HTML() | |
reparaphrased_sentences_tabs.append(reparaphrased_sent_html) | |
step4_time = gr.Textbox(label="Execution Time", interactive=False) | |
gr.Markdown("## Step 5: Finding Sweet Spot") | |
metrics_button = gr.Button("Calculate Metrics") | |
gr.Markdown("### 3D Visualization of Metrics") | |
three_D_plot = gr.Plot() | |
step5_time = gr.Textbox(label="Execution Time", interactive=False) | |
paraphrase_button.click( | |
pipeline.step1_paraphrasing, | |
inputs=user_input, | |
outputs=[highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences, step1_time] | |
) | |
masking_button.click( | |
pipeline.step2_masking, | |
inputs=None, | |
outputs=tree1_tabs + [step2_time] | |
) | |
sampling_button.click( | |
pipeline.step3_sampling, | |
inputs=None, | |
outputs=tree2_tabs + [step3_time], | |
show_progress=True | |
) | |
reparaphrase_button.click( | |
pipeline.step4_reparaphrase, | |
inputs=None, | |
outputs=reparaphrased_sentences_tabs + [step4_time] | |
) | |
metrics_button.click( | |
pipeline.step5_metrics, | |
inputs=None, | |
outputs=[three_D_plot, step5_time] | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = create_gradio_interface() | |
demo.launch(share=True) |