wendru18
commited on
Commit
·
89abf96
1
Parent(s):
afecc92
downgraded streamlit as per huggingface requirements
Browse files- app.py +12 -11
- requirements.txt +1 -1
app.py
CHANGED
@@ -9,7 +9,7 @@ traits = ["Openness to Experience", "Conscientiousness", "Extraversion", "Agreea
|
|
9 |
|
10 |
short_traits = ["o", "c", "e", "a", "n"]
|
11 |
|
12 |
-
@st.
|
13 |
def load_explainer():
|
14 |
|
15 |
print("Loading model...")
|
@@ -30,7 +30,7 @@ def load_explainer():
|
|
30 |
|
31 |
return explainer
|
32 |
|
33 |
-
@st.
|
34 |
def explain(text, _explainer):
|
35 |
|
36 |
attributions = _explainer(text)
|
@@ -45,16 +45,17 @@ def attributions_to_html(attributions):
|
|
45 |
html = ""
|
46 |
for word, attr in attributions:
|
47 |
|
48 |
-
|
49 |
continue
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
|
54 |
-
|
55 |
-
|
|
|
56 |
|
57 |
-
|
58 |
|
59 |
return html
|
60 |
|
@@ -73,7 +74,7 @@ def main():
|
|
73 |
|
74 |
explainer = load_explainer()
|
75 |
|
76 |
-
text = st.text_area(label="Input text here...", value="I
|
77 |
|
78 |
show_prediction = st.button("Predict Traits")
|
79 |
|
@@ -84,7 +85,7 @@ def main():
|
|
84 |
explanation = explain(text, explainer)
|
85 |
st.session_state.explanation = explanation
|
86 |
|
87 |
-
if st.session_state.
|
88 |
|
89 |
st.write("## Predicted Traits")
|
90 |
|
|
|
9 |
|
10 |
short_traits = ["o", "c", "e", "a", "n"]
|
11 |
|
12 |
+
@st.experimental_memo
|
13 |
def load_explainer():
|
14 |
|
15 |
print("Loading model...")
|
|
|
30 |
|
31 |
return explainer
|
32 |
|
33 |
+
@st.experimental_memo
|
34 |
def explain(text, _explainer):
|
35 |
|
36 |
attributions = _explainer(text)
|
|
|
45 |
html = ""
|
46 |
for word, attr in attributions:
|
47 |
|
48 |
+
if word in ["<s>", "</s>"]:
|
49 |
continue
|
50 |
+
|
51 |
+
attr = round(attr, 2)
|
52 |
+
abs_attr = abs(attr)
|
53 |
|
54 |
+
color = "rgba(255,255,255,0)"
|
55 |
+
if attr > 0: color = f"rgba(0,255,0,{abs_attr})"
|
56 |
+
elif attr < 0: color = f"rgba(255,0,0,{abs_attr})"
|
57 |
|
58 |
+
html += f'<span style="background-color: {color}" title="{str(attr)}">{word}</span> '
|
59 |
|
60 |
return html
|
61 |
|
|
|
74 |
|
75 |
explainer = load_explainer()
|
76 |
|
77 |
+
text = st.text_area(label="Input text here...", value="I enjoy meeting people and working hard!")
|
78 |
|
79 |
show_prediction = st.button("Predict Traits")
|
80 |
|
|
|
85 |
explanation = explain(text, explainer)
|
86 |
st.session_state.explanation = explanation
|
87 |
|
88 |
+
if len(st.session_state.explanation["preds"]) > 0:
|
89 |
|
90 |
st.write("## Predicted Traits")
|
91 |
|
requirements.txt
CHANGED
@@ -3,4 +3,4 @@ pandas==1.4.2
|
|
3 |
simpletransformers==0.63.9
|
4 |
transformers==4.27.3
|
5 |
transformers_interpret==0.10.0
|
6 |
-
streamlit
|
|
|
3 |
simpletransformers==0.63.9
|
4 |
transformers==4.27.3
|
5 |
transformers_interpret==0.10.0
|
6 |
+
streamlit==1.17.0
|