lsacy commited on
Commit
1317801
·
1 Parent(s): 0a5800f

minor update

Browse files
Files changed (3) hide show
  1. .gitignore +5 -0
  2. streamlit_app.py +2 -1
  3. test.ipynb +57 -57
.gitignore CHANGED
@@ -1,2 +1,7 @@
1
  data/
2
  .env
 
 
 
 
 
 
1
  data/
2
  .env
3
+ openai_api_key.txt
4
+ Dockerfile
5
+ test.ipynb
6
+ stream_app_minimum.py
7
+ joy.py
streamlit_app.py CHANGED
@@ -1,5 +1,6 @@
1
  import openai
2
- openai.api_key = os.getenv('OPENAI_API_KEY')
 
3
 
4
  import streamlit as st
5
  from streamlit_chat import message
 
1
  import openai
2
+ openai.api_key = st.secrets["openai_api_key"]
3
+
4
 
5
  import streamlit as st
6
  from streamlit_chat import message
test.ipynb CHANGED
@@ -54,13 +54,31 @@
54
  },
55
  {
56
  "cell_type": "code",
57
- "execution_count": 31,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  "metadata": {},
59
  "outputs": [
60
  {
61
  "data": {
62
  "text/plain": [
63
- "[\" Patient:im doing okay Joy:That's good to hear. Is there anything in particular that you would like to talk about today?\",\n",
 
64
  " ' Patient:no not really Joy:No problem. Is there anything on your mind that you would like to share? I am here to listen and offer support.',\n",
65
  " \" Patient:maybe actually. I am afraid my spouse cheats on me Joy:I'm sorry to hear that. What has been making you feel this way?\",\n",
66
  " ' Patient:she is writing text messages constantly with people i dont know Joy:That can certainly be concerning. Have you tried talking to your spouse about this?',\n",
@@ -76,7 +94,7 @@
76
  " \" Patient:yes please Joy:It can help to start by letting your mom know why you decided to talk to her about this. Sharing how you are feeling and asking for her advice might be a good way to open the conversation. I'm here if you need anything else before or after speaking with your mom.\"]"
77
  ]
78
  },
79
- "execution_count": 31,
80
  "metadata": {},
81
  "output_type": "execute_result"
82
  }
@@ -86,27 +104,34 @@
86
  " chat_log = [i.replace('\\n', ' ') for i in chat_log]\n",
87
  " return chat_log\n",
88
  "\n",
89
- "clean_chatlog(test3)"
90
  ]
91
  },
92
  {
93
  "cell_type": "code",
94
- "execution_count": 30,
95
  "metadata": {},
96
  "outputs": [
 
 
 
 
 
 
 
97
  {
98
  "data": {
99
  "text/plain": [
100
- "\"\\n\\nPatient:im doing okay\\nJoy:That's good to hear. Is there anything in particular that you would like to talk about today?\""
101
  ]
102
  },
103
- "execution_count": 30,
104
  "metadata": {},
105
  "output_type": "execute_result"
106
  }
107
  ],
108
  "source": [
109
- "test3[1]"
110
  ]
111
  },
112
  {
@@ -143,85 +168,60 @@
143
  },
144
  {
145
  "cell_type": "code",
146
- "execution_count": 20,
147
  "metadata": {},
148
  "outputs": [],
149
  "source": [
150
- "\n",
151
- "def clean_chat_log(chat_log):\n",
152
- " chat_log = ' '.join(chat_log)\n",
153
- " # find the first /n\n",
154
- " first_newline = chat_log.find('\\n')\n",
155
- " chat_log = chat_log[first_newline:]\n",
156
- " # remove all \\n\n",
157
- " chat_log = chat_log.replace('\\n', ' ')\n",
158
- " return chat_log"
159
  ]
160
  },
161
  {
162
  "cell_type": "code",
163
- "execution_count": 21,
164
  "metadata": {},
165
  "outputs": [
166
- {
167
- "name": "stderr",
168
- "output_type": "stream",
169
- "text": [
170
- "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n"
171
- ]
172
- },
173
  {
174
  "data": {
175
  "text/plain": [
176
- "682"
 
177
  ]
178
  },
179
- "execution_count": 21,
180
  "metadata": {},
181
  "output_type": "execute_result"
182
  }
183
  ],
184
  "source": [
185
- "len(tokenizer(test, padding=True, truncation=True, return_tensors=\"pt\")[0])"
186
  ]
187
  },
188
  {
189
  "cell_type": "code",
190
- "execution_count": 25,
191
- "metadata": {},
192
- "outputs": [],
193
- "source": [
194
- "test3_string = clean_chat_log(test3)"
195
- ]
196
- },
197
- {
198
- "cell_type": "code",
199
- "execution_count": 26,
200
  "metadata": {},
201
  "outputs": [
202
  {
203
- "ename": "RuntimeError",
204
- "evalue": "The expanded size of the tensor (634) must match the existing size (514) at non-singleton dimension 1. Target sizes: [1, 634]. Tensor sizes: [1, 514]",
205
- "output_type": "error",
206
- "traceback": [
207
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
208
- "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
209
- "Cell \u001b[0;32mIn[26], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m sentiment_task(test3_string)\n",
210
- "File \u001b[0;32m~/opt/miniconda3/envs/chatbot/lib/python3.8/site-packages/transformers/pipelines/text_classification.py:155\u001b[0m, in \u001b[0;36mTextClassificationPipeline.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 121\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__call__\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[1;32m 122\u001b[0m \u001b[39m \u001b[39m\u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 123\u001b[0m \u001b[39m Classify the text(s) given as inputs.\u001b[39;00m\n\u001b[1;32m 124\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 153\u001b[0m \u001b[39m If `top_k` is used, one such dictionary is returned per label.\u001b[39;00m\n\u001b[1;32m 154\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 155\u001b[0m result \u001b[39m=\u001b[39m \u001b[39msuper\u001b[39;49m()\u001b[39m.\u001b[39;49m\u001b[39m__call__\u001b[39;49m(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 156\u001b[0m \u001b[39m# TODO try and retrieve it in a nicer way from _sanitize_parameters.\u001b[39;00m\n\u001b[1;32m 157\u001b[0m _legacy \u001b[39m=\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mtop_k\u001b[39m\u001b[39m\"\u001b[39m \u001b[39mnot\u001b[39;00m \u001b[39min\u001b[39;00m kwargs\n",
211
- "File \u001b[0;32m~/opt/miniconda3/envs/chatbot/lib/python3.8/site-packages/transformers/pipelines/base.py:1074\u001b[0m, in \u001b[0;36mPipeline.__call__\u001b[0;34m(self, inputs, num_workers, batch_size, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1072\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39miterate(inputs, preprocess_params, forward_params, postprocess_params)\n\u001b[1;32m 1073\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m-> 1074\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mrun_single(inputs, preprocess_params, forward_params, postprocess_params)\n",
212
- "File \u001b[0;32m~/opt/miniconda3/envs/chatbot/lib/python3.8/site-packages/transformers/pipelines/base.py:1081\u001b[0m, in \u001b[0;36mPipeline.run_single\u001b[0;34m(self, inputs, preprocess_params, forward_params, postprocess_params)\u001b[0m\n\u001b[1;32m 1079\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mrun_single\u001b[39m(\u001b[39mself\u001b[39m, inputs, preprocess_params, forward_params, postprocess_params):\n\u001b[1;32m 1080\u001b[0m model_inputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpreprocess(inputs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mpreprocess_params)\n\u001b[0;32m-> 1081\u001b[0m model_outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mforward(model_inputs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mforward_params)\n\u001b[1;32m 1082\u001b[0m outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpostprocess(model_outputs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mpostprocess_params)\n\u001b[1;32m 1083\u001b[0m \u001b[39mreturn\u001b[39;00m outputs\n",
213
- "File \u001b[0;32m~/opt/miniconda3/envs/chatbot/lib/python3.8/site-packages/transformers/pipelines/base.py:990\u001b[0m, in \u001b[0;36mPipeline.forward\u001b[0;34m(self, model_inputs, **forward_params)\u001b[0m\n\u001b[1;32m 988\u001b[0m \u001b[39mwith\u001b[39;00m inference_context():\n\u001b[1;32m 989\u001b[0m model_inputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_ensure_tensor_on_device(model_inputs, device\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdevice)\n\u001b[0;32m--> 990\u001b[0m model_outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_forward(model_inputs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mforward_params)\n\u001b[1;32m 991\u001b[0m model_outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_ensure_tensor_on_device(model_outputs, device\u001b[39m=\u001b[39mtorch\u001b[39m.\u001b[39mdevice(\u001b[39m\"\u001b[39m\u001b[39mcpu\u001b[39m\u001b[39m\"\u001b[39m))\n\u001b[1;32m 992\u001b[0m \u001b[39melse\u001b[39;00m:\n",
214
- "File \u001b[0;32m~/opt/miniconda3/envs/chatbot/lib/python3.8/site-packages/transformers/pipelines/text_classification.py:182\u001b[0m, in \u001b[0;36mTextClassificationPipeline._forward\u001b[0;34m(self, model_inputs)\u001b[0m\n\u001b[1;32m 181\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_forward\u001b[39m(\u001b[39mself\u001b[39m, model_inputs):\n\u001b[0;32m--> 182\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmodel(\u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mmodel_inputs)\n",
215
- "File \u001b[0;32m~/opt/miniconda3/envs/chatbot/lib/python3.8/site-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1190\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1191\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1195\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
216
- "File \u001b[0;32m~/opt/miniconda3/envs/chatbot/lib/python3.8/site-packages/transformers/models/roberta/modeling_roberta.py:1215\u001b[0m, in \u001b[0;36mRobertaForSequenceClassification.forward\u001b[0;34m(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 1207\u001b[0m \u001b[39m\u001b[39m\u001b[39mr\u001b[39m\u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 1208\u001b[0m \u001b[39mlabels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\u001b[39;00m\n\u001b[1;32m 1209\u001b[0m \u001b[39m Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\u001b[39;00m\n\u001b[1;32m 1210\u001b[0m \u001b[39m config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\u001b[39;00m\n\u001b[1;32m 1211\u001b[0m \u001b[39m `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\u001b[39;00m\n\u001b[1;32m 1212\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 1213\u001b[0m return_dict \u001b[39m=\u001b[39m return_dict \u001b[39mif\u001b[39;00m return_dict \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39muse_return_dict\n\u001b[0;32m-> 1215\u001b[0m outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mroberta(\n\u001b[1;32m 1216\u001b[0m input_ids,\n\u001b[1;32m 1217\u001b[0m attention_mask\u001b[39m=\u001b[39;49mattention_mask,\n\u001b[1;32m 1218\u001b[0m token_type_ids\u001b[39m=\u001b[39;49mtoken_type_ids,\n\u001b[1;32m 1219\u001b[0m position_ids\u001b[39m=\u001b[39;49mposition_ids,\n\u001b[1;32m 1220\u001b[0m head_mask\u001b[39m=\u001b[39;49mhead_mask,\n\u001b[1;32m 1221\u001b[0m inputs_embeds\u001b[39m=\u001b[39;49minputs_embeds,\n\u001b[1;32m 1222\u001b[0m output_attentions\u001b[39m=\u001b[39;49moutput_attentions,\n\u001b[1;32m 1223\u001b[0m output_hidden_states\u001b[39m=\u001b[39;49moutput_hidden_states,\n\u001b[1;32m 1224\u001b[0m return_dict\u001b[39m=\u001b[39;49mreturn_dict,\n\u001b[1;32m 1225\u001b[0m )\n\u001b[1;32m 1226\u001b[0m sequence_output \u001b[39m=\u001b[39m outputs[\u001b[39m0\u001b[39m]\n\u001b[1;32m 1227\u001b[0m logits \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mclassifier(sequence_output)\n",
217
- "File \u001b[0;32m~/opt/miniconda3/envs/chatbot/lib/python3.8/site-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1190\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1191\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1195\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
218
- "File \u001b[0;32m~/opt/miniconda3/envs/chatbot/lib/python3.8/site-packages/transformers/models/roberta/modeling_roberta.py:819\u001b[0m, in \u001b[0;36mRobertaModel.forward\u001b[0;34m(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 817\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mhasattr\u001b[39m(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39membeddings, \u001b[39m\"\u001b[39m\u001b[39mtoken_type_ids\u001b[39m\u001b[39m\"\u001b[39m):\n\u001b[1;32m 818\u001b[0m buffered_token_type_ids \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39membeddings\u001b[39m.\u001b[39mtoken_type_ids[:, :seq_length]\n\u001b[0;32m--> 819\u001b[0m buffered_token_type_ids_expanded \u001b[39m=\u001b[39m buffered_token_type_ids\u001b[39m.\u001b[39;49mexpand(batch_size, seq_length)\n\u001b[1;32m 820\u001b[0m token_type_ids \u001b[39m=\u001b[39m buffered_token_type_ids_expanded\n\u001b[1;32m 821\u001b[0m \u001b[39melse\u001b[39;00m:\n",
219
- "\u001b[0;31mRuntimeError\u001b[0m: The expanded size of the tensor (634) must match the existing size (514) at non-singleton dimension 1. Target sizes: [1, 634]. Tensor sizes: [1, 514]"
220
  ]
221
  }
222
  ],
223
  "source": [
224
- "sentiment_task(test3_string)"
 
 
 
 
 
225
  ]
226
  },
227
  {
 
54
  },
55
  {
56
  "cell_type": "code",
57
+ "execution_count": 20,
58
+ "metadata": {},
59
+ "outputs": [],
60
+ "source": [
61
+ "\n",
62
+ "def clean_chat_log(chat_log):\n",
63
+ " chat_log = ' '.join(chat_log)\n",
64
+ " # find the first /n\n",
65
+ " first_newline = chat_log.find('\\n')\n",
66
+ " chat_log = chat_log[first_newline:]\n",
67
+ " # remove all \\n\n",
68
+ " chat_log = chat_log.replace('\\n', ' ')\n",
69
+ " return chat_log"
70
+ ]
71
+ },
72
+ {
73
+ "cell_type": "code",
74
+ "execution_count": 32,
75
  "metadata": {},
76
  "outputs": [
77
  {
78
  "data": {
79
  "text/plain": [
80
+ "['[Instruction] Act as a friendly, compasionate, insightful, and empathetic AI therapist named Joy. Joy listens and offers advices. End the conversation when the patient wishes to. Joy:I am Joy, your AI therapist. How are you feeling today?',\n",
81
+ " \" Patient:im doing okay Joy:That's good to hear. Is there anything in particular that you would like to talk about today?\",\n",
82
  " ' Patient:no not really Joy:No problem. Is there anything on your mind that you would like to share? I am here to listen and offer support.',\n",
83
  " \" Patient:maybe actually. I am afraid my spouse cheats on me Joy:I'm sorry to hear that. What has been making you feel this way?\",\n",
84
  " ' Patient:she is writing text messages constantly with people i dont know Joy:That can certainly be concerning. Have you tried talking to your spouse about this?',\n",
 
94
  " \" Patient:yes please Joy:It can help to start by letting your mom know why you decided to talk to her about this. Sharing how you are feeling and asking for her advice might be a good way to open the conversation. I'm here if you need anything else before or after speaking with your mom.\"]"
95
  ]
96
  },
97
+ "execution_count": 32,
98
  "metadata": {},
99
  "output_type": "execute_result"
100
  }
 
104
  " chat_log = [i.replace('\\n', ' ') for i in chat_log]\n",
105
  " return chat_log\n",
106
  "\n",
107
+ "remove_backslashN(test3)"
108
  ]
109
  },
110
  {
111
  "cell_type": "code",
112
+ "execution_count": 21,
113
  "metadata": {},
114
  "outputs": [
115
+ {
116
+ "name": "stderr",
117
+ "output_type": "stream",
118
+ "text": [
119
+ "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n"
120
+ ]
121
+ },
122
  {
123
  "data": {
124
  "text/plain": [
125
+ "682"
126
  ]
127
  },
128
+ "execution_count": 21,
129
  "metadata": {},
130
  "output_type": "execute_result"
131
  }
132
  ],
133
  "source": [
134
+ "len(tokenizer(test, padding=True, truncation=True, return_tensors=\"pt\")[0])"
135
  ]
136
  },
137
  {
 
168
  },
169
  {
170
  "cell_type": "code",
171
+ "execution_count": 56,
172
  "metadata": {},
173
  "outputs": [],
174
  "source": [
175
+ "import numpy as np\n",
176
+ "tokenized = tokenizer(['this is great', 'this sucks'], return_tensors='pt', truncation=True, padding=True)\n",
177
+ "output = model(**tokenized)\n",
178
+ "scores = output[0][0].detach().numpy()\n",
179
+ "scores = np.exp(scores) / np.sum(np.exp(scores), axis=0)"
 
 
 
 
180
  ]
181
  },
182
  {
183
  "cell_type": "code",
184
+ "execution_count": 58,
185
  "metadata": {},
186
  "outputs": [
 
 
 
 
 
 
 
187
  {
188
  "data": {
189
  "text/plain": [
190
+ "SequenceClassifierOutput(loss=None, logits=tensor([[-2.1332, -0.7378, 2.8549],\n",
191
+ " [ 1.6230, -0.1552, -1.7155]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)"
192
  ]
193
  },
194
+ "execution_count": 58,
195
  "metadata": {},
196
  "output_type": "execute_result"
197
  }
198
  ],
199
  "source": [
200
+ "output"
201
  ]
202
  },
203
  {
204
  "cell_type": "code",
205
+ "execution_count": 57,
 
 
 
 
 
 
 
 
 
206
  "metadata": {},
207
  "outputs": [
208
  {
209
+ "name": "stdout",
210
+ "output_type": "stream",
211
+ "text": [
212
+ "1) positive 0.9668\n",
213
+ "2) neutral 0.0266\n",
214
+ "3) negative 0.0066\n"
 
 
 
 
 
 
 
 
 
 
 
215
  ]
216
  }
217
  ],
218
  "source": [
219
+ "ranking = np.argsort(scores)\n",
220
+ "ranking = ranking[::-1]\n",
221
+ "for i in range(scores.shape[0]):\n",
222
+ " l = config.id2label[ranking[i]]\n",
223
+ " s = scores[ranking[i]]\n",
224
+ " print(f\"{i+1}) {l} {np.round(float(s), 4)}\")"
225
  ]
226
  },
227
  {