digitalWestie's picture
don sort
60929e9
raw
history blame
1.63 kB
import streamlit as st
import transformers
import matplotlib.pyplot as plt
@st.cache(allow_output_mutation=True, show_spinner=False)
def get_pipe():
model_name = "joeddav/distilbert-base-uncased-go-emotions-student"
model = transformers.AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
pipe = transformers.pipeline('text-classification', model=model, tokenizer=tokenizer,
return_all_scores=True, truncation=True)
return pipe
def sort_predictions(predictions):
return sorted(predictions, key=lambda x: x['score'], reverse=True)
st.set_page_config(page_title="Affect scores")
st.title("Affect scores")
st.write("Type text into the text box and then press 'Compute' to generate affect scores.")
default_text = "It's about a startup taking on a big yet creative challenge, with ups and downs along the way."
text = st.text_area('Enter text here:', value=default_text)
submit = st.button('Generate')
with st.spinner("Loading model..."):
pipe = get_pipe()
if (submit and len(text.strip()) > 0) or len(text.strip()) > 0:
prediction = pipe(text)[0]
#prediction = sort_predictions(prediction)
max_ylim = prediction[0]['score'] + 0.1
fig, ax = plt.subplots()
ax.barh([p['label'] for p in prediction], [p['score'] for p in prediction])
#ax.tick_params(rotation=0)
ax.set_xlim(0, max_ylim)
st.header('Result:')
st.pyplot(fig)
prediction = dict([(p['label'], p['score']) for p in prediction])
st.header('Raw values:')
st.json(prediction)