aiisc-watermarking-model / entailment.py
jgyasu's picture
Upload folder using huggingface_hub
63b3783 verified
raw
history blame
1.29 kB
from transformers import pipeline
import numpy as np
def analyze_entailment(original_sentence, paraphrased_sentences, threshold):
# Load the entailment model using pipeline
entailment_pipe = pipeline("text-classification", model="ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli")
# Function to perform entailment
def check_entailment(premise, hypothesis):
results = entailment_pipe(f"{premise} [SEP] {hypothesis}", return_all_scores=True)
return results[0]
all_sentences = {}
selected_sentences = {}
discarded_sentences = {}
# Check entailment for each paraphrased sentence
for paraphrased_sentence in paraphrased_sentences:
entailment_results = check_entailment(original_sentence, paraphrased_sentence)
entailment_score = next(result['score'] for result in entailment_results if result['label'] == 'entailment')
all_sentences[paraphrased_sentence] = entailment_score
if entailment_score >= threshold:
selected_sentences[paraphrased_sentence] = entailment_score
else:
discarded_sentences[paraphrased_sentence] = entailment_score
return all_sentences, selected_sentences, discarded_sentences
# print(analyze_entailment("I love you", [""], 0.7))