magnolia-pm commited on
Commit
164cb45
1 Parent(s): 9adf82d

init commit

Browse files
Files changed (2) hide show
  1. app.py +67 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import streamlit as st
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ import plotly.graph_objects as go
5
+
6
+ input_text = st.text_input(
7
+ label='Estimate item desirability:',
8
+ value='I love a good fight.',
9
+ placeholder='Enter item'
10
+
11
+ )
12
+
13
+ #model_path = '/nlp/nlp/models/finetuned/twitter-xlm-roberta-base-regressive-desirability-ft-4'
14
+ model_path = 'magnolia-psychometrics/item-desirability'
15
+
16
+ tokenizer = AutoTokenizer.from_pretrained(
17
+ pretrained_model_name_or_path=model_path,
18
+ use_fast=True
19
+ )
20
+
21
+ model = AutoModelForSequenceClassification.from_pretrained(
22
+ pretrained_model_name_or_path=model_path,
23
+ num_labels=1,
24
+ ignore_mismatched_sizes=True,
25
+ )
26
+
27
+ def z_score(y, mean=.04853076, sd=.9409466):
28
+ return (y - mean) / sd
29
+
30
+ if input_text:
31
+
32
+ inputs = tokenizer(input_text, padding=True, return_tensors='pt')
33
+
34
+ with torch.no_grad():
35
+ score = model(**inputs).logits.squeeze().tolist()
36
+ z = z_score(score)
37
+
38
+ fig = go.Figure(go.Indicator(
39
+ mode = "gauge+delta",
40
+ value = z,
41
+ domain = {'x': [0, 1], 'y': [0, 1]},
42
+ title = f"Item Desirability <br><sup>\"{input_text}\"</sup>",
43
+ delta = {
44
+ 'reference': 0,
45
+ 'decreasing': {'color': "#ec4899"},
46
+ 'increasing': {'color': "#36def1"}
47
+ },
48
+ gauge = {
49
+ 'axis': {'range': [-4, 4], 'tickwidth': 1, 'tickcolor': "black"},
50
+ 'bar': {'color': "#4361ee"},
51
+ 'bgcolor': "white",
52
+ 'borderwidth': 2,
53
+ 'bordercolor': "#efefef",
54
+ 'steps': [
55
+ {'range': [-4, 0], 'color': '#efefef'},
56
+ {'range': [0, 4], 'color': '#efefef'}],
57
+ 'threshold': {
58
+ 'line': {'color': "#4361ee", 'width': 8},
59
+ 'thickness': 0.75,
60
+ 'value': z}
61
+ }))
62
+
63
+ fig.update_layout(
64
+ paper_bgcolor = "white",
65
+ font = {'color': "black", 'family': "Arial"})
66
+
67
+ st.plotly_chart(fig, theme=None, use_container_width=True)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch==1.13.1+cu116
2
+ transformers==4.25.1
3
+ plotly==5.11.0