lsacy
commited on
Commit
·
1317801
1
Parent(s):
0a5800f
minor update
Browse files- .gitignore +5 -0
- streamlit_app.py +2 -1
- test.ipynb +57 -57
.gitignore
CHANGED
@@ -1,2 +1,7 @@
|
|
1 |
data/
|
2 |
.env
|
|
|
|
|
|
|
|
|
|
|
|
1 |
data/
|
2 |
.env
|
3 |
+
openai_api_key.txt
|
4 |
+
Dockerfile
|
5 |
+
test.ipynb
|
6 |
+
stream_app_minimum.py
|
7 |
+
joy.py
|
streamlit_app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import openai
|
2 |
-
openai.api_key =
|
|
|
3 |
|
4 |
import streamlit as st
|
5 |
from streamlit_chat import message
|
|
|
1 |
import openai
|
2 |
+
openai.api_key = st.secrets["openai_api_key"]
|
3 |
+
|
4 |
|
5 |
import streamlit as st
|
6 |
from streamlit_chat import message
|
test.ipynb
CHANGED
@@ -54,13 +54,31 @@
|
|
54 |
},
|
55 |
{
|
56 |
"cell_type": "code",
|
57 |
-
"execution_count":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
"metadata": {},
|
59 |
"outputs": [
|
60 |
{
|
61 |
"data": {
|
62 |
"text/plain": [
|
63 |
-
"[
|
|
|
64 |
" ' Patient:no not really Joy:No problem. Is there anything on your mind that you would like to share? I am here to listen and offer support.',\n",
|
65 |
" \" Patient:maybe actually. I am afraid my spouse cheats on me Joy:I'm sorry to hear that. What has been making you feel this way?\",\n",
|
66 |
" ' Patient:she is writing text messages constantly with people i dont know Joy:That can certainly be concerning. Have you tried talking to your spouse about this?',\n",
|
@@ -76,7 +94,7 @@
|
|
76 |
" \" Patient:yes please Joy:It can help to start by letting your mom know why you decided to talk to her about this. Sharing how you are feeling and asking for her advice might be a good way to open the conversation. I'm here if you need anything else before or after speaking with your mom.\"]"
|
77 |
]
|
78 |
},
|
79 |
-
"execution_count":
|
80 |
"metadata": {},
|
81 |
"output_type": "execute_result"
|
82 |
}
|
@@ -86,27 +104,34 @@
|
|
86 |
" chat_log = [i.replace('\\n', ' ') for i in chat_log]\n",
|
87 |
" return chat_log\n",
|
88 |
"\n",
|
89 |
-
"
|
90 |
]
|
91 |
},
|
92 |
{
|
93 |
"cell_type": "code",
|
94 |
-
"execution_count":
|
95 |
"metadata": {},
|
96 |
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
{
|
98 |
"data": {
|
99 |
"text/plain": [
|
100 |
-
"
|
101 |
]
|
102 |
},
|
103 |
-
"execution_count":
|
104 |
"metadata": {},
|
105 |
"output_type": "execute_result"
|
106 |
}
|
107 |
],
|
108 |
"source": [
|
109 |
-
"
|
110 |
]
|
111 |
},
|
112 |
{
|
@@ -143,85 +168,60 @@
|
|
143 |
},
|
144 |
{
|
145 |
"cell_type": "code",
|
146 |
-
"execution_count":
|
147 |
"metadata": {},
|
148 |
"outputs": [],
|
149 |
"source": [
|
150 |
-
"\n",
|
151 |
-
"
|
152 |
-
"
|
153 |
-
"
|
154 |
-
"
|
155 |
-
" chat_log = chat_log[first_newline:]\n",
|
156 |
-
" # remove all \\n\n",
|
157 |
-
" chat_log = chat_log.replace('\\n', ' ')\n",
|
158 |
-
" return chat_log"
|
159 |
]
|
160 |
},
|
161 |
{
|
162 |
"cell_type": "code",
|
163 |
-
"execution_count":
|
164 |
"metadata": {},
|
165 |
"outputs": [
|
166 |
-
{
|
167 |
-
"name": "stderr",
|
168 |
-
"output_type": "stream",
|
169 |
-
"text": [
|
170 |
-
"Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n"
|
171 |
-
]
|
172 |
-
},
|
173 |
{
|
174 |
"data": {
|
175 |
"text/plain": [
|
176 |
-
"
|
|
|
177 |
]
|
178 |
},
|
179 |
-
"execution_count":
|
180 |
"metadata": {},
|
181 |
"output_type": "execute_result"
|
182 |
}
|
183 |
],
|
184 |
"source": [
|
185 |
-
"
|
186 |
]
|
187 |
},
|
188 |
{
|
189 |
"cell_type": "code",
|
190 |
-
"execution_count":
|
191 |
-
"metadata": {},
|
192 |
-
"outputs": [],
|
193 |
-
"source": [
|
194 |
-
"test3_string = clean_chat_log(test3)"
|
195 |
-
]
|
196 |
-
},
|
197 |
-
{
|
198 |
-
"cell_type": "code",
|
199 |
-
"execution_count": 26,
|
200 |
"metadata": {},
|
201 |
"outputs": [
|
202 |
{
|
203 |
-
"
|
204 |
-
"
|
205 |
-
"
|
206 |
-
|
207 |
-
"\
|
208 |
-
"
|
209 |
-
"Cell \u001b[0;32mIn[26], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m sentiment_task(test3_string)\n",
|
210 |
-
"File \u001b[0;32m~/opt/miniconda3/envs/chatbot/lib/python3.8/site-packages/transformers/pipelines/text_classification.py:155\u001b[0m, in \u001b[0;36mTextClassificationPipeline.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 121\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__call__\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[1;32m 122\u001b[0m \u001b[39m \u001b[39m\u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 123\u001b[0m \u001b[39m Classify the text(s) given as inputs.\u001b[39;00m\n\u001b[1;32m 124\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 153\u001b[0m \u001b[39m If `top_k` is used, one such dictionary is returned per label.\u001b[39;00m\n\u001b[1;32m 154\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 155\u001b[0m result \u001b[39m=\u001b[39m \u001b[39msuper\u001b[39;49m()\u001b[39m.\u001b[39;49m\u001b[39m__call__\u001b[39;49m(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 156\u001b[0m \u001b[39m# TODO try and retrieve it in a nicer way from _sanitize_parameters.\u001b[39;00m\n\u001b[1;32m 157\u001b[0m _legacy \u001b[39m=\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mtop_k\u001b[39m\u001b[39m\"\u001b[39m \u001b[39mnot\u001b[39;00m \u001b[39min\u001b[39;00m kwargs\n",
|
211 |
-
"File \u001b[0;32m~/opt/miniconda3/envs/chatbot/lib/python3.8/site-packages/transformers/pipelines/base.py:1074\u001b[0m, in \u001b[0;36mPipeline.__call__\u001b[0;34m(self, inputs, num_workers, batch_size, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1072\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39miterate(inputs, preprocess_params, forward_params, postprocess_params)\n\u001b[1;32m 1073\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m-> 1074\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mrun_single(inputs, preprocess_params, forward_params, postprocess_params)\n",
|
212 |
-
"File \u001b[0;32m~/opt/miniconda3/envs/chatbot/lib/python3.8/site-packages/transformers/pipelines/base.py:1081\u001b[0m, in \u001b[0;36mPipeline.run_single\u001b[0;34m(self, inputs, preprocess_params, forward_params, postprocess_params)\u001b[0m\n\u001b[1;32m 1079\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mrun_single\u001b[39m(\u001b[39mself\u001b[39m, inputs, preprocess_params, forward_params, postprocess_params):\n\u001b[1;32m 1080\u001b[0m model_inputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpreprocess(inputs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mpreprocess_params)\n\u001b[0;32m-> 1081\u001b[0m model_outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mforward(model_inputs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mforward_params)\n\u001b[1;32m 1082\u001b[0m outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpostprocess(model_outputs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mpostprocess_params)\n\u001b[1;32m 1083\u001b[0m \u001b[39mreturn\u001b[39;00m outputs\n",
|
213 |
-
"File \u001b[0;32m~/opt/miniconda3/envs/chatbot/lib/python3.8/site-packages/transformers/pipelines/base.py:990\u001b[0m, in \u001b[0;36mPipeline.forward\u001b[0;34m(self, model_inputs, **forward_params)\u001b[0m\n\u001b[1;32m 988\u001b[0m \u001b[39mwith\u001b[39;00m inference_context():\n\u001b[1;32m 989\u001b[0m model_inputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_ensure_tensor_on_device(model_inputs, device\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdevice)\n\u001b[0;32m--> 990\u001b[0m model_outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_forward(model_inputs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mforward_params)\n\u001b[1;32m 991\u001b[0m model_outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_ensure_tensor_on_device(model_outputs, device\u001b[39m=\u001b[39mtorch\u001b[39m.\u001b[39mdevice(\u001b[39m\"\u001b[39m\u001b[39mcpu\u001b[39m\u001b[39m\"\u001b[39m))\n\u001b[1;32m 992\u001b[0m \u001b[39melse\u001b[39;00m:\n",
|
214 |
-
"File \u001b[0;32m~/opt/miniconda3/envs/chatbot/lib/python3.8/site-packages/transformers/pipelines/text_classification.py:182\u001b[0m, in \u001b[0;36mTextClassificationPipeline._forward\u001b[0;34m(self, model_inputs)\u001b[0m\n\u001b[1;32m 181\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_forward\u001b[39m(\u001b[39mself\u001b[39m, model_inputs):\n\u001b[0;32m--> 182\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmodel(\u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mmodel_inputs)\n",
|
215 |
-
"File \u001b[0;32m~/opt/miniconda3/envs/chatbot/lib/python3.8/site-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1190\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1191\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1195\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
|
216 |
-
"File \u001b[0;32m~/opt/miniconda3/envs/chatbot/lib/python3.8/site-packages/transformers/models/roberta/modeling_roberta.py:1215\u001b[0m, in \u001b[0;36mRobertaForSequenceClassification.forward\u001b[0;34m(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 1207\u001b[0m \u001b[39m\u001b[39m\u001b[39mr\u001b[39m\u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 1208\u001b[0m \u001b[39mlabels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\u001b[39;00m\n\u001b[1;32m 1209\u001b[0m \u001b[39m Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\u001b[39;00m\n\u001b[1;32m 1210\u001b[0m \u001b[39m config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\u001b[39;00m\n\u001b[1;32m 1211\u001b[0m \u001b[39m `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\u001b[39;00m\n\u001b[1;32m 1212\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 1213\u001b[0m return_dict \u001b[39m=\u001b[39m return_dict \u001b[39mif\u001b[39;00m return_dict \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39muse_return_dict\n\u001b[0;32m-> 1215\u001b[0m outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mroberta(\n\u001b[1;32m 1216\u001b[0m input_ids,\n\u001b[1;32m 1217\u001b[0m attention_mask\u001b[39m=\u001b[39;49mattention_mask,\n\u001b[1;32m 1218\u001b[0m token_type_ids\u001b[39m=\u001b[39;49mtoken_type_ids,\n\u001b[1;32m 1219\u001b[0m position_ids\u001b[39m=\u001b[39;49mposition_ids,\n\u001b[1;32m 1220\u001b[0m head_mask\u001b[39m=\u001b[39;49mhead_mask,\n\u001b[1;32m 1221\u001b[0m inputs_embeds\u001b[39m=\u001b[39;49minputs_embeds,\n\u001b[1;32m 1222\u001b[0m output_attentions\u001b[39m=\u001b[39;49moutput_attentions,\n\u001b[1;32m 1223\u001b[0m output_hidden_states\u001b[39m=\u001b[39;49moutput_hidden_states,\n\u001b[1;32m 1224\u001b[0m return_dict\u001b[39m=\u001b[39;49mreturn_dict,\n\u001b[1;32m 1225\u001b[0m )\n\u001b[1;32m 1226\u001b[0m sequence_output \u001b[39m=\u001b[39m outputs[\u001b[39m0\u001b[39m]\n\u001b[1;32m 1227\u001b[0m logits \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mclassifier(sequence_output)\n",
|
217 |
-
"File \u001b[0;32m~/opt/miniconda3/envs/chatbot/lib/python3.8/site-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1190\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1191\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1195\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
|
218 |
-
"File \u001b[0;32m~/opt/miniconda3/envs/chatbot/lib/python3.8/site-packages/transformers/models/roberta/modeling_roberta.py:819\u001b[0m, in \u001b[0;36mRobertaModel.forward\u001b[0;34m(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 817\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mhasattr\u001b[39m(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39membeddings, \u001b[39m\"\u001b[39m\u001b[39mtoken_type_ids\u001b[39m\u001b[39m\"\u001b[39m):\n\u001b[1;32m 818\u001b[0m buffered_token_type_ids \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39membeddings\u001b[39m.\u001b[39mtoken_type_ids[:, :seq_length]\n\u001b[0;32m--> 819\u001b[0m buffered_token_type_ids_expanded \u001b[39m=\u001b[39m buffered_token_type_ids\u001b[39m.\u001b[39;49mexpand(batch_size, seq_length)\n\u001b[1;32m 820\u001b[0m token_type_ids \u001b[39m=\u001b[39m buffered_token_type_ids_expanded\n\u001b[1;32m 821\u001b[0m \u001b[39melse\u001b[39;00m:\n",
|
219 |
-
"\u001b[0;31mRuntimeError\u001b[0m: The expanded size of the tensor (634) must match the existing size (514) at non-singleton dimension 1. Target sizes: [1, 634]. Tensor sizes: [1, 514]"
|
220 |
]
|
221 |
}
|
222 |
],
|
223 |
"source": [
|
224 |
-
"
|
|
|
|
|
|
|
|
|
|
|
225 |
]
|
226 |
},
|
227 |
{
|
|
|
54 |
},
|
55 |
{
|
56 |
"cell_type": "code",
|
57 |
+
"execution_count": 20,
|
58 |
+
"metadata": {},
|
59 |
+
"outputs": [],
|
60 |
+
"source": [
|
61 |
+
"\n",
|
62 |
+
"def clean_chat_log(chat_log):\n",
|
63 |
+
" chat_log = ' '.join(chat_log)\n",
|
64 |
+
" # find the first /n\n",
|
65 |
+
" first_newline = chat_log.find('\\n')\n",
|
66 |
+
" chat_log = chat_log[first_newline:]\n",
|
67 |
+
" # remove all \\n\n",
|
68 |
+
" chat_log = chat_log.replace('\\n', ' ')\n",
|
69 |
+
" return chat_log"
|
70 |
+
]
|
71 |
+
},
|
72 |
+
{
|
73 |
+
"cell_type": "code",
|
74 |
+
"execution_count": 32,
|
75 |
"metadata": {},
|
76 |
"outputs": [
|
77 |
{
|
78 |
"data": {
|
79 |
"text/plain": [
|
80 |
+
"['[Instruction] Act as a friendly, compasionate, insightful, and empathetic AI therapist named Joy. Joy listens and offers advices. End the conversation when the patient wishes to. Joy:I am Joy, your AI therapist. How are you feeling today?',\n",
|
81 |
+
" \" Patient:im doing okay Joy:That's good to hear. Is there anything in particular that you would like to talk about today?\",\n",
|
82 |
" ' Patient:no not really Joy:No problem. Is there anything on your mind that you would like to share? I am here to listen and offer support.',\n",
|
83 |
" \" Patient:maybe actually. I am afraid my spouse cheats on me Joy:I'm sorry to hear that. What has been making you feel this way?\",\n",
|
84 |
" ' Patient:she is writing text messages constantly with people i dont know Joy:That can certainly be concerning. Have you tried talking to your spouse about this?',\n",
|
|
|
94 |
" \" Patient:yes please Joy:It can help to start by letting your mom know why you decided to talk to her about this. Sharing how you are feeling and asking for her advice might be a good way to open the conversation. I'm here if you need anything else before or after speaking with your mom.\"]"
|
95 |
]
|
96 |
},
|
97 |
+
"execution_count": 32,
|
98 |
"metadata": {},
|
99 |
"output_type": "execute_result"
|
100 |
}
|
|
|
104 |
" chat_log = [i.replace('\\n', ' ') for i in chat_log]\n",
|
105 |
" return chat_log\n",
|
106 |
"\n",
|
107 |
+
"remove_backslashN(test3)"
|
108 |
]
|
109 |
},
|
110 |
{
|
111 |
"cell_type": "code",
|
112 |
+
"execution_count": 21,
|
113 |
"metadata": {},
|
114 |
"outputs": [
|
115 |
+
{
|
116 |
+
"name": "stderr",
|
117 |
+
"output_type": "stream",
|
118 |
+
"text": [
|
119 |
+
"Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n"
|
120 |
+
]
|
121 |
+
},
|
122 |
{
|
123 |
"data": {
|
124 |
"text/plain": [
|
125 |
+
"682"
|
126 |
]
|
127 |
},
|
128 |
+
"execution_count": 21,
|
129 |
"metadata": {},
|
130 |
"output_type": "execute_result"
|
131 |
}
|
132 |
],
|
133 |
"source": [
|
134 |
+
"len(tokenizer(test, padding=True, truncation=True, return_tensors=\"pt\")[0])"
|
135 |
]
|
136 |
},
|
137 |
{
|
|
|
168 |
},
|
169 |
{
|
170 |
"cell_type": "code",
|
171 |
+
"execution_count": 56,
|
172 |
"metadata": {},
|
173 |
"outputs": [],
|
174 |
"source": [
|
175 |
+
"import numpy as np\n",
|
176 |
+
"tokenized = tokenizer(['this is great', 'this sucks'], return_tensors='pt', truncation=True, padding=True)\n",
|
177 |
+
"output = model(**tokenized)\n",
|
178 |
+
"scores = output[0][0].detach().numpy()\n",
|
179 |
+
"scores = np.exp(scores) / np.sum(np.exp(scores), axis=0)"
|
|
|
|
|
|
|
|
|
180 |
]
|
181 |
},
|
182 |
{
|
183 |
"cell_type": "code",
|
184 |
+
"execution_count": 58,
|
185 |
"metadata": {},
|
186 |
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
{
|
188 |
"data": {
|
189 |
"text/plain": [
|
190 |
+
"SequenceClassifierOutput(loss=None, logits=tensor([[-2.1332, -0.7378, 2.8549],\n",
|
191 |
+
" [ 1.6230, -0.1552, -1.7155]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)"
|
192 |
]
|
193 |
},
|
194 |
+
"execution_count": 58,
|
195 |
"metadata": {},
|
196 |
"output_type": "execute_result"
|
197 |
}
|
198 |
],
|
199 |
"source": [
|
200 |
+
"output"
|
201 |
]
|
202 |
},
|
203 |
{
|
204 |
"cell_type": "code",
|
205 |
+
"execution_count": 57,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
"metadata": {},
|
207 |
"outputs": [
|
208 |
{
|
209 |
+
"name": "stdout",
|
210 |
+
"output_type": "stream",
|
211 |
+
"text": [
|
212 |
+
"1) positive 0.9668\n",
|
213 |
+
"2) neutral 0.0266\n",
|
214 |
+
"3) negative 0.0066\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
]
|
216 |
}
|
217 |
],
|
218 |
"source": [
|
219 |
+
"ranking = np.argsort(scores)\n",
|
220 |
+
"ranking = ranking[::-1]\n",
|
221 |
+
"for i in range(scores.shape[0]):\n",
|
222 |
+
" l = config.id2label[ranking[i]]\n",
|
223 |
+
" s = scores[ranking[i]]\n",
|
224 |
+
" print(f\"{i+1}) {l} {np.round(float(s), 4)}\")"
|
225 |
]
|
226 |
},
|
227 |
{
|