Spaces:
Runtime error
Runtime error
import streamlit as st | |
import transformers | |
import matplotlib.pyplot as plt | |
def get_pipe(): | |
model_name = "joeddav/distilbert-base-uncased-go-emotions-student" | |
model = transformers.AutoModelForSequenceClassification.from_pretrained(model_name) | |
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) | |
pipe = transformers.pipeline('text-classification', model=model, tokenizer=tokenizer, | |
return_all_scores=True, truncation=True) | |
return pipe | |
def sort_predictions(predictions): | |
return sorted(predictions, key=lambda x: x['score'], reverse=True) | |
st.set_page_config(page_title="Affect scores") | |
st.title("Affect scores") | |
st.write("Type text into the text box and then press 'Compute' to generate affect scores.") | |
default_text = "It's about a startup taking on a big yet creative challenge, with ups and downs along the way." | |
text = st.text_area('Enter text here:', value=default_text) | |
submit = st.button('Generate') | |
with st.spinner("Loading model..."): | |
pipe = get_pipe() | |
if (submit and len(text.strip()) > 0) or len(text.strip()) > 0: | |
prediction = pipe(text)[0] | |
sorted_p = sort_predictions(prediction) | |
max_ylim = sorted_p[0]['score'] + 0.1 | |
fig, ax = plt.subplots() | |
ax.barh([p['label'] for p in prediction], [p['score'] for p in prediction]) | |
#ax.tick_params(rotation=0) | |
ax.set_xlim(0, max_ylim) | |
st.header('Result:') | |
st.pyplot(fig) | |
prediction = dict([(p['label'], p['score']) for p in prediction]) | |
st.header('Raw values:') | |
st.json(prediction) | |