Spaces:
Running
Running
update milestone-3 notebook
Browse files
CS670_milestone_3_AyeThuzar.ipynb
CHANGED
@@ -922,18 +922,18 @@
|
|
922 |
"\n",
|
923 |
"from transformers import pipeline, Trainer, TrainingArguments\n",
|
924 |
"\n",
|
925 |
-
"\n",
|
926 |
"import torch\n",
|
927 |
"import torch.nn.functional as F\n",
|
928 |
"\n",
|
929 |
"from transformers import logging\n",
|
930 |
"\n",
|
931 |
-
"logging.set_verbosity_warning()"
|
932 |
],
|
933 |
"metadata": {
|
934 |
"id": "FxZeFFTlFvz1"
|
935 |
},
|
936 |
-
"execution_count":
|
937 |
"outputs": []
|
938 |
},
|
939 |
{
|
@@ -1163,7 +1163,7 @@
|
|
1163 |
"metadata": {
|
1164 |
"colab": {
|
1165 |
"base_uri": "https://localhost:8080/",
|
1166 |
-
"height":
|
1167 |
},
|
1168 |
"id": "jDBvcgmP5Puh",
|
1169 |
"outputId": "f4f73693-11f7-4918-a86d-2912e863b151"
|
@@ -1193,7 +1193,7 @@
|
|
1193 |
"metadata": {
|
1194 |
"colab": {
|
1195 |
"base_uri": "https://localhost:8080/",
|
1196 |
-
"height":
|
1197 |
},
|
1198 |
"id": "sBhSPSV-5XKS",
|
1199 |
"outputId": "0057e051-3b36-4705-8636-19e7850fa0a9"
|
@@ -4118,7 +4118,7 @@
|
|
4118 |
"id": "h7bzRvkItdir",
|
4119 |
"outputId": "7495ec10-0ee5-4f1c-ffe9-50f4afe2cb83"
|
4120 |
},
|
4121 |
-
"execution_count":
|
4122 |
"outputs": [
|
4123 |
{
|
4124 |
"output_type": "stream",
|
@@ -4253,7 +4253,469 @@
|
|
4253 |
"batch_average_accuray: 0.5\n",
|
4254 |
"batch_average_accuray: 0.5\n",
|
4255 |
"batch_average_accuray: 0.625\n",
|
4256 |
-
"batch_average_accuray: 0.75\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4257 |
]
|
4258 |
}
|
4259 |
]
|
@@ -4277,7 +4739,7 @@
|
|
4277 |
"metadata": {
|
4278 |
"id": "KefqatP-YDSC"
|
4279 |
},
|
4280 |
-
"execution_count":
|
4281 |
"outputs": []
|
4282 |
},
|
4283 |
{
|
@@ -4289,9 +4751,70 @@
|
|
4289 |
"metadata": {
|
4290 |
"id": "Km8eScKJl4VP"
|
4291 |
},
|
4292 |
-
"execution_count":
|
4293 |
"outputs": []
|
4294 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4295 |
{
|
4296 |
"cell_type": "markdown",
|
4297 |
"source": [
|
|
|
922 |
"\n",
|
923 |
"from transformers import pipeline, Trainer, TrainingArguments\n",
|
924 |
"\n",
|
925 |
+
"import numpy as np\n",
|
926 |
"import torch\n",
|
927 |
"import torch.nn.functional as F\n",
|
928 |
"\n",
|
929 |
"from transformers import logging\n",
|
930 |
"\n",
|
931 |
+
"logging.set_verbosity_warning()\n"
|
932 |
],
|
933 |
"metadata": {
|
934 |
"id": "FxZeFFTlFvz1"
|
935 |
},
|
936 |
+
"execution_count": 92,
|
937 |
"outputs": []
|
938 |
},
|
939 |
{
|
|
|
1163 |
"metadata": {
|
1164 |
"colab": {
|
1165 |
"base_uri": "https://localhost:8080/",
|
1166 |
+
"height": 157
|
1167 |
},
|
1168 |
"id": "jDBvcgmP5Puh",
|
1169 |
"outputId": "f4f73693-11f7-4918-a86d-2912e863b151"
|
|
|
1193 |
"metadata": {
|
1194 |
"colab": {
|
1195 |
"base_uri": "https://localhost:8080/",
|
1196 |
+
"height": 105
|
1197 |
},
|
1198 |
"id": "sBhSPSV-5XKS",
|
1199 |
"outputId": "0057e051-3b36-4705-8636-19e7850fa0a9"
|
|
|
4118 |
"id": "h7bzRvkItdir",
|
4119 |
"outputId": "7495ec10-0ee5-4f1c-ffe9-50f4afe2cb83"
|
4120 |
},
|
4121 |
+
"execution_count": 88,
|
4122 |
"outputs": [
|
4123 |
{
|
4124 |
"output_type": "stream",
|
|
|
4253 |
"batch_average_accuray: 0.5\n",
|
4254 |
"batch_average_accuray: 0.5\n",
|
4255 |
"batch_average_accuray: 0.625\n",
|
4256 |
+
"batch_average_accuray: 0.75\n",
|
4257 |
+
"batch_average_accuray: 0.375\n",
|
4258 |
+
"batch_average_accuray: 0.5\n",
|
4259 |
+
"batch_average_accuray: 0.25\n",
|
4260 |
+
"batch_average_accuray: 0.5625\n",
|
4261 |
+
"batch_average_accuray: 0.4375\n",
|
4262 |
+
"batch_average_accuray: 0.75\n",
|
4263 |
+
"batch_average_accuray: 0.375\n",
|
4264 |
+
"batch_average_accuray: 0.5625\n",
|
4265 |
+
"batch_average_accuray: 0.8125\n",
|
4266 |
+
"batch_average_accuray: 0.5625\n",
|
4267 |
+
"batch_average_accuray: 0.5625\n",
|
4268 |
+
"batch_average_accuray: 0.5\n",
|
4269 |
+
"batch_average_accuray: 0.625\n",
|
4270 |
+
"batch_average_accuray: 0.6875\n",
|
4271 |
+
"batch_average_accuray: 0.4375\n",
|
4272 |
+
"batch_average_accuray: 0.625\n",
|
4273 |
+
"batch_average_accuray: 0.625\n",
|
4274 |
+
"batch_average_accuray: 0.5625\n",
|
4275 |
+
"batch_average_accuray: 0.5\n",
|
4276 |
+
"batch_average_accuray: 0.5\n",
|
4277 |
+
"batch_average_accuray: 0.75\n",
|
4278 |
+
"batch_average_accuray: 0.5625\n",
|
4279 |
+
"batch_average_accuray: 0.5625\n",
|
4280 |
+
"batch_average_accuray: 0.5625\n",
|
4281 |
+
"batch_average_accuray: 0.375\n",
|
4282 |
+
"batch_average_accuray: 0.5625\n",
|
4283 |
+
"batch_average_accuray: 0.625\n",
|
4284 |
+
"batch_average_accuray: 0.375\n",
|
4285 |
+
"batch_average_accuray: 0.6875\n",
|
4286 |
+
"batch_average_accuray: 0.5\n",
|
4287 |
+
"batch_average_accuray: 0.625\n",
|
4288 |
+
"batch_average_accuray: 0.4375\n",
|
4289 |
+
"batch_average_accuray: 0.375\n",
|
4290 |
+
"batch_average_accuray: 0.375\n",
|
4291 |
+
"batch_average_accuray: 0.5625\n",
|
4292 |
+
"batch_average_accuray: 0.5625\n",
|
4293 |
+
"batch_average_accuray: 0.375\n",
|
4294 |
+
"batch_average_accuray: 0.4375\n",
|
4295 |
+
"batch_average_accuray: 0.75\n",
|
4296 |
+
"batch_average_accuray: 0.4375\n",
|
4297 |
+
"batch_average_accuray: 0.4375\n",
|
4298 |
+
"batch_average_accuray: 0.5625\n",
|
4299 |
+
"batch_average_accuray: 0.4375\n",
|
4300 |
+
"batch_average_accuray: 0.6875\n",
|
4301 |
+
"batch_average_accuray: 0.625\n",
|
4302 |
+
"batch_average_accuray: 0.6875\n",
|
4303 |
+
"batch_average_accuray: 0.625\n",
|
4304 |
+
"batch_average_accuray: 0.5\n",
|
4305 |
+
"batch_average_accuray: 0.4375\n",
|
4306 |
+
"batch_average_accuray: 0.375\n",
|
4307 |
+
"batch_average_accuray: 0.4375\n",
|
4308 |
+
"batch_average_accuray: 0.625\n",
|
4309 |
+
"batch_average_accuray: 0.625\n",
|
4310 |
+
"batch_average_accuray: 0.625\n",
|
4311 |
+
"batch_average_accuray: 0.75\n",
|
4312 |
+
"batch_average_accuray: 0.6875\n",
|
4313 |
+
"batch_average_accuray: 0.5625\n",
|
4314 |
+
"batch_average_accuray: 0.5\n",
|
4315 |
+
"batch_average_accuray: 0.4375\n",
|
4316 |
+
"batch_average_accuray: 0.5625\n",
|
4317 |
+
"batch_average_accuray: 0.6875\n",
|
4318 |
+
"batch_average_accuray: 0.625\n",
|
4319 |
+
"batch_average_accuray: 0.75\n",
|
4320 |
+
"batch_average_accuray: 0.4375\n",
|
4321 |
+
"batch_average_accuray: 0.4375\n",
|
4322 |
+
"batch_average_accuray: 0.6875\n",
|
4323 |
+
"batch_average_accuray: 0.4375\n",
|
4324 |
+
"batch_average_accuray: 0.5625\n",
|
4325 |
+
"batch_average_accuray: 0.6875\n",
|
4326 |
+
"batch_average_accuray: 0.375\n",
|
4327 |
+
"batch_average_accuray: 0.3125\n",
|
4328 |
+
"batch_average_accuray: 0.5625\n",
|
4329 |
+
"batch_average_accuray: 0.5625\n",
|
4330 |
+
"batch_average_accuray: 0.625\n",
|
4331 |
+
"batch_average_accuray: 0.5\n",
|
4332 |
+
"batch_average_accuray: 0.4375\n",
|
4333 |
+
"batch_average_accuray: 0.5625\n",
|
4334 |
+
"batch_average_accuray: 0.625\n",
|
4335 |
+
"batch_average_accuray: 0.5\n",
|
4336 |
+
"batch_average_accuray: 0.6875\n",
|
4337 |
+
"batch_average_accuray: 0.5625\n",
|
4338 |
+
"batch_average_accuray: 0.375\n",
|
4339 |
+
"batch_average_accuray: 0.5\n",
|
4340 |
+
"batch_average_accuray: 0.4375\n",
|
4341 |
+
"batch_average_accuray: 0.5\n",
|
4342 |
+
"batch_average_accuray: 0.625\n",
|
4343 |
+
"batch_average_accuray: 0.5625\n",
|
4344 |
+
"batch_average_accuray: 0.4375\n",
|
4345 |
+
"batch_average_accuray: 0.5\n",
|
4346 |
+
"batch_average_accuray: 0.5\n",
|
4347 |
+
"batch_average_accuray: 0.4375\n",
|
4348 |
+
"batch_average_accuray: 0.8125\n",
|
4349 |
+
"batch_average_accuray: 0.625\n",
|
4350 |
+
"batch_average_accuray: 0.5\n",
|
4351 |
+
"batch_average_accuray: 0.6875\n",
|
4352 |
+
"batch_average_accuray: 0.5625\n",
|
4353 |
+
"batch_average_accuray: 0.4375\n",
|
4354 |
+
"batch_average_accuray: 0.5\n",
|
4355 |
+
"batch_average_accuray: 0.625\n",
|
4356 |
+
"batch_average_accuray: 0.6875\n",
|
4357 |
+
"batch_average_accuray: 0.25\n",
|
4358 |
+
"batch_average_accuray: 0.625\n",
|
4359 |
+
"batch_average_accuray: 0.5625\n",
|
4360 |
+
"batch_average_accuray: 0.25\n",
|
4361 |
+
"batch_average_accuray: 0.375\n",
|
4362 |
+
"batch_average_accuray: 0.75\n",
|
4363 |
+
"batch_average_accuray: 0.625\n",
|
4364 |
+
"batch_average_accuray: 0.75\n",
|
4365 |
+
"batch_average_accuray: 0.375\n",
|
4366 |
+
"batch_average_accuray: 0.4375\n",
|
4367 |
+
"batch_average_accuray: 0.625\n",
|
4368 |
+
"batch_average_accuray: 0.4375\n",
|
4369 |
+
"batch_average_accuray: 0.5\n",
|
4370 |
+
"batch_average_accuray: 0.75\n",
|
4371 |
+
"batch_average_accuray: 0.3125\n",
|
4372 |
+
"batch_average_accuray: 0.5625\n",
|
4373 |
+
"batch_average_accuray: 0.75\n",
|
4374 |
+
"batch_average_accuray: 0.5625\n",
|
4375 |
+
"batch_average_accuray: 0.75\n",
|
4376 |
+
"batch_average_accuray: 0.5625\n",
|
4377 |
+
"batch_average_accuray: 0.625\n",
|
4378 |
+
"batch_average_accuray: 0.75\n",
|
4379 |
+
"batch_average_accuray: 0.6875\n",
|
4380 |
+
"batch_average_accuray: 0.5625\n",
|
4381 |
+
"batch_average_accuray: 0.6875\n",
|
4382 |
+
"batch_average_accuray: 0.3125\n",
|
4383 |
+
"batch_average_accuray: 0.5\n",
|
4384 |
+
"batch_average_accuray: 0.5\n",
|
4385 |
+
"batch_average_accuray: 0.25\n",
|
4386 |
+
"batch_average_accuray: 0.4375\n",
|
4387 |
+
"batch_average_accuray: 0.375\n",
|
4388 |
+
"batch_average_accuray: 0.375\n",
|
4389 |
+
"batch_average_accuray: 0.5625\n",
|
4390 |
+
"batch_average_accuray: 0.5\n",
|
4391 |
+
"batch_average_accuray: 0.625\n",
|
4392 |
+
"batch_average_accuray: 0.4375\n",
|
4393 |
+
"batch_average_accuray: 0.5\n",
|
4394 |
+
"batch_average_accuray: 0.6875\n",
|
4395 |
+
"batch_average_accuray: 0.5625\n",
|
4396 |
+
"batch_average_accuray: 0.625\n",
|
4397 |
+
"batch_average_accuray: 0.5625\n",
|
4398 |
+
"batch_average_accuray: 0.5625\n",
|
4399 |
+
"batch_average_accuray: 0.6875\n",
|
4400 |
+
"batch_average_accuray: 0.625\n",
|
4401 |
+
"batch_average_accuray: 0.5625\n",
|
4402 |
+
"batch_average_accuray: 0.5\n",
|
4403 |
+
"batch_average_accuray: 0.5625\n",
|
4404 |
+
"batch_average_accuray: 0.6875\n",
|
4405 |
+
"batch_average_accuray: 0.6875\n",
|
4406 |
+
"batch_average_accuray: 0.75\n",
|
4407 |
+
"batch_average_accuray: 0.25\n",
|
4408 |
+
"batch_average_accuray: 0.5\n",
|
4409 |
+
"batch_average_accuray: 0.625\n",
|
4410 |
+
"batch_average_accuray: 0.625\n",
|
4411 |
+
"batch_average_accuray: 0.5625\n",
|
4412 |
+
"batch_average_accuray: 0.5\n",
|
4413 |
+
"batch_average_accuray: 0.375\n",
|
4414 |
+
"batch_average_accuray: 0.6875\n",
|
4415 |
+
"batch_average_accuray: 0.75\n",
|
4416 |
+
"batch_average_accuray: 0.375\n",
|
4417 |
+
"batch_average_accuray: 0.625\n",
|
4418 |
+
"batch_average_accuray: 0.5625\n",
|
4419 |
+
"batch_average_accuray: 0.5\n",
|
4420 |
+
"batch_average_accuray: 0.5\n",
|
4421 |
+
"batch_average_accuray: 0.5\n",
|
4422 |
+
"batch_average_accuray: 0.5625\n",
|
4423 |
+
"batch_average_accuray: 0.375\n",
|
4424 |
+
"batch_average_accuray: 0.625\n",
|
4425 |
+
"batch_average_accuray: 0.5625\n",
|
4426 |
+
"batch_average_accuray: 0.75\n",
|
4427 |
+
"batch_average_accuray: 0.6875\n",
|
4428 |
+
"batch_average_accuray: 0.375\n",
|
4429 |
+
"batch_average_accuray: 0.5625\n",
|
4430 |
+
"batch_average_accuray: 0.5625\n",
|
4431 |
+
"batch_average_accuray: 0.5\n",
|
4432 |
+
"batch_average_accuray: 0.625\n",
|
4433 |
+
"batch_average_accuray: 0.5625\n",
|
4434 |
+
"batch_average_accuray: 0.625\n",
|
4435 |
+
"batch_average_accuray: 0.625\n",
|
4436 |
+
"batch_average_accuray: 0.25\n",
|
4437 |
+
"batch_average_accuray: 0.3125\n",
|
4438 |
+
"batch_average_accuray: 0.5625\n",
|
4439 |
+
"batch_average_accuray: 0.375\n",
|
4440 |
+
"batch_average_accuray: 0.4375\n",
|
4441 |
+
"batch_average_accuray: 0.4375\n",
|
4442 |
+
"batch_average_accuray: 0.375\n",
|
4443 |
+
"batch_average_accuray: 0.8125\n",
|
4444 |
+
"batch_average_accuray: 0.6875\n",
|
4445 |
+
"batch_average_accuray: 0.4375\n",
|
4446 |
+
"batch_average_accuray: 0.5625\n",
|
4447 |
+
"batch_average_accuray: 0.6875\n",
|
4448 |
+
"batch_average_accuray: 0.5\n",
|
4449 |
+
"batch_average_accuray: 0.4375\n",
|
4450 |
+
"batch_average_accuray: 0.375\n",
|
4451 |
+
"batch_average_accuray: 0.5\n",
|
4452 |
+
"batch_average_accuray: 0.4375\n",
|
4453 |
+
"batch_average_accuray: 0.4375\n",
|
4454 |
+
"batch_average_accuray: 0.375\n",
|
4455 |
+
"batch_average_accuray: 0.5\n",
|
4456 |
+
"batch_average_accuray: 0.4375\n",
|
4457 |
+
"batch_average_accuray: 0.5\n",
|
4458 |
+
"batch_average_accuray: 0.4375\n",
|
4459 |
+
"batch_average_accuray: 0.5625\n",
|
4460 |
+
"batch_average_accuray: 0.6875\n",
|
4461 |
+
"batch_average_accuray: 0.5\n",
|
4462 |
+
"batch_average_accuray: 0.75\n",
|
4463 |
+
"batch_average_accuray: 0.625\n",
|
4464 |
+
"batch_average_accuray: 0.625\n",
|
4465 |
+
"batch_average_accuray: 0.5\n",
|
4466 |
+
"batch_average_accuray: 0.375\n",
|
4467 |
+
"batch_average_accuray: 0.5\n",
|
4468 |
+
"batch_average_accuray: 0.8125\n",
|
4469 |
+
"batch_average_accuray: 0.375\n",
|
4470 |
+
"batch_average_accuray: 0.6875\n",
|
4471 |
+
"batch_average_accuray: 0.6875\n",
|
4472 |
+
"batch_average_accuray: 0.5625\n",
|
4473 |
+
"batch_average_accuray: 0.5625\n",
|
4474 |
+
"batch_average_accuray: 0.5625\n",
|
4475 |
+
"batch_average_accuray: 0.5\n",
|
4476 |
+
"batch_average_accuray: 0.5625\n",
|
4477 |
+
"batch_average_accuray: 0.5625\n",
|
4478 |
+
"batch_average_accuray: 0.5\n",
|
4479 |
+
"batch_average_accuray: 0.5625\n",
|
4480 |
+
"batch_average_accuray: 0.4375\n",
|
4481 |
+
"batch_average_accuray: 0.375\n",
|
4482 |
+
"batch_average_accuray: 0.875\n",
|
4483 |
+
"batch_average_accuray: 0.5\n",
|
4484 |
+
"batch_average_accuray: 0.4375\n",
|
4485 |
+
"batch_average_accuray: 0.5\n",
|
4486 |
+
"batch_average_accuray: 0.625\n",
|
4487 |
+
"batch_average_accuray: 0.5\n",
|
4488 |
+
"batch_average_accuray: 0.4375\n",
|
4489 |
+
"batch_average_accuray: 0.6875\n",
|
4490 |
+
"batch_average_accuray: 0.625\n",
|
4491 |
+
"batch_average_accuray: 0.4375\n",
|
4492 |
+
"batch_average_accuray: 0.4375\n",
|
4493 |
+
"batch_average_accuray: 0.4375\n",
|
4494 |
+
"batch_average_accuray: 0.625\n",
|
4495 |
+
"batch_average_accuray: 0.4375\n",
|
4496 |
+
"batch_average_accuray: 0.6875\n",
|
4497 |
+
"batch_average_accuray: 0.625\n",
|
4498 |
+
"batch_average_accuray: 0.5625\n",
|
4499 |
+
"batch_average_accuray: 0.5\n",
|
4500 |
+
"batch_average_accuray: 0.4375\n",
|
4501 |
+
"batch_average_accuray: 0.375\n",
|
4502 |
+
"batch_average_accuray: 0.75\n",
|
4503 |
+
"batch_average_accuray: 0.625\n",
|
4504 |
+
"batch_average_accuray: 0.75\n",
|
4505 |
+
"batch_average_accuray: 0.4375\n",
|
4506 |
+
"batch_average_accuray: 0.4375\n",
|
4507 |
+
"batch_average_accuray: 0.3125\n",
|
4508 |
+
"batch_average_accuray: 0.5\n",
|
4509 |
+
"batch_average_accuray: 0.375\n",
|
4510 |
+
"batch_average_accuray: 0.5\n",
|
4511 |
+
"batch_average_accuray: 0.8125\n",
|
4512 |
+
"batch_average_accuray: 0.4375\n",
|
4513 |
+
"batch_average_accuray: 0.8125\n",
|
4514 |
+
"batch_average_accuray: 0.4375\n",
|
4515 |
+
"batch_average_accuray: 0.75\n",
|
4516 |
+
"batch_average_accuray: 0.625\n",
|
4517 |
+
"batch_average_accuray: 0.6875\n",
|
4518 |
+
"batch_average_accuray: 0.75\n",
|
4519 |
+
"batch_average_accuray: 0.5625\n",
|
4520 |
+
"batch_average_accuray: 0.5625\n",
|
4521 |
+
"batch_average_accuray: 0.6875\n",
|
4522 |
+
"batch_average_accuray: 0.4375\n",
|
4523 |
+
"batch_average_accuray: 0.375\n",
|
4524 |
+
"batch_average_accuray: 0.5\n",
|
4525 |
+
"batch_average_accuray: 0.75\n",
|
4526 |
+
"batch_average_accuray: 0.5\n",
|
4527 |
+
"batch_average_accuray: 0.625\n",
|
4528 |
+
"batch_average_accuray: 0.5\n",
|
4529 |
+
"batch_average_accuray: 0.5625\n",
|
4530 |
+
"batch_average_accuray: 0.25\n",
|
4531 |
+
"batch_average_accuray: 0.6875\n",
|
4532 |
+
"batch_average_accuray: 0.5625\n",
|
4533 |
+
"batch_average_accuray: 0.5\n",
|
4534 |
+
"batch_average_accuray: 0.5\n",
|
4535 |
+
"batch_average_accuray: 0.4375\n",
|
4536 |
+
"batch_average_accuray: 0.375\n",
|
4537 |
+
"batch_average_accuray: 0.625\n",
|
4538 |
+
"batch_average_accuray: 0.6875\n",
|
4539 |
+
"batch_average_accuray: 0.5625\n",
|
4540 |
+
"batch_average_accuray: 0.5\n",
|
4541 |
+
"batch_average_accuray: 0.5\n",
|
4542 |
+
"batch_average_accuray: 0.6875\n",
|
4543 |
+
"batch_average_accuray: 0.5\n",
|
4544 |
+
"batch_average_accuray: 0.5\n",
|
4545 |
+
"batch_average_accuray: 0.5625\n",
|
4546 |
+
"batch_average_accuray: 0.5\n",
|
4547 |
+
"batch_average_accuray: 0.5\n",
|
4548 |
+
"batch_average_accuray: 0.75\n",
|
4549 |
+
"batch_average_accuray: 0.625\n",
|
4550 |
+
"batch_average_accuray: 0.4375\n",
|
4551 |
+
"batch_average_accuray: 0.5625\n",
|
4552 |
+
"batch_average_accuray: 0.625\n",
|
4553 |
+
"batch_average_accuray: 0.625\n",
|
4554 |
+
"batch_average_accuray: 0.4375\n",
|
4555 |
+
"batch_average_accuray: 0.5\n",
|
4556 |
+
"batch_average_accuray: 0.25\n",
|
4557 |
+
"batch_average_accuray: 0.5\n",
|
4558 |
+
"batch_average_accuray: 0.4375\n",
|
4559 |
+
"batch_average_accuray: 0.8125\n",
|
4560 |
+
"batch_average_accuray: 0.75\n",
|
4561 |
+
"batch_average_accuray: 0.6875\n",
|
4562 |
+
"batch_average_accuray: 0.625\n",
|
4563 |
+
"batch_average_accuray: 0.5625\n",
|
4564 |
+
"batch_average_accuray: 0.6875\n",
|
4565 |
+
"batch_average_accuray: 0.625\n",
|
4566 |
+
"batch_average_accuray: 0.5625\n",
|
4567 |
+
"batch_average_accuray: 0.625\n",
|
4568 |
+
"batch_average_accuray: 0.4375\n",
|
4569 |
+
"batch_average_accuray: 0.6875\n",
|
4570 |
+
"batch_average_accuray: 0.3125\n",
|
4571 |
+
"batch_average_accuray: 0.75\n",
|
4572 |
+
"batch_average_accuray: 0.4375\n",
|
4573 |
+
"batch_average_accuray: 0.5625\n",
|
4574 |
+
"batch_average_accuray: 0.5\n",
|
4575 |
+
"batch_average_accuray: 0.6875\n",
|
4576 |
+
"batch_average_accuray: 0.5625\n",
|
4577 |
+
"batch_average_accuray: 0.4375\n",
|
4578 |
+
"batch_average_accuray: 0.75\n",
|
4579 |
+
"batch_average_accuray: 0.5625\n",
|
4580 |
+
"batch_average_accuray: 0.4375\n",
|
4581 |
+
"batch_average_accuray: 0.625\n",
|
4582 |
+
"batch_average_accuray: 0.5625\n",
|
4583 |
+
"batch_average_accuray: 0.5\n",
|
4584 |
+
"batch_average_accuray: 0.4375\n",
|
4585 |
+
"batch_average_accuray: 0.625\n",
|
4586 |
+
"batch_average_accuray: 0.8125\n",
|
4587 |
+
"batch_average_accuray: 0.8125\n",
|
4588 |
+
"batch_average_accuray: 0.5625\n",
|
4589 |
+
"batch_average_accuray: 0.5625\n",
|
4590 |
+
"batch_average_accuray: 0.5625\n",
|
4591 |
+
"batch_average_accuray: 0.6875\n",
|
4592 |
+
"batch_average_accuray: 0.375\n",
|
4593 |
+
"batch_average_accuray: 0.5625\n",
|
4594 |
+
"batch_average_accuray: 0.5625\n",
|
4595 |
+
"batch_average_accuray: 0.375\n",
|
4596 |
+
"batch_average_accuray: 0.625\n",
|
4597 |
+
"batch_average_accuray: 0.4375\n",
|
4598 |
+
"batch_average_accuray: 0.375\n",
|
4599 |
+
"batch_average_accuray: 0.5625\n",
|
4600 |
+
"batch_average_accuray: 0.6875\n",
|
4601 |
+
"batch_average_accuray: 0.625\n",
|
4602 |
+
"batch_average_accuray: 0.375\n",
|
4603 |
+
"batch_average_accuray: 0.625\n",
|
4604 |
+
"batch_average_accuray: 0.5625\n",
|
4605 |
+
"batch_average_accuray: 0.5\n",
|
4606 |
+
"batch_average_accuray: 0.625\n",
|
4607 |
+
"batch_average_accuray: 0.4375\n",
|
4608 |
+
"batch_average_accuray: 0.5\n",
|
4609 |
+
"batch_average_accuray: 0.5625\n",
|
4610 |
+
"batch_average_accuray: 0.5\n",
|
4611 |
+
"batch_average_accuray: 0.4375\n",
|
4612 |
+
"batch_average_accuray: 0.4375\n",
|
4613 |
+
"batch_average_accuray: 0.3125\n",
|
4614 |
+
"batch_average_accuray: 0.75\n",
|
4615 |
+
"batch_average_accuray: 0.75\n",
|
4616 |
+
"batch_average_accuray: 0.625\n",
|
4617 |
+
"batch_average_accuray: 0.5\n",
|
4618 |
+
"batch_average_accuray: 0.25\n",
|
4619 |
+
"batch_average_accuray: 0.5625\n",
|
4620 |
+
"batch_average_accuray: 0.75\n",
|
4621 |
+
"batch_average_accuray: 0.625\n",
|
4622 |
+
"batch_average_accuray: 0.375\n",
|
4623 |
+
"batch_average_accuray: 0.625\n",
|
4624 |
+
"batch_average_accuray: 0.625\n",
|
4625 |
+
"batch_average_accuray: 0.5625\n",
|
4626 |
+
"batch_average_accuray: 0.625\n",
|
4627 |
+
"batch_average_accuray: 0.625\n",
|
4628 |
+
"batch_average_accuray: 0.4375\n",
|
4629 |
+
"batch_average_accuray: 0.5\n",
|
4630 |
+
"batch_average_accuray: 0.75\n",
|
4631 |
+
"batch_average_accuray: 0.4375\n",
|
4632 |
+
"batch_average_accuray: 0.625\n",
|
4633 |
+
"batch_average_accuray: 0.375\n",
|
4634 |
+
"batch_average_accuray: 0.625\n",
|
4635 |
+
"batch_average_accuray: 0.625\n",
|
4636 |
+
"batch_average_accuray: 0.4375\n",
|
4637 |
+
"batch_average_accuray: 0.5625\n",
|
4638 |
+
"batch_average_accuray: 0.3125\n",
|
4639 |
+
"batch_average_accuray: 0.5625\n",
|
4640 |
+
"batch_average_accuray: 0.75\n",
|
4641 |
+
"batch_average_accuray: 0.6875\n",
|
4642 |
+
"batch_average_accuray: 0.375\n",
|
4643 |
+
"batch_average_accuray: 0.5625\n",
|
4644 |
+
"batch_average_accuray: 0.6875\n",
|
4645 |
+
"batch_average_accuray: 0.625\n",
|
4646 |
+
"batch_average_accuray: 0.625\n",
|
4647 |
+
"batch_average_accuray: 0.5625\n",
|
4648 |
+
"batch_average_accuray: 0.375\n",
|
4649 |
+
"batch_average_accuray: 0.5\n",
|
4650 |
+
"batch_average_accuray: 0.5\n",
|
4651 |
+
"batch_average_accuray: 0.5625\n",
|
4652 |
+
"batch_average_accuray: 0.5625\n",
|
4653 |
+
"batch_average_accuray: 0.5625\n",
|
4654 |
+
"batch_average_accuray: 0.4375\n",
|
4655 |
+
"batch_average_accuray: 0.5625\n",
|
4656 |
+
"batch_average_accuray: 0.5\n",
|
4657 |
+
"batch_average_accuray: 0.6875\n",
|
4658 |
+
"batch_average_accuray: 0.375\n",
|
4659 |
+
"batch_average_accuray: 0.4375\n",
|
4660 |
+
"batch_average_accuray: 0.5625\n",
|
4661 |
+
"batch_average_accuray: 0.4375\n",
|
4662 |
+
"batch_average_accuray: 0.6875\n",
|
4663 |
+
"batch_average_accuray: 0.5\n",
|
4664 |
+
"batch_average_accuray: 0.5625\n",
|
4665 |
+
"batch_average_accuray: 0.875\n",
|
4666 |
+
"batch_average_accuray: 0.75\n",
|
4667 |
+
"batch_average_accuray: 0.25\n",
|
4668 |
+
"batch_average_accuray: 0.5\n",
|
4669 |
+
"batch_average_accuray: 0.625\n",
|
4670 |
+
"batch_average_accuray: 0.375\n",
|
4671 |
+
"batch_average_accuray: 0.5625\n",
|
4672 |
+
"batch_average_accuray: 0.5625\n",
|
4673 |
+
"batch_average_accuray: 0.5625\n",
|
4674 |
+
"batch_average_accuray: 0.4375\n",
|
4675 |
+
"batch_average_accuray: 0.5625\n",
|
4676 |
+
"batch_average_accuray: 0.625\n",
|
4677 |
+
"batch_average_accuray: 0.4375\n",
|
4678 |
+
"batch_average_accuray: 0.5625\n",
|
4679 |
+
"batch_average_accuray: 0.375\n",
|
4680 |
+
"batch_average_accuray: 0.625\n",
|
4681 |
+
"batch_average_accuray: 0.4375\n",
|
4682 |
+
"batch_average_accuray: 0.625\n",
|
4683 |
+
"batch_average_accuray: 0.6875\n",
|
4684 |
+
"batch_average_accuray: 0.375\n",
|
4685 |
+
"batch_average_accuray: 0.6875\n",
|
4686 |
+
"batch_average_accuray: 0.5625\n",
|
4687 |
+
"batch_average_accuray: 0.6875\n",
|
4688 |
+
"batch_average_accuray: 0.6875\n",
|
4689 |
+
"batch_average_accuray: 0.4375\n",
|
4690 |
+
"batch_average_accuray: 0.5\n",
|
4691 |
+
"batch_average_accuray: 0.625\n",
|
4692 |
+
"batch_average_accuray: 0.5625\n",
|
4693 |
+
"batch_average_accuray: 0.5625\n",
|
4694 |
+
"batch_average_accuray: 0.5625\n",
|
4695 |
+
"batch_average_accuray: 0.125\n"
|
4696 |
+
]
|
4697 |
+
}
|
4698 |
+
]
|
4699 |
+
},
|
4700 |
+
{
|
4701 |
+
"cell_type": "code",
|
4702 |
+
"source": [
|
4703 |
+
"print(f\"average accuracy: {np.mean(accuracy)}\")"
|
4704 |
+
],
|
4705 |
+
"metadata": {
|
4706 |
+
"colab": {
|
4707 |
+
"base_uri": "https://localhost:8080/"
|
4708 |
+
},
|
4709 |
+
"id": "-Ow1N7MnEc98",
|
4710 |
+
"outputId": "01fddc67-f273-4659-ecfa-fd89e6c78935"
|
4711 |
+
},
|
4712 |
+
"execution_count": 93,
|
4713 |
+
"outputs": [
|
4714 |
+
{
|
4715 |
+
"output_type": "stream",
|
4716 |
+
"name": "stdout",
|
4717 |
+
"text": [
|
4718 |
+
"average accuracy: 0.5421792618629174\n"
|
4719 |
]
|
4720 |
}
|
4721 |
]
|
|
|
4739 |
"metadata": {
|
4740 |
"id": "KefqatP-YDSC"
|
4741 |
},
|
4742 |
+
"execution_count": 94,
|
4743 |
"outputs": []
|
4744 |
},
|
4745 |
{
|
|
|
4751 |
"metadata": {
|
4752 |
"id": "Km8eScKJl4VP"
|
4753 |
},
|
4754 |
+
"execution_count": 95,
|
4755 |
"outputs": []
|
4756 |
},
|
4757 |
+
{
|
4758 |
+
"cell_type": "markdown",
|
4759 |
+
"source": [
|
4760 |
+
"## Testing the saved model"
|
4761 |
+
],
|
4762 |
+
"metadata": {
|
4763 |
+
"id": "dCZQwr_ZE-cB"
|
4764 |
+
}
|
4765 |
+
},
|
4766 |
+
{
|
4767 |
+
"cell_type": "code",
|
4768 |
+
"source": [
|
4769 |
+
"with torch.no_grad():\n",
|
4770 |
+
" outputs = model_saved(batch['input_ids']).logits\n",
|
4771 |
+
" print(outputs)\n",
|
4772 |
+
" predictions = F.softmax(outputs, dim = 1)\n",
|
4773 |
+
" print(predictions)\n",
|
4774 |
+
" labels = torch.argmax(predictions, dim = 1)\n",
|
4775 |
+
" print(labels)\n",
|
4776 |
+
" print(\"--------\")\n",
|
4777 |
+
" print(batch['decision'])\n",
|
4778 |
+
" print(\"--------\")\n",
|
4779 |
+
" res = labels == batch['decision']\n",
|
4780 |
+
" print(res)\n",
|
4781 |
+
" print(res.sum() / batch_size)"
|
4782 |
+
],
|
4783 |
+
"metadata": {
|
4784 |
+
"colab": {
|
4785 |
+
"base_uri": "https://localhost:8080/"
|
4786 |
+
},
|
4787 |
+
"id": "u_iN3BSHFB27",
|
4788 |
+
"outputId": "d73153a7-f156-413c-9e3c-6f2930e8905d"
|
4789 |
+
},
|
4790 |
+
"execution_count": 96,
|
4791 |
+
"outputs": [
|
4792 |
+
{
|
4793 |
+
"output_type": "stream",
|
4794 |
+
"name": "stdout",
|
4795 |
+
"text": [
|
4796 |
+
"tensor([[-0.2934, 0.9680, 4.0130, -8.2634, -8.1291, -8.6447],\n",
|
4797 |
+
" [ 0.5176, 3.2941, 1.8334, -8.3832, -8.6352, -8.5553],\n",
|
4798 |
+
" [-0.4728, 0.9731, 4.1658, -8.1353, -7.9516, -8.5336],\n",
|
4799 |
+
" [-0.4363, 1.1413, 4.1972, -8.3214, -8.2106, -8.7486],\n",
|
4800 |
+
" [-0.3831, 1.4167, 4.0593, -8.5625, -8.5613, -9.0239],\n",
|
4801 |
+
" [ 0.3174, 3.2739, 2.2290, -8.6113, -8.8512, -8.8537]])\n",
|
4802 |
+
"tensor([[1.2706e-02, 4.4856e-02, 9.4243e-01, 4.3923e-06, 5.0237e-06, 2.9996e-06],\n",
|
4803 |
+
" [4.8101e-02, 7.7258e-01, 1.7930e-01, 6.5550e-06, 5.0946e-06, 5.5186e-06],\n",
|
4804 |
+
" [9.2039e-03, 3.9077e-02, 9.5171e-01, 4.3269e-06, 5.1996e-06, 2.9054e-06],\n",
|
4805 |
+
" [9.1980e-03, 4.4548e-02, 9.4624e-01, 3.4612e-06, 3.8667e-06, 2.2579e-06],\n",
|
4806 |
+
" [1.0866e-02, 6.5728e-02, 9.2340e-01, 3.0465e-06, 3.0504e-06, 1.9206e-06],\n",
|
4807 |
+
" [3.7043e-02, 7.1237e-01, 2.5057e-01, 4.9094e-06, 3.8624e-06, 3.8528e-06]])\n",
|
4808 |
+
"tensor([2, 1, 2, 2, 2, 1])\n",
|
4809 |
+
"--------\n",
|
4810 |
+
"tensor([2, 2, 0, 1, 1, 1])\n",
|
4811 |
+
"--------\n",
|
4812 |
+
"tensor([ True, False, False, False, False, True])\n",
|
4813 |
+
"tensor(0.1250)\n"
|
4814 |
+
]
|
4815 |
+
}
|
4816 |
+
]
|
4817 |
+
},
|
4818 |
{
|
4819 |
"cell_type": "markdown",
|
4820 |
"source": [
|