|
import os |
|
import torch |
|
import json |
|
import time |
|
import random |
|
import streamlit as st |
|
import firebase_admin |
|
import logging |
|
from firebase_admin import credentials, firestore |
|
from dotenv import load_dotenv |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
from transformers import pipeline |
|
import plotly.graph_objects as go |
|
|
|
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO) |
|
|
|
load_dotenv() |
|
|
|
def load_credentials(): |
|
try: |
|
with open('public_creds.json') as f: |
|
credentials_dict = json.load(f) |
|
secret = { |
|
'private_key_id': os.environ.get('private_key_id'), |
|
'private_key': os.environ.get('private_key').replace(r'\n', '\n') |
|
} |
|
credentials_dict.update(secret) |
|
return credentials_dict |
|
except Exception as e: |
|
logging.error(f'Error while loading credentials: {e}') |
|
return None |
|
|
|
def connect_to_db(credentials_dict): |
|
try: |
|
cred = credentials.Certificate(credentials_dict) |
|
if not firebase_admin._apps: |
|
firebase_admin.initialize_app(cred) |
|
logging.info('Established connection to db!') |
|
return firestore.client() |
|
except Exception as e: |
|
logging.error(f'Error while connecting to db: {e}') |
|
return None |
|
|
|
def get_statements_from_db(db): |
|
try: |
|
document = db.collection('ItemDesirability').document('Items') |
|
statements = document.get().to_dict()['statements'] |
|
logging.info(f'Retrieved {len(statements)} statements from db!') |
|
return statements |
|
except Exception as e: |
|
logging.error(f'Error while retrieving items from db: {e}') |
|
return None |
|
|
|
def update_db(db, payload): |
|
|
|
try: |
|
collection_ref = db.collection('ItemDesirability') |
|
doc_ref = collection_ref.document('Responses') |
|
doc = doc_ref.get() |
|
|
|
if doc.exists: |
|
doc_ref.update({ |
|
'Data': firestore.ArrayUnion([payload]) |
|
}) |
|
else: |
|
doc_ref.set({ |
|
'Data': [payload] |
|
}) |
|
logging.info(f'Sent payload to db!') |
|
return True |
|
except Exception as e: |
|
logging.error(f'Error while sending payload to db: {e}') |
|
return False |
|
|
|
def pick_random(input_list): |
|
try: |
|
return random.choice(input_list) |
|
except Exception as e: |
|
logging.error(f'Error while picking random statement: {e}') |
|
return None |
|
|
|
def z_score(y, mean=.04853076, sd=.9409466): |
|
return (y - mean) / sd |
|
|
|
def score_text(input_text): |
|
classifier_output = st.session_state.classifier(input_text) |
|
classifier_output_dict = {x['label']: x['score'] for x in classifier_output[0]} |
|
sentiment = classifier_output_dict['positive'] - classifier_output_dict['negative'] |
|
|
|
inputs = st.session_state.tokenizer(text=input_text, padding=True, return_tensors='pt') |
|
|
|
with torch.no_grad(): |
|
score = st.session_state.model(**inputs).logits.squeeze().tolist() |
|
desirability = z_score(score) |
|
|
|
return sentiment, desirability |
|
|
|
def indicator_plot(value, title, value_range, domain): |
|
|
|
plot = go.Indicator( |
|
mode = "gauge+delta", |
|
value = value, |
|
domain = domain, |
|
title = title, |
|
delta = { |
|
'reference': 0, |
|
'decreasing': {'color': "#ec4899"}, |
|
'increasing': {'color': "#36def1"} |
|
}, |
|
gauge = { |
|
'axis': {'range': value_range, 'tickwidth': 1, 'tickcolor': "black"}, |
|
'bar': {'color': "#4361ee"}, |
|
'bgcolor': "white", |
|
'borderwidth': 2, |
|
'bordercolor': "#efefef", |
|
'steps': [ |
|
{'range': [value_range[0], 0], 'color': '#efefef'}, |
|
{'range': [0, value_range[1]], 'color': '#efefef'} |
|
], |
|
'threshold': { |
|
'line': {'color': "#4361ee", 'width': 8}, |
|
'thickness': 0.75, |
|
'value': value |
|
} |
|
} |
|
) |
|
|
|
return plot |
|
|
|
def show_scores(sentiment, desirability, input_text): |
|
p1 = indicator_plot( |
|
value=sentiment, |
|
title=f'Item Sentiment', |
|
value_range=[-1, 1], |
|
domain={'x': [0, .45], 'y': [0, 1]}, |
|
) |
|
|
|
p2 = indicator_plot( |
|
value=desirability, |
|
title=f'Item Desirability', |
|
value_range=[-4, 4], |
|
domain={'x': [.55, 1], 'y': [0, 1]} |
|
) |
|
|
|
fig = go.Figure() |
|
fig.add_trace(p1) |
|
fig.add_trace(p2) |
|
|
|
fig.update_layout( |
|
title=dict(text=f'"{input_text}"', font=dict(size=36),yref='paper'), |
|
paper_bgcolor = "white", |
|
font = {'color': "black", 'family': "Arial"}) |
|
|
|
st.plotly_chart(fig, theme=None, use_container_width=True) |
|
|
|
st.markdown(""" |
|
Item sentiment: Absolute differences between positive and negative sentiment. |
|
Item desirability: z-transformed values, 0 indicated "neutral". |
|
""") |
|
|
|
def update_statement_placeholder(placeholder): |
|
|
|
placeholder.markdown( |
|
body=f""" |
|
Is it socially desirable or undesirable to endorse the following statement? |
|
### <center>\"{st.session_state.current_statement.capitalize()}\"</center> |
|
""", |
|
unsafe_allow_html=True |
|
) |
|
|
|
def show(): |
|
credentials_dict = load_credentials() |
|
connection_attempts = 0 |
|
|
|
if 'db' not in st.session_state: |
|
st.session_state.db = None |
|
|
|
while st.session_state.db is None and connection_attempts < 3: |
|
st.session_state.db = connect_to_db(credentials_dict) |
|
if st.session_state.db is None: |
|
logging.info('Retrying to connect to db...') |
|
connection_attempts += 1 |
|
time.sleep(1) |
|
|
|
|
|
retrieval_attempts = 0 |
|
|
|
if 'statements' not in st.session_state: |
|
st.session_state.statements = None |
|
|
|
if 'current_statement' not in st.session_state: |
|
st.session_state.current_statement = None |
|
|
|
while st.session_state.statements is None and retrieval_attempts < 3: |
|
st.session_state.statements = get_statements_from_db(st.session_state.db) |
|
st.session_state.current_statement = pick_random(st.session_state.statements) |
|
if st.session_state.statements is None: |
|
logging.info('Retrying to retrieve statements from db...') |
|
retrieval_attempts += 1 |
|
time.sleep(1) |
|
|
|
st.markdown(""" |
|
## Try it yourself! |
|
Use the text field below to enter a statement that might be part of a psychological questionnaire (e.g., "I love a good fight."). |
|
The left dial indicates how socially desirable it might be to endorse this item. |
|
The right dial indicates sentiment (i.e., valence) as estimated by regular sentiment analysis (using the `cardiffnlp/twitter-xlm-roberta-base-sentiment` model). |
|
""") |
|
|
|
if st.session_state.db: |
|
collect_data = st.checkbox( |
|
label='I want to support and help improve this research.', |
|
value=True |
|
) |
|
else: |
|
collect_data = False |
|
|
|
if st.session_state.db and collect_data: |
|
|
|
statement_placeholder = st.empty() |
|
update_statement_placeholder(statement_placeholder) |
|
|
|
rating_options = ['[Please select]', 'Very undesirable', 'Undesirable', 'Neutral', 'Desirable', 'Very desirable'] |
|
|
|
selected_rating = st.selectbox( |
|
label='Rate the statement above according to whether it is socially desirable or undesirable.', |
|
options=rating_options, |
|
index=0 |
|
) |
|
|
|
suitability_options = ['No, I\'m just playing around', 'Yes, my input can help improve this research'] |
|
research_suitability = st.radio( |
|
label='Is your input suitable for research purposes?', |
|
options=suitability_options, |
|
horizontal=True |
|
) |
|
|
|
with st.spinner('Loading the model might take a couple of seconds...'): |
|
|
|
st.markdown("### Estimate item desirability") |
|
|
|
if os.environ.get('item-desirability'): |
|
model_path = 'magnolia-psychometrics/item-desirability' |
|
else: |
|
model_path = os.getenv('model_path') |
|
|
|
auth_token = os.environ.get('item-desirability') or True |
|
|
|
if 'tokenizer' not in st.session_state: |
|
st.session_state.tokenizer = AutoTokenizer.from_pretrained( |
|
pretrained_model_name_or_path=model_path, |
|
use_fast=True, |
|
use_auth_token=auth_token |
|
) |
|
|
|
if 'model' not in st.session_state: |
|
st.session_state.model = AutoModelForSequenceClassification.from_pretrained( |
|
pretrained_model_name_or_path=model_path, |
|
num_labels=1, |
|
ignore_mismatched_sizes=True, |
|
use_auth_token=auth_token |
|
) |
|
|
|
|
|
if 'classifier' not in st.session_state: |
|
st.session_state.sentiment_model = 'cardiffnlp/twitter-xlm-roberta-base-sentiment' |
|
st.session_state.classifier = pipeline( |
|
task='sentiment-analysis', |
|
model=st.session_state.sentiment_model, |
|
tokenizer=st.session_state.sentiment_model, |
|
use_fast=False, |
|
top_k=3 |
|
) |
|
|
|
input_text = st.text_input( |
|
label='Item text/statement:', |
|
value='I love a good fight.', |
|
placeholder='Enter item text' |
|
) |
|
|
|
if st.button(label='Evaluate Item Text', type="primary"): |
|
if collect_data and st.session_state.db: |
|
if selected_rating != rating_options[0]: |
|
item_rating = rating_options.index(selected_rating) |
|
suitability_rating = suitability_options.index(research_suitability) |
|
sentiment, desirability = score_text(input_text) |
|
|
|
payload = { |
|
'user_id': st.session_state.user_id, |
|
'statement': st.session_state.current_statement, |
|
'rating': item_rating, |
|
'suitability': suitability_rating, |
|
'input_text': input_text, |
|
'sentiment': sentiment, |
|
'desirability': desirability, |
|
} |
|
|
|
update_success = update_db( |
|
db=st.session_state.db, |
|
payload=payload |
|
) |
|
|
|
if update_success: |
|
st.session_state.current_statement = pick_random(st.session_state.statements) |
|
update_statement_placeholder(statement_placeholder) |
|
|
|
show_scores(sentiment, desirability, input_text) |
|
|
|
else: |
|
st.error('Please rate the statement presented above!') |
|
else: |
|
sentiment, desirability = score_text(input_text) |
|
show_scores(sentiment, desirability, input_text) |