Spaces:
Runtime error
Runtime error
first try
Browse files- w2v_ovr_svc.sav β models/w2v_ovr_svc.sav +0 -0
- requirements.txt +5 -0
- text_class_app.py +33 -0
- utils.py +87 -0
w2v_ovr_svc.sav β models/w2v_ovr_svc.sav
RENAMED
File without changes
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit==1.4.0
|
2 |
+
re==2.2.1
|
3 |
+
gensim==4.1.2
|
4 |
+
transformers==4.16.1
|
5 |
+
pickle
|
text_class_app.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import utils
|
3 |
+
|
4 |
+
########## Title for the Web App ##########
|
5 |
+
st.title("Text Classification for Service Feedback")
|
6 |
+
|
7 |
+
########## Create Input field ##########
|
8 |
+
feedback = st.text_input('Type your text here', 'The staff were extremely polite and helpful!')
|
9 |
+
|
10 |
+
if st.button('Click for predictions!'):
|
11 |
+
with st.spinner('Generating predictions...'):
|
12 |
+
|
13 |
+
result = get_single_prediction(feedback)
|
14 |
+
|
15 |
+
st.success(f'Your text has been predicted to fall under the following labels: {result[:-1]}. This text is {result[-1]}.')
|
16 |
+
|
17 |
+
st.text('Or... Upload a csv file if you have many texts')
|
18 |
+
|
19 |
+
uploaded_file = st.file_uploader("Please upload a csv file with only 1 column of texts.")
|
20 |
+
|
21 |
+
if uploaded_file is not None:
|
22 |
+
|
23 |
+
with st.spinner('Generating predictions...'):
|
24 |
+
results = get_multiple_predictions(uploaded_file)
|
25 |
+
|
26 |
+
st.download_button(
|
27 |
+
label="Download results as CSV",
|
28 |
+
data=results,
|
29 |
+
file_name='results.csv',
|
30 |
+
mime='text/csv',
|
31 |
+
)
|
32 |
+
|
33 |
+
|
utils.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from gensim.models.keyedvectors import KeyedVectors
|
3 |
+
from transformers import pipeline
|
4 |
+
import pickle
|
5 |
+
|
6 |
+
w2v = KeyedVectors.load('models/word2vec')
|
7 |
+
w2v_vocab = set(sorted(w2v.index_to_key))
|
8 |
+
model = pickle.load(open('models/w2v_ovr_svc.sav', 'rb'))
|
9 |
+
classifier = pipeline("zero-shot-classification",
|
10 |
+
model="facebook/bart-large-mnli", device=0, framework='pt'
|
11 |
+
)
|
12 |
+
|
13 |
+
labels = [
|
14 |
+
'communication', 'waiting time',
|
15 |
+
'information', 'user interface',
|
16 |
+
'facilities', 'location', 'price'
|
17 |
+
]
|
18 |
+
|
19 |
+
def get_sentiment_label_facebook(list_of_sent_dicts):
|
20 |
+
if list_of_sent_dicts['labels'][0] == 'negative':
|
21 |
+
return 'negative'
|
22 |
+
else:
|
23 |
+
return 'positive'
|
24 |
+
|
25 |
+
def get_single_prediction(text):
|
26 |
+
|
27 |
+
# manipulate data into a format that we pass to our model
|
28 |
+
text = text.lower() #lower case
|
29 |
+
text = re.sub('[^0-9a-zA-Z\s]', '', text) #remove special char, punctuation
|
30 |
+
|
31 |
+
# Remove OOV words
|
32 |
+
text = ' '.join([i for i in text.split() if i in w2v_vocab])
|
33 |
+
|
34 |
+
# Vectorise text and store in new dataframe. Sentence vector = average of word vectors
|
35 |
+
text_vectors = np.mean([w2v[i] for i in text.split()], axis=0)
|
36 |
+
|
37 |
+
# Make predictions
|
38 |
+
results = model.predict(text_vectors)
|
39 |
+
|
40 |
+
# Get sentiment
|
41 |
+
sentiment = get_sentiment_label_facebook(classifier(text,
|
42 |
+
candidate_labels=['positive', 'negative'],
|
43 |
+
hypothesis_template='The sentiment of this is {}'))
|
44 |
+
|
45 |
+
# Consolidate results
|
46 |
+
pred_labels = [labels[idx] for idx, tag in enumerate(results) if tag == 1]
|
47 |
+
pred_labels.append(sentiment)
|
48 |
+
|
49 |
+
return pred_labels
|
50 |
+
|
51 |
+
def get_multiple_predictions(csv):
|
52 |
+
|
53 |
+
df = pd.read_csv(csv)
|
54 |
+
df.columns = ['sequence']
|
55 |
+
|
56 |
+
df['sequence'] = df['sequence'].str.lower() #lower case
|
57 |
+
df['sequence'] = df['sequence'].str.replace('[^0-9a-zA-Z\s]','') #remove special char, punctuation
|
58 |
+
|
59 |
+
# Remove OOV words
|
60 |
+
df['sequence'] = df['sequence'].apply(lambda x: ' '.join([i for i in x.split() if i in w2v_vocab]))
|
61 |
+
|
62 |
+
# Remove rows with blank string
|
63 |
+
invalid = df[(pd.isna(df['sequence'])) | (df['sequence'] == '')]
|
64 |
+
|
65 |
+
df.dropna(inplace=True)
|
66 |
+
df = df[df['sequence'] != ''].reset_index(drop=True)
|
67 |
+
|
68 |
+
# Vectorise text and store in new dataframe. Sentence vector = average of word vectors
|
69 |
+
series_text_vectors = pd.DataFrame(df['sequence'].apply(lambda x: np.mean([w2v[i] for i in x.split()], axis=0)).values.tolist())
|
70 |
+
|
71 |
+
# Get predictions
|
72 |
+
pred_results = pd.DataFrame(model.predict(series_text_vectors), columns = labels)
|
73 |
+
|
74 |
+
# Join back to original sequence
|
75 |
+
final_results = df.join(series_text_vectors)
|
76 |
+
|
77 |
+
# Get sentiment labels
|
78 |
+
final_results['sentiment'] = final_results['sequence'].apply(lambda x: get_sentiment_label_facebook(classifier(x,
|
79 |
+
candidate_labels=['positive', 'negative'],
|
80 |
+
hypothesis_template='The sentiment of this is {}'))
|
81 |
+
)
|
82 |
+
|
83 |
+
# Append invalid rows
|
84 |
+
if len(invalid) == 0:
|
85 |
+
return final_results
|
86 |
+
else:
|
87 |
+
return pd.concat([final_results, invalid]).reset_index(drop=True)
|