aiisc-watermarking-model / distortion.py
BheemaShankerNeyigapula's picture
Upload folder using huggingface_hub
ea6afa4 verified
# Import necessary libraries
import nltk
import numpy as np
import torch
import matplotlib.pyplot as plt
from scipy.special import rel_entr
from collections import Counter
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
# Download NLTK data if not already present
nltk.download('punkt', quiet=True)
class SentenceDistortionCalculator:
"""
A class to calculate and analyze distortion metrics between an original sentence and modified sentences.
"""
def __init__(self, original_sentence, modified_sentences):
self.original_sentence = original_sentence
self.modified_sentences = modified_sentences
self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
self.model = GPT2LMHeadModel.from_pretrained("gpt2").eval() # Set model to evaluation mode
# Raw metric dictionaries
self.metrics = {
'levenshtein': {},
'word_level_changes': {},
'kl_divergences': {},
'perplexities': {},
}
# Combined distortion dictionary
self.combined_distortions = {}
def calculate_all_metrics(self):
"""Calculate all distortion metrics for each modified sentence."""
for idx, modified_sentence in enumerate(self.modified_sentences):
key = f"Sentence_{idx + 1}"
self.metrics['levenshtein'][key] = self._calculate_levenshtein_distance(modified_sentence)
self.metrics['word_level_changes'][key] = self._calculate_word_level_change(modified_sentence)
self.metrics['kl_divergences'][key] = self._calculate_kl_divergence(modified_sentence)
self.metrics['perplexities'][key] = self._calculate_perplexity(modified_sentence)
def normalize_metrics(self):
"""Normalize all metrics to be between 0 and 1."""
for metric in self.metrics:
self.metrics[metric] = self._normalize_dict(self.metrics[metric])
def calculate_combined_distortion(self):
"""Calculate the combined distortion using the root mean square of the normalized metrics."""
for key in self.metrics['levenshtein']:
rms = np.sqrt(sum(self.metrics[metric][key] ** 2 for metric in self.metrics) / len(self.metrics))
self.combined_distortions[key] = rms
def plot_metrics(self):
"""Plot each normalized metric and the combined distortion in separate graphs."""
keys = list(self.metrics['levenshtein'].keys())
indices = np.arange(len(keys))
for metric_name, values in self.metrics.items():
plt.figure(figsize=(12, 6))
plt.plot(indices, list(values.values()), marker='o', label=metric_name)
plt.xlabel('Sentence Index')
plt.ylabel('Normalized Value (0-1)')
plt.title(f'Normalized {metric_name.replace("_", " ").title()}')
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()
# Private methods for metric calculations
def _calculate_levenshtein_distance(self, modified_sentence):
"""Calculate the Levenshtein Distance between the original and modified sentence."""
return nltk.edit_distance(self.original_sentence, modified_sentence)
def _calculate_word_level_change(self, modified_sentence):
"""Calculate the proportion of word-level changes between the original and modified sentence."""
original_words = self.original_sentence.split()
modified_words = modified_sentence.split()
total_words = max(len(original_words), len(modified_words))
changed_words = sum(o != m for o, m in zip(original_words, modified_words)) + abs(len(original_words) - len(modified_words))
return changed_words / total_words if total_words > 0 else 0
def _calculate_kl_divergence(self, modified_sentence):
"""Calculate the KL Divergence between the word distributions of the original and modified sentence."""
original_counts = Counter(self.original_sentence.lower().split())
modified_counts = Counter(modified_sentence.lower().split())
all_words = set(original_counts.keys()).union(modified_counts.keys())
original_probs = np.array([original_counts[word] for word in all_words], dtype=float)
modified_probs = np.array([modified_counts[word] for word in all_words], dtype=float)
original_probs /= original_probs.sum() + 1e-10 # Avoid division by zero
modified_probs /= modified_probs.sum() + 1e-10
return np.sum(rel_entr(original_probs, modified_probs))
def _calculate_perplexity(self, sentence):
"""Calculate the perplexity of a sentence using GPT-2."""
encodings = self.tokenizer(sentence, return_tensors='pt')
stride = self.model.config.n_positions
log_likelihoods = []
for i in range(0, encodings.input_ids.size(1), stride):
input_ids = encodings.input_ids[:, i:i + stride]
with torch.no_grad():
outputs = self.model(input_ids, labels=input_ids)
log_likelihoods.append(outputs.loss.item())
avg_log_likelihood = np.mean(log_likelihoods)
return torch.exp(torch.tensor(avg_log_likelihood)).item()
def _normalize_dict(self, metric_dict):
"""Normalize the values in a dictionary to be between 0 and 1."""
values = np.array(list(metric_dict.values()))
min_val, max_val = values.min(), values.max()
normalized_values = (values - min_val) / (max_val - min_val) if max_val > min_val else np.zeros_like(values)
return dict(zip(metric_dict.keys(), normalized_values))
def get_normalized_metrics(self):
"""Get all normalized metrics as a dictionary."""
return {metric: self._normalize_dict(values) for metric, values in self.metrics.items()}
def get_combined_distortions(self):
"""Get the dictionary of combined distortion values."""
return self.combined_distortions