jgyasu's picture
Upload folder using huggingface_hub
38c3a0a verified
raw
history blame
13.5 kB
import nltk
nltk.download('stopwords')
import random
import gradio as gr
import time
from tree import generate_subplot1, generate_subplot2
from paraphraser import generate_paraphrase
# from lcs import find_common_subsequences, find_common_gram_positions
from highlighter import highlight_common_words, highlight_common_words_dict, reparaphrased_sentences_html
from entailment import analyze_entailment
from masking_methods import mask_non_stopword, mask_non_stopword_pseudorandom, high_entropy_words
from sampling_methods import sample_word
from detectability import SentenceDetectabilityCalculator
from distortion import SentenceDistortionCalculator
from euclidean_distance import SentenceEuclideanDistanceCalculator
from threeD_plot import gen_three_D_plot
from twokenize import tokenize_sentences, tokenize_sentence
from non_melting_points import find_non_melting_points
class WatermarkingPipeline:
def __init__(self):
# Existing initialization code...
self.user_prompt = None
self.paraphrased_sentences = None
self.analyzed_paraphrased_sentences = None
self.selected_sentences = None
self.discarded_sentences = None
self.common_grams = None
self.subsequences = None
self.common_grams_position = None
self.masked_sentences = None
self.masked_words = None
self.masked_logits = None
self.sampled_sentences = None
self.reparaphrased_sentences = None
self.distortion_list = None
self.detectability_list = None
self.euclidean_dist_list = None
def step1_paraphrasing(self, prompt, threshold=0.7):
start_time = time.time()
self.user_prompt = prompt
self.paraphrased_sentences = generate_paraphrase(prompt)
if self.paraphrased_sentences is None:
return "Error in generating paraphrases", "Error: Could not complete step"
self.analyzed_paraphrased_sentences, self.selected_sentences, self.discarded_sentences = \
analyze_entailment(self.user_prompt, self.paraphrased_sentences, threshold)
self.user_prompt_tokenized = tokenize_sentence(self.user_prompt)
self.selected_sentences_tokenized = tokenize_sentences(self.selected_sentences)
self.discarded_sentences_tokenized = tokenize_sentences(self.discarded_sentences)
all_tokenized_sentences = []
all_tokenized_sentences.append(self.user_prompt_tokenized)
all_tokenized_sentences.extend(self.selected_sentences_tokenized)
self.common_grams = find_non_melting_points(all_tokenized_sentences)
highlighted_user_prompt = highlight_common_words(
self.common_grams, [self.user_prompt], "Highlighted LCS in the User Prompt"
)
highlighted_accepted_sentences = highlight_common_words_dict(
self.common_grams, self.selected_sentences, "Paraphrased Sentences"
)
highlighted_discarded_sentences = highlight_common_words_dict(
self.common_grams, self.discarded_sentences, "Discarded Sentences"
)
execution_time = time.time() - start_time
time_info = f"Step 1 completed in {execution_time:.2f} seconds"
return [
highlighted_user_prompt,
highlighted_accepted_sentences,
highlighted_discarded_sentences,
time_info
]
def step2_masking(self):
start_time = time.time()
if self.paraphrased_sentences is None:
return [None] * 10 + ["Error: Please complete step 1 first"]
# Existing step2 code...
self.masked_sentences = []
self.masked_words = []
self.masked_logits = []
for sentence in self.paraphrased_sentences:
for mask_func in [mask_non_stopword, mask_non_stopword_pseudorandom,
lambda s: high_entropy_words(s, self.common_grams)]:
masked_sent, logits, words = mask_func(sentence)
self.masked_sentences.append(masked_sent)
self.masked_words.append(words)
self.masked_logits.append(logits)
trees = []
masked_index = 0
colors = ["red", "blue", "brown", "green"]
highlight_info = [(word, random.choice(colors)) for _, word in self.common_grams]
for i, sentence in enumerate(self.paraphrased_sentences):
next_masked = self.masked_sentences[masked_index:masked_index + 3]
tree = generate_subplot1(sentence, next_masked, highlight_info, self.common_grams)
trees.append(tree)
masked_index += 3
execution_time = time.time() - start_time
time_info = f"Step 2 completed in {execution_time:.2f} seconds"
return trees + [time_info]
def step3_sampling(self):
start_time = time.time()
if self.masked_sentences is None:
return [None] * 10 + ["Error: Please complete step 2 first"]
# Existing step3 code...
self.sampled_sentences = []
trees = []
colors = ["red", "blue", "brown", "green"]
highlight_info = [(word, random.choice(colors)) for _, word in self.common_grams]
sampling_techniques = [
('inverse_transform', 1.0),
('exponential_minimum', 1.0),
('temperature', 1.0),
('greedy', 1.0)
]
masked_index = 0
while masked_index < len(self.masked_sentences):
current_masked = self.masked_sentences[masked_index:masked_index + 3]
current_words = self.masked_words[masked_index:masked_index + 3]
current_logits = self.masked_logits[masked_index:masked_index + 3]
batch_samples = []
for masked_sent, words, logits in zip(current_masked, current_words, current_logits):
for technique, temp in sampling_techniques:
sampled = sample_word(masked_sent, words, logits,
sampling_technique=technique,
temperature=temp)
batch_samples.append(sampled)
self.sampled_sentences.extend(batch_samples)
if current_masked:
tree = generate_subplot2(
current_masked,
batch_samples,
highlight_info,
self.common_grams
)
trees.append(tree)
masked_index += 3
if len(trees) < 10:
trees.extend([None] * (10 - len(trees)))
execution_time = time.time() - start_time
time_info = f"Step 3 completed in {execution_time:.2f} seconds"
return trees[:10] + [time_info]
def step4_reparaphrase(self):
start_time = time.time()
if self.sampled_sentences is None:
return ["Error: Please complete step 3 first"] * 120 + ["Error: Please complete step 3 first"]
# Existing step4 code...
self.reparaphrased_sentences = []
for i in range(13):
self.reparaphrased_sentences.append(generate_paraphrase(self.sampled_sentences[i]))
reparaphrased_sentences_list = []
for i in range(0, len(self.reparaphrased_sentences), 10):
batch = self.reparaphrased_sentences[i:i + 10]
if len(batch) == 10:
html_block = reparaphrased_sentences_html(batch)
reparaphrased_sentences_list.append(html_block)
execution_time = time.time() - start_time
time_info = f"Step 4 completed in {execution_time:.2f} seconds"
return reparaphrased_sentences_list + [time_info]
def step5_metrics(self):
start_time = time.time()
if self.reparaphrased_sentences is None:
return "Please complete step 4 first", "Error: Please complete step 4 first"
# Existing step5 code...
distortion_calculator = SentenceDistortionCalculator(self.user_prompt, self.reparaphrased_sentences)
distortion_calculator.calculate_all_metrics()
distortion_calculator.normalize_metrics()
distortion_calculator.calculate_combined_distortion()
distortion = distortion_calculator.get_combined_distortions()
self.distortion_list = [each[1] for each in distortion.items()]
detectability_calculator = SentenceDetectabilityCalculator(self.user_prompt, self.reparaphrased_sentences)
detectability_calculator.calculate_all_metrics()
detectability_calculator.normalize_metrics()
detectability_calculator.calculate_combined_detectability()
detectability = detectability_calculator.get_combined_detectabilities()
self.detectability_list = [each[1] for each in detectability.items()]
euclidean_dist_calculator = SentenceEuclideanDistanceCalculator(self.user_prompt, self.reparaphrased_sentences)
euclidean_dist_calculator.calculate_all_metrics()
euclidean_dist_calculator.normalize_metrics()
euclidean_dist = detectability_calculator.get_combined_detectabilities()
self.euclidean_dist_list = [each[1] for each in euclidean_dist.items()]
three_D_plot = gen_three_D_plot(
self.detectability_list,
self.distortion_list,
self.euclidean_dist_list
)
execution_time = time.time() - start_time
time_info = f"Step 5 completed in {execution_time:.2f} seconds"
return three_D_plot, time_info
def create_gradio_interface():
pipeline = WatermarkingPipeline()
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
gr.Markdown("# **AIISC Watermarking Model**")
with gr.Column():
gr.Markdown("## Input Prompt")
user_input = gr.Textbox(label="Enter Your Prompt")
gr.Markdown("## Step 1: Paraphrasing, LCS and Entailment Analysis")
paraphrase_button = gr.Button("Generate Paraphrases")
highlighted_user_prompt = gr.HTML(label="Highlighted User Prompt")
with gr.Tabs():
with gr.TabItem("Accepted Paraphrased Sentences"):
highlighted_accepted_sentences = gr.HTML()
with gr.TabItem("Discarded Paraphrased Sentences"):
highlighted_discarded_sentences = gr.HTML()
step1_time = gr.Textbox(label="Execution Time", interactive=False)
gr.Markdown("## Step 2: Where to Mask?")
masking_button = gr.Button("Apply Masking")
gr.Markdown("### Masked Sentence Trees")
with gr.Tabs():
tree1_tabs = []
for i in range(10):
with gr.TabItem(f"Masked Sentence {i+1}"):
tree1 = gr.Plot()
tree1_tabs.append(tree1)
step2_time = gr.Textbox(label="Execution Time", interactive=False)
gr.Markdown("## Step 3: How to Mask?")
sampling_button = gr.Button("Sample Words")
gr.Markdown("### Sampled Sentence Trees")
with gr.Tabs():
tree2_tabs = []
for i in range(10):
with gr.TabItem(f"Sampled Sentence {i+1}"):
tree2 = gr.Plot()
tree2_tabs.append(tree2)
step3_time = gr.Textbox(label="Execution Time", interactive=False)
gr.Markdown("## Step 4: Re-paraphrasing")
reparaphrase_button = gr.Button("Re-paraphrase")
gr.Markdown("### Reparaphrased Sentences")
with gr.Tabs():
reparaphrased_sentences_tabs = []
for i in range(120):
with gr.TabItem(f"Reparaphrased Batch {i+1}"):
reparaphrased_sent_html = gr.HTML()
reparaphrased_sentences_tabs.append(reparaphrased_sent_html)
step4_time = gr.Textbox(label="Execution Time", interactive=False)
gr.Markdown("## Step 5: Finding Sweet Spot")
metrics_button = gr.Button("Calculate Metrics")
gr.Markdown("### 3D Visualization of Metrics")
three_D_plot = gr.Plot()
step5_time = gr.Textbox(label="Execution Time", interactive=False)
paraphrase_button.click(
pipeline.step1_paraphrasing,
inputs=user_input,
outputs=[highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences, step1_time]
)
masking_button.click(
pipeline.step2_masking,
inputs=None,
outputs=tree1_tabs + [step2_time]
)
sampling_button.click(
pipeline.step3_sampling,
inputs=None,
outputs=tree2_tabs + [step3_time],
show_progress=True
)
reparaphrase_button.click(
pipeline.step4_reparaphrase,
inputs=None,
outputs=reparaphrased_sentences_tabs + [step4_time]
)
metrics_button.click(
pipeline.step5_metrics,
inputs=None,
outputs=[three_D_plot, step5_time]
)
return demo
if __name__ == "__main__":
demo = create_gradio_interface()
demo.launch(share=True)