import streamlit as st import transformers import matplotlib.pyplot as plt @st.cache(allow_output_mutation=True, show_spinner=False) def get_pipe(): model_name = "joeddav/distilbert-base-uncased-go-emotions-student" model = transformers.AutoModelForSequenceClassification.from_pretrained(model_name) tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) pipe = transformers.pipeline('text-classification', model=model, tokenizer=tokenizer, return_all_scores=True, truncation=True) return pipe def sort_predictions(predictions): return sorted(predictions, key=lambda x: x['score'], reverse=True) st.set_page_config(page_title="Affect scores") st.title("Affect scores") st.write("Type text into the text box and then press 'Compute' to generate affect scores.") default_text = "It's about a startup taking on a big yet creative challenge, with ups and downs along the way." text = st.text_area('Enter text here:', value=default_text) submit = st.button('Generate') with st.spinner("Loading model..."): pipe = get_pipe() if (submit and len(text.strip()) > 0) or len(text.strip()) > 0: prediction = pipe(text)[0] #prediction = sort_predictions(prediction) max_ylim = prediction[0]['score'] + 0.1 fig, ax = plt.subplots() ax.barh([p['label'] for p in prediction], [p['score'] for p in prediction]) #ax.tick_params(rotation=0) ax.set_xlim(0, max_ylim) st.header('Result:') st.pyplot(fig) prediction = dict([(p['label'], p['score']) for p in prediction]) st.header('Raw values:') st.json(prediction)