Spaces:
Runtime error
Runtime error
from transformers import pipeline | |
def analyze_entailment(original_sentence, paraphrased_sentences, threshold): | |
# Load the entailment model once | |
entailment_pipe = pipeline("text-classification", model="ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli") | |
all_sentences = {} | |
selected_sentences = {} | |
discarded_sentences = {} | |
# Prepare input for entailment checks | |
inputs = [f"{original_sentence} [SEP] {paraphrase}" for paraphrase in paraphrased_sentences] | |
# Perform entailment checks for all paraphrased sentences in one go | |
entailment_results = entailment_pipe(inputs, return_all_scores=True) | |
# Iterate over results | |
for paraphrased_sentence, results in zip(paraphrased_sentences, entailment_results): | |
# Extract the entailment score for each paraphrased sentence | |
entailment_score = next((result['score'] for result in results if result['label'] == 'entailment'), 0) | |
all_sentences[paraphrased_sentence] = entailment_score | |
# Store sentences based on the threshold | |
if entailment_score >= threshold: | |
selected_sentences[paraphrased_sentence] = entailment_score | |
else: | |
discarded_sentences[paraphrased_sentence] = entailment_score | |
return all_sentences, selected_sentences, discarded_sentences | |
# Example usage | |
# print(analyze_entailment("I love you", ["I adore you", "I hate you"], 0.7)) | |