Spaces:
Runtime error
Runtime error
# Import necessary libraries | |
import nltk | |
import numpy as np | |
import torch | |
import matplotlib.pyplot as plt | |
from scipy.special import rel_entr | |
from collections import Counter | |
from transformers import GPT2LMHeadModel, GPT2TokenizerFast | |
# Download NLTK data if not already present | |
nltk.download('punkt', quiet=True) | |
class SentenceDistortionCalculator: | |
""" | |
A class to calculate and analyze distortion metrics between an original sentence and modified sentences. | |
""" | |
def __init__(self, original_sentence, modified_sentences): | |
self.original_sentence = original_sentence | |
self.modified_sentences = modified_sentences | |
self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") | |
self.model = GPT2LMHeadModel.from_pretrained("gpt2").eval() # Set model to evaluation mode | |
# Raw metric dictionaries | |
self.metrics = { | |
'levenshtein': {}, | |
'word_level_changes': {}, | |
'kl_divergences': {}, | |
'perplexities': {}, | |
} | |
# Combined distortion dictionary | |
self.combined_distortions = {} | |
def calculate_all_metrics(self): | |
"""Calculate all distortion metrics for each modified sentence.""" | |
for idx, modified_sentence in enumerate(self.modified_sentences): | |
key = f"Sentence_{idx + 1}" | |
self.metrics['levenshtein'][key] = self._calculate_levenshtein_distance(modified_sentence) | |
self.metrics['word_level_changes'][key] = self._calculate_word_level_change(modified_sentence) | |
self.metrics['kl_divergences'][key] = self._calculate_kl_divergence(modified_sentence) | |
self.metrics['perplexities'][key] = self._calculate_perplexity(modified_sentence) | |
def normalize_metrics(self): | |
"""Normalize all metrics to be between 0 and 1.""" | |
for metric in self.metrics: | |
self.metrics[metric] = self._normalize_dict(self.metrics[metric]) | |
def calculate_combined_distortion(self): | |
"""Calculate the combined distortion using the root mean square of the normalized metrics.""" | |
for key in self.metrics['levenshtein']: | |
rms = np.sqrt(sum(self.metrics[metric][key] ** 2 for metric in self.metrics) / len(self.metrics)) | |
self.combined_distortions[key] = rms | |
def plot_metrics(self): | |
"""Plot each normalized metric and the combined distortion in separate graphs.""" | |
keys = list(self.metrics['levenshtein'].keys()) | |
indices = np.arange(len(keys)) | |
for metric_name, values in self.metrics.items(): | |
plt.figure(figsize=(12, 6)) | |
plt.plot(indices, list(values.values()), marker='o', label=metric_name) | |
plt.xlabel('Sentence Index') | |
plt.ylabel('Normalized Value (0-1)') | |
plt.title(f'Normalized {metric_name.replace("_", " ").title()}') | |
plt.grid(True) | |
plt.legend() | |
plt.tight_layout() | |
plt.show() | |
# Private methods for metric calculations | |
def _calculate_levenshtein_distance(self, modified_sentence): | |
"""Calculate the Levenshtein Distance between the original and modified sentence.""" | |
return nltk.edit_distance(self.original_sentence, modified_sentence) | |
def _calculate_word_level_change(self, modified_sentence): | |
"""Calculate the proportion of word-level changes between the original and modified sentence.""" | |
original_words = self.original_sentence.split() | |
modified_words = modified_sentence.split() | |
total_words = max(len(original_words), len(modified_words)) | |
changed_words = sum(o != m for o, m in zip(original_words, modified_words)) + abs(len(original_words) - len(modified_words)) | |
return changed_words / total_words if total_words > 0 else 0 | |
def _calculate_kl_divergence(self, modified_sentence): | |
"""Calculate the KL Divergence between the word distributions of the original and modified sentence.""" | |
original_counts = Counter(self.original_sentence.lower().split()) | |
modified_counts = Counter(modified_sentence.lower().split()) | |
all_words = set(original_counts.keys()).union(modified_counts.keys()) | |
original_probs = np.array([original_counts[word] for word in all_words], dtype=float) | |
modified_probs = np.array([modified_counts[word] for word in all_words], dtype=float) | |
original_probs /= original_probs.sum() + 1e-10 # Avoid division by zero | |
modified_probs /= modified_probs.sum() + 1e-10 | |
return np.sum(rel_entr(original_probs, modified_probs)) | |
def _calculate_perplexity(self, sentence): | |
"""Calculate the perplexity of a sentence using GPT-2.""" | |
encodings = self.tokenizer(sentence, return_tensors='pt') | |
stride = self.model.config.n_positions | |
log_likelihoods = [] | |
for i in range(0, encodings.input_ids.size(1), stride): | |
input_ids = encodings.input_ids[:, i:i + stride] | |
with torch.no_grad(): | |
outputs = self.model(input_ids, labels=input_ids) | |
log_likelihoods.append(outputs.loss.item()) | |
avg_log_likelihood = np.mean(log_likelihoods) | |
return torch.exp(torch.tensor(avg_log_likelihood)).item() | |
def _normalize_dict(self, metric_dict): | |
"""Normalize the values in a dictionary to be between 0 and 1.""" | |
values = np.array(list(metric_dict.values())) | |
min_val, max_val = values.min(), values.max() | |
normalized_values = (values - min_val) / (max_val - min_val) if max_val > min_val else np.zeros_like(values) | |
return dict(zip(metric_dict.keys(), normalized_values)) | |
def get_normalized_metrics(self): | |
"""Get all normalized metrics as a dictionary.""" | |
return {metric: self._normalize_dict(values) for metric, values in self.metrics.items()} | |
def get_combined_distortions(self): | |
"""Get the dictionary of combined distortion values.""" | |
return self.combined_distortions | |