Spaces:
Runtime error
Runtime error
File size: 2,297 Bytes
917d2f9 514343b 4fd42f1 917d2f9 4fd42f1 8f36cc4 843aeb0 917d2f9 44264ed 917d2f9 418bd7c 8f36cc4 44264ed 3367c7d e838b9b 8f36cc4 1716434 8f36cc4 418bd7c 8f36cc4 20efea7 8f36cc4 20efea7 8f36cc4 ea052a5 8f36cc4 917d2f9 8f36cc4 917d2f9 8f36cc4 917d2f9 8f36cc4 917d2f9 8f36cc4 917d2f9 8f36cc4 917d2f9 8f36cc4 917d2f9 8f36cc4 917d2f9 8f36cc4 917d2f9 8f36cc4 917d2f9 8f36cc4 5daf8df 8f36cc4 5daf8df 8f36cc4 5daf8df 8f36cc4 5daf8df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import streamlit as st
# Library for Entailment
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
# Load model
tokenizer = AutoTokenizer.from_pretrained("roberta-large-mnli")
text_classification_model = AutoModelForSequenceClassification.from_pretrained("roberta-large-mnli")
### Streamlit interface ###
st.title("Text Classification")
st.subheader("Entailment, neutral or contradiction?")
with st.form("submission_form", clear_on_submit=False):
threshold = st.slider("Threshold", min_value=0.0, max_value=1.0, step=0.1, value=0.7)
sentence_1 = st.text_input("Sentence 1 input")
sentence_2 = st.text_input("Sentence 2 input")
submit_button_compare = st.form_submit_button("Compare Sentences")
# If submit_button_compare clicked
if submit_button_compare:
print("Comparing sentences...")
### Text classification - entailment, neutral or contradiction ###
raw_inputs = [f"{sentence_1}</s></s>{sentence_2}"]
inputs = tokenizer(raw_inputs, padding=True, truncation=True, return_tensors="pt")
# print(inputs)
outputs = text_classification_model(**inputs)
outputs = torch.nn.functional.softmax(outputs.logits, dim = -1)
# print(outputs)
# argmax_index = torch.argmax(outputs).item()
print(text_classification_model.config.id2label[0], ":", round(outputs[0][0].item()*100,2),"%")
print(text_classification_model.config.id2label[1], ":", round(outputs[0][1].item()*100,2),"%")
print(text_classification_model.config.id2label[2], ":", round(outputs[0][2].item()*100,2),"%")
st.subheader("Text classification for both sentences:")
st.write(text_classification_model.config.id2label[1], ":", round(outputs[0][1].item()*100,2),"%")
st.write(text_classification_model.config.id2label[0], ":", round(outputs[0][0].item()*100,2),"%")
st.write(text_classification_model.config.id2label[2], ":", round(outputs[0][2].item()*100,2),"%")
entailment_score = round(outputs[0][2].item()*100,2)
if entailment_score >= threshold:
st.subheader("The statements are very similar!")
else:
st.subheader("The statements are not close enough")
|