Spaces:
Runtime error
Runtime error
Commit
·
514c12c
1
Parent(s):
b6612ec
add training script
Browse files- train.ipynb +1463 -0
train.ipynb
ADDED
@@ -0,0 +1,1463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "429b26f3-8c61-46cc-b5fc-284add4d018f",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"import json\n",
|
11 |
+
"from tqdm.auto import tqdm\n",
|
12 |
+
"from datasets import load_dataset\n",
|
13 |
+
"import pandas as pd\n",
|
14 |
+
"import numpy as np\n",
|
15 |
+
"import torch\n",
|
16 |
+
"import os\n",
|
17 |
+
"\n",
|
18 |
+
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\""
|
19 |
+
]
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"cell_type": "code",
|
23 |
+
"execution_count": 2,
|
24 |
+
"id": "2a927511-78a0-42d5-861d-9e7af50ff000",
|
25 |
+
"metadata": {},
|
26 |
+
"outputs": [],
|
27 |
+
"source": [
|
28 |
+
"import requests\n",
|
29 |
+
"from bs4 import BeautifulSoup\n",
|
30 |
+
"\n",
|
31 |
+
"page = requests.get('https://arxiv.org/category_taxonomy')\n",
|
32 |
+
"soup = BeautifulSoup(page.content)\n",
|
33 |
+
"tag_to_name = {}\n",
|
34 |
+
"for tag_html in soup.find_all('h4')[1:]:\n",
|
35 |
+
" tag, name = tag_html.text.split(maxsplit=1)\n",
|
36 |
+
" tag_to_name[tag] = name[1:-1]\n",
|
37 |
+
"with open('tag_to_name.json', 'w') as fout:\n",
|
38 |
+
" json.dump(tag_to_name, fout)"
|
39 |
+
]
|
40 |
+
},
|
41 |
+
{
|
42 |
+
"cell_type": "code",
|
43 |
+
"execution_count": 3,
|
44 |
+
"id": "19b75e52-15c0-472e-b737-72c5eea896ec",
|
45 |
+
"metadata": {},
|
46 |
+
"outputs": [],
|
47 |
+
"source": [
|
48 |
+
"tag_to_label = dict(zip(tag_to_name, range(len(tag_to_name))))"
|
49 |
+
]
|
50 |
+
},
|
51 |
+
{
|
52 |
+
"cell_type": "code",
|
53 |
+
"execution_count": 4,
|
54 |
+
"id": "fec2865f-2992-4b3e-9202-8e9b8c5a7da1",
|
55 |
+
"metadata": {},
|
56 |
+
"outputs": [],
|
57 |
+
"source": [
|
58 |
+
"def add_labels(row):\n",
|
59 |
+
" tag_list = eval(row['tag'])\n",
|
60 |
+
" label_ids, label_tags = [], []\n",
|
61 |
+
" for tag_dict in tag_list:\n",
|
62 |
+
" if tag_dict['term'] in tag_to_label:\n",
|
63 |
+
" label_tags.append(tag_dict['term'])\n",
|
64 |
+
" label_ids.append(tag_to_label[tag_dict['term']])\n",
|
65 |
+
" return {'label_ids': label_ids, 'label_tags': label_tags}"
|
66 |
+
]
|
67 |
+
},
|
68 |
+
{
|
69 |
+
"cell_type": "code",
|
70 |
+
"execution_count": 5,
|
71 |
+
"id": "81dff335-093f-4a59-93b5-27d7c57aac9a",
|
72 |
+
"metadata": {},
|
73 |
+
"outputs": [
|
74 |
+
{
|
75 |
+
"name": "stderr",
|
76 |
+
"output_type": "stream",
|
77 |
+
"text": [
|
78 |
+
"Using custom data configuration default-60d1f0f90275ae1e\n",
|
79 |
+
"Found cached dataset json (/root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51)\n"
|
80 |
+
]
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"name": "stdout",
|
84 |
+
"output_type": "stream",
|
85 |
+
"text": [
|
86 |
+
" "
|
87 |
+
]
|
88 |
+
},
|
89 |
+
{
|
90 |
+
"name": "stderr",
|
91 |
+
"output_type": "stream",
|
92 |
+
"text": [
|
93 |
+
"Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-66945521f8e38136.arrow\n"
|
94 |
+
]
|
95 |
+
},
|
96 |
+
{
|
97 |
+
"name": "stdout",
|
98 |
+
"output_type": "stream",
|
99 |
+
"text": [
|
100 |
+
" "
|
101 |
+
]
|
102 |
+
},
|
103 |
+
{
|
104 |
+
"name": "stderr",
|
105 |
+
"output_type": "stream",
|
106 |
+
"text": [
|
107 |
+
"Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-5298549794823409.arrow\n"
|
108 |
+
]
|
109 |
+
},
|
110 |
+
{
|
111 |
+
"name": "stdout",
|
112 |
+
"output_type": "stream",
|
113 |
+
"text": [
|
114 |
+
" "
|
115 |
+
]
|
116 |
+
},
|
117 |
+
{
|
118 |
+
"name": "stderr",
|
119 |
+
"output_type": "stream",
|
120 |
+
"text": [
|
121 |
+
"Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-6c93a706327f5678.arrow\n"
|
122 |
+
]
|
123 |
+
},
|
124 |
+
{
|
125 |
+
"name": "stdout",
|
126 |
+
"output_type": "stream",
|
127 |
+
"text": [
|
128 |
+
" "
|
129 |
+
]
|
130 |
+
},
|
131 |
+
{
|
132 |
+
"name": "stderr",
|
133 |
+
"output_type": "stream",
|
134 |
+
"text": [
|
135 |
+
"Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-ff58b61d0d461ac4.arrow\n"
|
136 |
+
]
|
137 |
+
},
|
138 |
+
{
|
139 |
+
"name": "stdout",
|
140 |
+
"output_type": "stream",
|
141 |
+
"text": [
|
142 |
+
" "
|
143 |
+
]
|
144 |
+
},
|
145 |
+
{
|
146 |
+
"name": "stderr",
|
147 |
+
"output_type": "stream",
|
148 |
+
"text": [
|
149 |
+
"Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-259b966b550351dc.arrow\n"
|
150 |
+
]
|
151 |
+
},
|
152 |
+
{
|
153 |
+
"name": "stdout",
|
154 |
+
"output_type": "stream",
|
155 |
+
"text": [
|
156 |
+
" "
|
157 |
+
]
|
158 |
+
},
|
159 |
+
{
|
160 |
+
"name": "stderr",
|
161 |
+
"output_type": "stream",
|
162 |
+
"text": [
|
163 |
+
"Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-8f0ed2baf297a3db.arrow\n"
|
164 |
+
]
|
165 |
+
},
|
166 |
+
{
|
167 |
+
"name": "stdout",
|
168 |
+
"output_type": "stream",
|
169 |
+
"text": [
|
170 |
+
" "
|
171 |
+
]
|
172 |
+
},
|
173 |
+
{
|
174 |
+
"name": "stderr",
|
175 |
+
"output_type": "stream",
|
176 |
+
"text": [
|
177 |
+
"Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-845944d2885d6a34.arrow\n"
|
178 |
+
]
|
179 |
+
},
|
180 |
+
{
|
181 |
+
"name": "stdout",
|
182 |
+
"output_type": "stream",
|
183 |
+
"text": [
|
184 |
+
" "
|
185 |
+
]
|
186 |
+
},
|
187 |
+
{
|
188 |
+
"name": "stderr",
|
189 |
+
"output_type": "stream",
|
190 |
+
"text": [
|
191 |
+
"Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-8ec43ba6cf3d3eba.arrow\n"
|
192 |
+
]
|
193 |
+
}
|
194 |
+
],
|
195 |
+
"source": [
|
196 |
+
"dataset = load_dataset(\"json\", data_files=\"arxivData.json\", split=\"train\")\n",
|
197 |
+
"dataset = dataset.map(add_labels, num_proc=8)\n",
|
198 |
+
"dataset = dataset.remove_columns(['author', 'day', 'id', 'link', 'month', 'tag', 'year'])"
|
199 |
+
]
|
200 |
+
},
|
201 |
+
{
|
202 |
+
"cell_type": "code",
|
203 |
+
"execution_count": 6,
|
204 |
+
"id": "c9a6ab6a-6a47-4377-a9d9-044c3a395ef3",
|
205 |
+
"metadata": {},
|
206 |
+
"outputs": [
|
207 |
+
{
|
208 |
+
"data": {
|
209 |
+
"text/html": [
|
210 |
+
"<div>\n",
|
211 |
+
"<style scoped>\n",
|
212 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
213 |
+
" vertical-align: middle;\n",
|
214 |
+
" }\n",
|
215 |
+
"\n",
|
216 |
+
" .dataframe tbody tr th {\n",
|
217 |
+
" vertical-align: top;\n",
|
218 |
+
" }\n",
|
219 |
+
"\n",
|
220 |
+
" .dataframe thead th {\n",
|
221 |
+
" text-align: right;\n",
|
222 |
+
" }\n",
|
223 |
+
"</style>\n",
|
224 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
225 |
+
" <thead>\n",
|
226 |
+
" <tr style=\"text-align: right;\">\n",
|
227 |
+
" <th></th>\n",
|
228 |
+
" <th>summary</th>\n",
|
229 |
+
" <th>title</th>\n",
|
230 |
+
" <th>label_ids</th>\n",
|
231 |
+
" <th>label_tags</th>\n",
|
232 |
+
" </tr>\n",
|
233 |
+
" </thead>\n",
|
234 |
+
" <tbody>\n",
|
235 |
+
" <tr>\n",
|
236 |
+
" <th>0</th>\n",
|
237 |
+
" <td>We propose an architecture for VQA which utili...</td>\n",
|
238 |
+
" <td>Dual Recurrent Attention Units for Visual Ques...</td>\n",
|
239 |
+
" <td>[0, 5, 7, 28, 152]</td>\n",
|
240 |
+
" <td>[cs.AI, cs.CL, cs.CV, cs.NE, stat.ML]</td>\n",
|
241 |
+
" </tr>\n",
|
242 |
+
" <tr>\n",
|
243 |
+
" <th>1</th>\n",
|
244 |
+
" <td>In a physical neural system, where storage and...</td>\n",
|
245 |
+
" <td>A Theory of Local Learning, the Learning Chann...</td>\n",
|
246 |
+
" <td>[22, 28, 152]</td>\n",
|
247 |
+
" <td>[cs.LG, cs.NE, stat.ML]</td>\n",
|
248 |
+
" </tr>\n",
|
249 |
+
" <tr>\n",
|
250 |
+
" <th>2</th>\n",
|
251 |
+
" <td>One way to approach end-to-end autonomous driv...</td>\n",
|
252 |
+
" <td>Query-Efficient Imitation Learning for End-to-...</td>\n",
|
253 |
+
" <td>[22, 0, 34]</td>\n",
|
254 |
+
" <td>[cs.LG, cs.AI, cs.RO]</td>\n",
|
255 |
+
" </tr>\n",
|
256 |
+
" </tbody>\n",
|
257 |
+
"</table>\n",
|
258 |
+
"</div>"
|
259 |
+
],
|
260 |
+
"text/plain": [
|
261 |
+
" summary \\\n",
|
262 |
+
"0 We propose an architecture for VQA which utili... \n",
|
263 |
+
"1 In a physical neural system, where storage and... \n",
|
264 |
+
"2 One way to approach end-to-end autonomous driv... \n",
|
265 |
+
"\n",
|
266 |
+
" title label_ids \\\n",
|
267 |
+
"0 Dual Recurrent Attention Units for Visual Ques... [0, 5, 7, 28, 152] \n",
|
268 |
+
"1 A Theory of Local Learning, the Learning Chann... [22, 28, 152] \n",
|
269 |
+
"2 Query-Efficient Imitation Learning for End-to-... [22, 0, 34] \n",
|
270 |
+
"\n",
|
271 |
+
" label_tags \n",
|
272 |
+
"0 [cs.AI, cs.CL, cs.CV, cs.NE, stat.ML] \n",
|
273 |
+
"1 [cs.LG, cs.NE, stat.ML] \n",
|
274 |
+
"2 [cs.LG, cs.AI, cs.RO] "
|
275 |
+
]
|
276 |
+
},
|
277 |
+
"execution_count": 6,
|
278 |
+
"metadata": {},
|
279 |
+
"output_type": "execute_result"
|
280 |
+
}
|
281 |
+
],
|
282 |
+
"source": [
|
283 |
+
"pd.DataFrame(dataset.select([0, 1000, 10000]))"
|
284 |
+
]
|
285 |
+
},
|
286 |
+
{
|
287 |
+
"cell_type": "code",
|
288 |
+
"execution_count": 7,
|
289 |
+
"id": "c193d04b-5def-443f-b723-1e3cf9df4d9e",
|
290 |
+
"metadata": {},
|
291 |
+
"outputs": [
|
292 |
+
{
|
293 |
+
"name": "stderr",
|
294 |
+
"output_type": "stream",
|
295 |
+
"text": [
|
296 |
+
"Loading cached split indices for dataset at /root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-7ce5346705e1f437.arrow and /root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-981e0a6e9da25ee7.arrow\n",
|
297 |
+
"Loading cached split indices for dataset at /root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-1ab388509804381c.arrow and /root/.cache/huggingface/datasets/json/default-60d1f0f90275ae1e/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-eac731b57f161563.arrow\n"
|
298 |
+
]
|
299 |
+
},
|
300 |
+
{
|
301 |
+
"data": {
|
302 |
+
"text/plain": [
|
303 |
+
"DatasetDict({\n",
|
304 |
+
" train: Dataset({\n",
|
305 |
+
" features: ['summary', 'title', 'label_ids', 'label_tags'],\n",
|
306 |
+
" num_rows: 38952\n",
|
307 |
+
" })\n",
|
308 |
+
" val: Dataset({\n",
|
309 |
+
" features: ['summary', 'title', 'label_ids', 'label_tags'],\n",
|
310 |
+
" num_rows: 1024\n",
|
311 |
+
" })\n",
|
312 |
+
" test: Dataset({\n",
|
313 |
+
" features: ['summary', 'title', 'label_ids', 'label_tags'],\n",
|
314 |
+
" num_rows: 1024\n",
|
315 |
+
" })\n",
|
316 |
+
"})"
|
317 |
+
]
|
318 |
+
},
|
319 |
+
"execution_count": 7,
|
320 |
+
"metadata": {},
|
321 |
+
"output_type": "execute_result"
|
322 |
+
}
|
323 |
+
],
|
324 |
+
"source": [
|
325 |
+
"from datasets import DatasetDict\n",
|
326 |
+
"\n",
|
327 |
+
"dataset = dataset.train_test_split(test_size=2048, seed=0)\n",
|
328 |
+
"dataset_val = dataset['test'].train_test_split(test_size=1024, seed=0)\n",
|
329 |
+
"\n",
|
330 |
+
"dataset = DatasetDict({\n",
|
331 |
+
" 'train': dataset['train'],\n",
|
332 |
+
" 'val': dataset_val['train'],\n",
|
333 |
+
" 'test': dataset_val['test'],\n",
|
334 |
+
"})\n",
|
335 |
+
"\n",
|
336 |
+
"dataset"
|
337 |
+
]
|
338 |
+
},
|
339 |
+
{
|
340 |
+
"cell_type": "code",
|
341 |
+
"execution_count": 17,
|
342 |
+
"id": "2544c24b-d2ed-4fba-bb86-75469053db8c",
|
343 |
+
"metadata": {},
|
344 |
+
"outputs": [],
|
345 |
+
"source": [
|
346 |
+
"def get_collator(tokenizer, abstract_proba=0.5, num_labels=len(tag_to_label)):\n",
|
347 |
+
" def collate_fn(rows):\n",
|
348 |
+
" texts = []\n",
|
349 |
+
" take_abstracts = np.random.rand(len(rows)) < abstract_proba\n",
|
350 |
+
" for row, take_abstract in zip(rows, take_abstracts):\n",
|
351 |
+
" if take_abstract:\n",
|
352 |
+
" texts.append(row['title'] + '[SEP]' + row['summary'])\n",
|
353 |
+
" else:\n",
|
354 |
+
" texts.append(row['title'])\n",
|
355 |
+
" processed = tokenizer(texts, truncation=True, return_tensors='pt', padding=True, max_length=512)\n",
|
356 |
+
" labels = torch.zeros(size=(len(rows), num_labels), dtype=torch.float)\n",
|
357 |
+
" for i, row in enumerate(rows):\n",
|
358 |
+
" labels[i, row['label_ids']] = 1 / len(row['label_ids'])\n",
|
359 |
+
" processed['labels'] = labels\n",
|
360 |
+
" return processed\n",
|
361 |
+
" return collate_fn"
|
362 |
+
]
|
363 |
+
},
|
364 |
+
{
|
365 |
+
"cell_type": "code",
|
366 |
+
"execution_count": 9,
|
367 |
+
"id": "33934717-57ca-49e8-8354-3eafe503bcf0",
|
368 |
+
"metadata": {
|
369 |
+
"tags": []
|
370 |
+
},
|
371 |
+
"outputs": [
|
372 |
+
{
|
373 |
+
"name": "stderr",
|
374 |
+
"output_type": "stream",
|
375 |
+
"text": [
|
376 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
|
377 |
+
"/usr/local/lib/python3.8/dist-packages/transformers/convert_slow_tokenizer.py:446: UserWarning: The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option which is not implemented in the fast tokenizers. In practice this means that the fast version of the tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these unknown tokens into a sequence of byte tokens matching the original piece of text.\n",
|
378 |
+
" warnings.warn(\n",
|
379 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
|
380 |
+
"Some weights of the model checkpoint at microsoft/deberta-v3-base were not used when initializing DebertaV2ForSequenceClassification: ['lm_predictions.lm_head.LayerNorm.bias', 'lm_predictions.lm_head.bias', 'mask_predictions.classifier.weight', 'mask_predictions.LayerNorm.bias', 'mask_predictions.classifier.bias', 'mask_predictions.dense.weight', 'lm_predictions.lm_head.dense.bias', 'mask_predictions.LayerNorm.weight', 'lm_predictions.lm_head.LayerNorm.weight', 'mask_predictions.dense.bias', 'lm_predictions.lm_head.dense.weight']\n",
|
381 |
+
"- This IS expected if you are initializing DebertaV2ForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
382 |
+
"- This IS NOT expected if you are initializing DebertaV2ForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
383 |
+
"Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at microsoft/deberta-v3-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.weight', 'pooler.dense.bias']\n",
|
384 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
385 |
+
]
|
386 |
+
}
|
387 |
+
],
|
388 |
+
"source": [
|
389 |
+
"from transformers import AutoTokenizer\n",
|
390 |
+
"from transformers import AutoModelForSequenceClassification\n",
|
391 |
+
"\n",
|
392 |
+
"tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-base')\n",
|
393 |
+
"\n",
|
394 |
+
"model = AutoModelForSequenceClassification.from_pretrained(\n",
|
395 |
+
" 'microsoft/deberta-v3-base',\n",
|
396 |
+
" problem_type=None, # https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L1349\n",
|
397 |
+
" num_labels=len(tag_to_label), id2label={v: k for k, v in tag_to_label.items()}, label2id=tag_to_label)\n"
|
398 |
+
]
|
399 |
+
},
|
400 |
+
{
|
401 |
+
"cell_type": "code",
|
402 |
+
"execution_count": 10,
|
403 |
+
"id": "588d769e-d44b-4367-a22b-7b9b87cb5319",
|
404 |
+
"metadata": {},
|
405 |
+
"outputs": [
|
406 |
+
{
|
407 |
+
"name": "stderr",
|
408 |
+
"output_type": "stream",
|
409 |
+
"text": [
|
410 |
+
"max_steps is given, it will override any value given in num_train_epochs\n"
|
411 |
+
]
|
412 |
+
}
|
413 |
+
],
|
414 |
+
"source": [
|
415 |
+
"from transformers import TrainingArguments, Trainer\n",
|
416 |
+
"\n",
|
417 |
+
"training_args = TrainingArguments(\n",
|
418 |
+
" output_dir='checkpoints',\n",
|
419 |
+
" learning_rate=2e-5,\n",
|
420 |
+
" per_device_train_batch_size=24,\n",
|
421 |
+
" per_device_eval_batch_size=24,\n",
|
422 |
+
" weight_decay=0.01,\n",
|
423 |
+
" warmup_ratio=0.02,\n",
|
424 |
+
" logging_steps=100,\n",
|
425 |
+
" overwrite_output_dir=True,\n",
|
426 |
+
" seed=0,\n",
|
427 |
+
" dataloader_num_workers=8,\n",
|
428 |
+
" do_train=True,\n",
|
429 |
+
" do_eval=True,\n",
|
430 |
+
" max_steps=5000,\n",
|
431 |
+
" save_strategy=\"steps\",\n",
|
432 |
+
" evaluation_strategy=\"steps\",\n",
|
433 |
+
" eval_steps=100,\n",
|
434 |
+
" save_steps=100,\n",
|
435 |
+
" save_total_limit=2,\n",
|
436 |
+
" lr_scheduler_type=\"linear\",\n",
|
437 |
+
" load_best_model_at_end=True,\n",
|
438 |
+
" report_to=\"tensorboard\",\n",
|
439 |
+
" remove_unused_columns=False,\n",
|
440 |
+
")\n",
|
441 |
+
"\n",
|
442 |
+
"trainer = Trainer(\n",
|
443 |
+
" model=model,\n",
|
444 |
+
" args=training_args,\n",
|
445 |
+
" train_dataset=dataset['train'],\n",
|
446 |
+
" eval_dataset=dataset['val'],\n",
|
447 |
+
" tokenizer=tokenizer,\n",
|
448 |
+
" data_collator=get_collator(tokenizer),\n",
|
449 |
+
")"
|
450 |
+
]
|
451 |
+
},
|
452 |
+
{
|
453 |
+
"cell_type": "code",
|
454 |
+
"execution_count": 11,
|
455 |
+
"id": "04d3ccf6-193b-4ee2-ad4d-2cdbb3c7b737",
|
456 |
+
"metadata": {
|
457 |
+
"tags": []
|
458 |
+
},
|
459 |
+
"outputs": [
|
460 |
+
{
|
461 |
+
"name": "stderr",
|
462 |
+
"output_type": "stream",
|
463 |
+
"text": [
|
464 |
+
"/usr/local/lib/python3.8/dist-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
|
465 |
+
" warnings.warn(\n",
|
466 |
+
"***** Running training *****\n",
|
467 |
+
" Num examples = 38952\n",
|
468 |
+
" Num Epochs = 4\n",
|
469 |
+
" Instantaneous batch size per device = 24\n",
|
470 |
+
" Total train batch size (w. parallel, distributed & accumulation) = 24\n",
|
471 |
+
" Gradient Accumulation steps = 1\n",
|
472 |
+
" Total optimization steps = 5000\n",
|
473 |
+
" Number of trainable parameters = 184541339\n"
|
474 |
+
]
|
475 |
+
},
|
476 |
+
{
|
477 |
+
"data": {
|
478 |
+
"text/html": [
|
479 |
+
"\n",
|
480 |
+
" <div>\n",
|
481 |
+
" \n",
|
482 |
+
" <progress value='5000' max='5000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
483 |
+
" [5000/5000 22:52, Epoch 3/4]\n",
|
484 |
+
" </div>\n",
|
485 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
486 |
+
" <thead>\n",
|
487 |
+
" <tr style=\"text-align: left;\">\n",
|
488 |
+
" <th>Step</th>\n",
|
489 |
+
" <th>Training Loss</th>\n",
|
490 |
+
" <th>Validation Loss</th>\n",
|
491 |
+
" </tr>\n",
|
492 |
+
" </thead>\n",
|
493 |
+
" <tbody>\n",
|
494 |
+
" <tr>\n",
|
495 |
+
" <td>100</td>\n",
|
496 |
+
" <td>4.286100</td>\n",
|
497 |
+
" <td>2.809958</td>\n",
|
498 |
+
" </tr>\n",
|
499 |
+
" <tr>\n",
|
500 |
+
" <td>200</td>\n",
|
501 |
+
" <td>2.365700</td>\n",
|
502 |
+
" <td>2.110714</td>\n",
|
503 |
+
" </tr>\n",
|
504 |
+
" <tr>\n",
|
505 |
+
" <td>300</td>\n",
|
506 |
+
" <td>2.023600</td>\n",
|
507 |
+
" <td>2.046348</td>\n",
|
508 |
+
" </tr>\n",
|
509 |
+
" <tr>\n",
|
510 |
+
" <td>400</td>\n",
|
511 |
+
" <td>2.020400</td>\n",
|
512 |
+
" <td>1.982979</td>\n",
|
513 |
+
" </tr>\n",
|
514 |
+
" <tr>\n",
|
515 |
+
" <td>500</td>\n",
|
516 |
+
" <td>1.927300</td>\n",
|
517 |
+
" <td>1.915667</td>\n",
|
518 |
+
" </tr>\n",
|
519 |
+
" <tr>\n",
|
520 |
+
" <td>600</td>\n",
|
521 |
+
" <td>1.919500</td>\n",
|
522 |
+
" <td>1.927610</td>\n",
|
523 |
+
" </tr>\n",
|
524 |
+
" <tr>\n",
|
525 |
+
" <td>700</td>\n",
|
526 |
+
" <td>1.834600</td>\n",
|
527 |
+
" <td>1.929402</td>\n",
|
528 |
+
" </tr>\n",
|
529 |
+
" <tr>\n",
|
530 |
+
" <td>800</td>\n",
|
531 |
+
" <td>1.840800</td>\n",
|
532 |
+
" <td>1.861055</td>\n",
|
533 |
+
" </tr>\n",
|
534 |
+
" <tr>\n",
|
535 |
+
" <td>900</td>\n",
|
536 |
+
" <td>1.823900</td>\n",
|
537 |
+
" <td>1.819358</td>\n",
|
538 |
+
" </tr>\n",
|
539 |
+
" <tr>\n",
|
540 |
+
" <td>1000</td>\n",
|
541 |
+
" <td>1.757100</td>\n",
|
542 |
+
" <td>1.798097</td>\n",
|
543 |
+
" </tr>\n",
|
544 |
+
" <tr>\n",
|
545 |
+
" <td>1100</td>\n",
|
546 |
+
" <td>1.746500</td>\n",
|
547 |
+
" <td>1.779167</td>\n",
|
548 |
+
" </tr>\n",
|
549 |
+
" <tr>\n",
|
550 |
+
" <td>1200</td>\n",
|
551 |
+
" <td>1.775000</td>\n",
|
552 |
+
" <td>1.774340</td>\n",
|
553 |
+
" </tr>\n",
|
554 |
+
" <tr>\n",
|
555 |
+
" <td>1300</td>\n",
|
556 |
+
" <td>1.698500</td>\n",
|
557 |
+
" <td>1.764457</td>\n",
|
558 |
+
" </tr>\n",
|
559 |
+
" <tr>\n",
|
560 |
+
" <td>1400</td>\n",
|
561 |
+
" <td>1.684200</td>\n",
|
562 |
+
" <td>1.741629</td>\n",
|
563 |
+
" </tr>\n",
|
564 |
+
" <tr>\n",
|
565 |
+
" <td>1500</td>\n",
|
566 |
+
" <td>1.763000</td>\n",
|
567 |
+
" <td>1.680664</td>\n",
|
568 |
+
" </tr>\n",
|
569 |
+
" <tr>\n",
|
570 |
+
" <td>1600</td>\n",
|
571 |
+
" <td>1.678400</td>\n",
|
572 |
+
" <td>1.712918</td>\n",
|
573 |
+
" </tr>\n",
|
574 |
+
" <tr>\n",
|
575 |
+
" <td>1700</td>\n",
|
576 |
+
" <td>1.669800</td>\n",
|
577 |
+
" <td>1.710484</td>\n",
|
578 |
+
" </tr>\n",
|
579 |
+
" <tr>\n",
|
580 |
+
" <td>1800</td>\n",
|
581 |
+
" <td>1.665000</td>\n",
|
582 |
+
" <td>1.698851</td>\n",
|
583 |
+
" </tr>\n",
|
584 |
+
" <tr>\n",
|
585 |
+
" <td>1900</td>\n",
|
586 |
+
" <td>1.645200</td>\n",
|
587 |
+
" <td>1.663767</td>\n",
|
588 |
+
" </tr>\n",
|
589 |
+
" <tr>\n",
|
590 |
+
" <td>2000</td>\n",
|
591 |
+
" <td>1.667600</td>\n",
|
592 |
+
" <td>1.674545</td>\n",
|
593 |
+
" </tr>\n",
|
594 |
+
" <tr>\n",
|
595 |
+
" <td>2100</td>\n",
|
596 |
+
" <td>1.602300</td>\n",
|
597 |
+
" <td>1.680639</td>\n",
|
598 |
+
" </tr>\n",
|
599 |
+
" <tr>\n",
|
600 |
+
" <td>2200</td>\n",
|
601 |
+
" <td>1.651800</td>\n",
|
602 |
+
" <td>1.667343</td>\n",
|
603 |
+
" </tr>\n",
|
604 |
+
" <tr>\n",
|
605 |
+
" <td>2300</td>\n",
|
606 |
+
" <td>1.622600</td>\n",
|
607 |
+
" <td>1.659117</td>\n",
|
608 |
+
" </tr>\n",
|
609 |
+
" <tr>\n",
|
610 |
+
" <td>2400</td>\n",
|
611 |
+
" <td>1.616900</td>\n",
|
612 |
+
" <td>1.645381</td>\n",
|
613 |
+
" </tr>\n",
|
614 |
+
" <tr>\n",
|
615 |
+
" <td>2500</td>\n",
|
616 |
+
" <td>1.600900</td>\n",
|
617 |
+
" <td>1.642603</td>\n",
|
618 |
+
" </tr>\n",
|
619 |
+
" <tr>\n",
|
620 |
+
" <td>2600</td>\n",
|
621 |
+
" <td>1.590200</td>\n",
|
622 |
+
" <td>1.657698</td>\n",
|
623 |
+
" </tr>\n",
|
624 |
+
" <tr>\n",
|
625 |
+
" <td>2700</td>\n",
|
626 |
+
" <td>1.646300</td>\n",
|
627 |
+
" <td>1.644075</td>\n",
|
628 |
+
" </tr>\n",
|
629 |
+
" <tr>\n",
|
630 |
+
" <td>2800</td>\n",
|
631 |
+
" <td>1.602600</td>\n",
|
632 |
+
" <td>1.626339</td>\n",
|
633 |
+
" </tr>\n",
|
634 |
+
" <tr>\n",
|
635 |
+
" <td>2900</td>\n",
|
636 |
+
" <td>1.596800</td>\n",
|
637 |
+
" <td>1.646950</td>\n",
|
638 |
+
" </tr>\n",
|
639 |
+
" <tr>\n",
|
640 |
+
" <td>3000</td>\n",
|
641 |
+
" <td>1.547200</td>\n",
|
642 |
+
" <td>1.622913</td>\n",
|
643 |
+
" </tr>\n",
|
644 |
+
" <tr>\n",
|
645 |
+
" <td>3100</td>\n",
|
646 |
+
" <td>1.563500</td>\n",
|
647 |
+
" <td>1.611651</td>\n",
|
648 |
+
" </tr>\n",
|
649 |
+
" <tr>\n",
|
650 |
+
" <td>3200</td>\n",
|
651 |
+
" <td>1.583500</td>\n",
|
652 |
+
" <td>1.608005</td>\n",
|
653 |
+
" </tr>\n",
|
654 |
+
" <tr>\n",
|
655 |
+
" <td>3300</td>\n",
|
656 |
+
" <td>1.565800</td>\n",
|
657 |
+
" <td>1.626086</td>\n",
|
658 |
+
" </tr>\n",
|
659 |
+
" <tr>\n",
|
660 |
+
" <td>3400</td>\n",
|
661 |
+
" <td>1.531000</td>\n",
|
662 |
+
" <td>1.626902</td>\n",
|
663 |
+
" </tr>\n",
|
664 |
+
" <tr>\n",
|
665 |
+
" <td>3500</td>\n",
|
666 |
+
" <td>1.566100</td>\n",
|
667 |
+
" <td>1.607745</td>\n",
|
668 |
+
" </tr>\n",
|
669 |
+
" <tr>\n",
|
670 |
+
" <td>3600</td>\n",
|
671 |
+
" <td>1.555100</td>\n",
|
672 |
+
" <td>1.594658</td>\n",
|
673 |
+
" </tr>\n",
|
674 |
+
" <tr>\n",
|
675 |
+
" <td>3700</td>\n",
|
676 |
+
" <td>1.597600</td>\n",
|
677 |
+
" <td>1.597994</td>\n",
|
678 |
+
" </tr>\n",
|
679 |
+
" <tr>\n",
|
680 |
+
" <td>3800</td>\n",
|
681 |
+
" <td>1.497600</td>\n",
|
682 |
+
" <td>1.590335</td>\n",
|
683 |
+
" </tr>\n",
|
684 |
+
" <tr>\n",
|
685 |
+
" <td>3900</td>\n",
|
686 |
+
" <td>1.522300</td>\n",
|
687 |
+
" <td>1.588875</td>\n",
|
688 |
+
" </tr>\n",
|
689 |
+
" <tr>\n",
|
690 |
+
" <td>4000</td>\n",
|
691 |
+
" <td>1.506600</td>\n",
|
692 |
+
" <td>1.572686</td>\n",
|
693 |
+
" </tr>\n",
|
694 |
+
" <tr>\n",
|
695 |
+
" <td>4100</td>\n",
|
696 |
+
" <td>1.497900</td>\n",
|
697 |
+
" <td>1.602122</td>\n",
|
698 |
+
" </tr>\n",
|
699 |
+
" <tr>\n",
|
700 |
+
" <td>4200</td>\n",
|
701 |
+
" <td>1.534100</td>\n",
|
702 |
+
" <td>1.576102</td>\n",
|
703 |
+
" </tr>\n",
|
704 |
+
" <tr>\n",
|
705 |
+
" <td>4300</td>\n",
|
706 |
+
" <td>1.517400</td>\n",
|
707 |
+
" <td>1.578320</td>\n",
|
708 |
+
" </tr>\n",
|
709 |
+
" <tr>\n",
|
710 |
+
" <td>4400</td>\n",
|
711 |
+
" <td>1.518500</td>\n",
|
712 |
+
" <td>1.588920</td>\n",
|
713 |
+
" </tr>\n",
|
714 |
+
" <tr>\n",
|
715 |
+
" <td>4500</td>\n",
|
716 |
+
" <td>1.510200</td>\n",
|
717 |
+
" <td>1.596100</td>\n",
|
718 |
+
" </tr>\n",
|
719 |
+
" <tr>\n",
|
720 |
+
" <td>4600</td>\n",
|
721 |
+
" <td>1.441100</td>\n",
|
722 |
+
" <td>1.576099</td>\n",
|
723 |
+
" </tr>\n",
|
724 |
+
" <tr>\n",
|
725 |
+
" <td>4700</td>\n",
|
726 |
+
" <td>1.511000</td>\n",
|
727 |
+
" <td>1.575001</td>\n",
|
728 |
+
" </tr>\n",
|
729 |
+
" <tr>\n",
|
730 |
+
" <td>4800</td>\n",
|
731 |
+
" <td>1.487700</td>\n",
|
732 |
+
" <td>1.579319</td>\n",
|
733 |
+
" </tr>\n",
|
734 |
+
" <tr>\n",
|
735 |
+
" <td>4900</td>\n",
|
736 |
+
" <td>1.491300</td>\n",
|
737 |
+
" <td>1.591276</td>\n",
|
738 |
+
" </tr>\n",
|
739 |
+
" <tr>\n",
|
740 |
+
" <td>5000</td>\n",
|
741 |
+
" <td>1.474700</td>\n",
|
742 |
+
" <td>1.572709</td>\n",
|
743 |
+
" </tr>\n",
|
744 |
+
" </tbody>\n",
|
745 |
+
"</table><p>"
|
746 |
+
],
|
747 |
+
"text/plain": [
|
748 |
+
"<IPython.core.display.HTML object>"
|
749 |
+
]
|
750 |
+
},
|
751 |
+
"metadata": {},
|
752 |
+
"output_type": "display_data"
|
753 |
+
},
|
754 |
+
{
|
755 |
+
"name": "stderr",
|
756 |
+
"output_type": "stream",
|
757 |
+
"text": [
|
758 |
+
"***** Running Evaluation *****\n",
|
759 |
+
" Num examples = 1024\n",
|
760 |
+
" Batch size = 24\n",
|
761 |
+
"Saving model checkpoint to checkpoints/checkpoint-100\n",
|
762 |
+
"Configuration saved in checkpoints/checkpoint-100/config.json\n",
|
763 |
+
"Model weights saved in checkpoints/checkpoint-100/pytorch_model.bin\n",
|
764 |
+
"tokenizer config file saved in checkpoints/checkpoint-100/tokenizer_config.json\n",
|
765 |
+
"Special tokens file saved in checkpoints/checkpoint-100/special_tokens_map.json\n",
|
766 |
+
"***** Running Evaluation *****\n",
|
767 |
+
" Num examples = 1024\n",
|
768 |
+
" Batch size = 24\n",
|
769 |
+
"Saving model checkpoint to checkpoints/checkpoint-200\n",
|
770 |
+
"Configuration saved in checkpoints/checkpoint-200/config.json\n",
|
771 |
+
"Model weights saved in checkpoints/checkpoint-200/pytorch_model.bin\n",
|
772 |
+
"tokenizer config file saved in checkpoints/checkpoint-200/tokenizer_config.json\n",
|
773 |
+
"Special tokens file saved in checkpoints/checkpoint-200/special_tokens_map.json\n",
|
774 |
+
"***** Running Evaluation *****\n",
|
775 |
+
" Num examples = 1024\n",
|
776 |
+
" Batch size = 24\n",
|
777 |
+
"Saving model checkpoint to checkpoints/checkpoint-300\n",
|
778 |
+
"Configuration saved in checkpoints/checkpoint-300/config.json\n",
|
779 |
+
"Model weights saved in checkpoints/checkpoint-300/pytorch_model.bin\n",
|
780 |
+
"tokenizer config file saved in checkpoints/checkpoint-300/tokenizer_config.json\n",
|
781 |
+
"Special tokens file saved in checkpoints/checkpoint-300/special_tokens_map.json\n",
|
782 |
+
"Deleting older checkpoint [checkpoints/checkpoint-100] due to args.save_total_limit\n",
|
783 |
+
"***** Running Evaluation *****\n",
|
784 |
+
" Num examples = 1024\n",
|
785 |
+
" Batch size = 24\n",
|
786 |
+
"Saving model checkpoint to checkpoints/checkpoint-400\n",
|
787 |
+
"Configuration saved in checkpoints/checkpoint-400/config.json\n",
|
788 |
+
"Model weights saved in checkpoints/checkpoint-400/pytorch_model.bin\n",
|
789 |
+
"tokenizer config file saved in checkpoints/checkpoint-400/tokenizer_config.json\n",
|
790 |
+
"Special tokens file saved in checkpoints/checkpoint-400/special_tokens_map.json\n",
|
791 |
+
"Deleting older checkpoint [checkpoints/checkpoint-200] due to args.save_total_limit\n",
|
792 |
+
"***** Running Evaluation *****\n",
|
793 |
+
" Num examples = 1024\n",
|
794 |
+
" Batch size = 24\n",
|
795 |
+
"Saving model checkpoint to checkpoints/checkpoint-500\n",
|
796 |
+
"Configuration saved in checkpoints/checkpoint-500/config.json\n",
|
797 |
+
"Model weights saved in checkpoints/checkpoint-500/pytorch_model.bin\n",
|
798 |
+
"tokenizer config file saved in checkpoints/checkpoint-500/tokenizer_config.json\n",
|
799 |
+
"Special tokens file saved in checkpoints/checkpoint-500/special_tokens_map.json\n",
|
800 |
+
"Deleting older checkpoint [checkpoints/checkpoint-300] due to args.save_total_limit\n",
|
801 |
+
"***** Running Evaluation *****\n",
|
802 |
+
" Num examples = 1024\n",
|
803 |
+
" Batch size = 24\n",
|
804 |
+
"Saving model checkpoint to checkpoints/checkpoint-600\n",
|
805 |
+
"Configuration saved in checkpoints/checkpoint-600/config.json\n",
|
806 |
+
"Model weights saved in checkpoints/checkpoint-600/pytorch_model.bin\n",
|
807 |
+
"tokenizer config file saved in checkpoints/checkpoint-600/tokenizer_config.json\n",
|
808 |
+
"Special tokens file saved in checkpoints/checkpoint-600/special_tokens_map.json\n",
|
809 |
+
"Deleting older checkpoint [checkpoints/checkpoint-400] due to args.save_total_limit\n",
|
810 |
+
"***** Running Evaluation *****\n",
|
811 |
+
" Num examples = 1024\n",
|
812 |
+
" Batch size = 24\n",
|
813 |
+
"Saving model checkpoint to checkpoints/checkpoint-700\n",
|
814 |
+
"Configuration saved in checkpoints/checkpoint-700/config.json\n",
|
815 |
+
"Model weights saved in checkpoints/checkpoint-700/pytorch_model.bin\n",
|
816 |
+
"tokenizer config file saved in checkpoints/checkpoint-700/tokenizer_config.json\n",
|
817 |
+
"Special tokens file saved in checkpoints/checkpoint-700/special_tokens_map.json\n",
|
818 |
+
"Deleting older checkpoint [checkpoints/checkpoint-600] due to args.save_total_limit\n",
|
819 |
+
"***** Running Evaluation *****\n",
|
820 |
+
" Num examples = 1024\n",
|
821 |
+
" Batch size = 24\n",
|
822 |
+
"Saving model checkpoint to checkpoints/checkpoint-800\n",
|
823 |
+
"Configuration saved in checkpoints/checkpoint-800/config.json\n",
|
824 |
+
"Model weights saved in checkpoints/checkpoint-800/pytorch_model.bin\n",
|
825 |
+
"tokenizer config file saved in checkpoints/checkpoint-800/tokenizer_config.json\n",
|
826 |
+
"Special tokens file saved in checkpoints/checkpoint-800/special_tokens_map.json\n",
|
827 |
+
"Deleting older checkpoint [checkpoints/checkpoint-500] due to args.save_total_limit\n",
|
828 |
+
"***** Running Evaluation *****\n",
|
829 |
+
" Num examples = 1024\n",
|
830 |
+
" Batch size = 24\n",
|
831 |
+
"Saving model checkpoint to checkpoints/checkpoint-900\n",
|
832 |
+
"Configuration saved in checkpoints/checkpoint-900/config.json\n",
|
833 |
+
"Model weights saved in checkpoints/checkpoint-900/pytorch_model.bin\n",
|
834 |
+
"tokenizer config file saved in checkpoints/checkpoint-900/tokenizer_config.json\n",
|
835 |
+
"Special tokens file saved in checkpoints/checkpoint-900/special_tokens_map.json\n",
|
836 |
+
"Deleting older checkpoint [checkpoints/checkpoint-700] due to args.save_total_limit\n",
|
837 |
+
"***** Running Evaluation *****\n",
|
838 |
+
" Num examples = 1024\n",
|
839 |
+
" Batch size = 24\n",
|
840 |
+
"Saving model checkpoint to checkpoints/checkpoint-1000\n",
|
841 |
+
"Configuration saved in checkpoints/checkpoint-1000/config.json\n",
|
842 |
+
"Model weights saved in checkpoints/checkpoint-1000/pytorch_model.bin\n",
|
843 |
+
"tokenizer config file saved in checkpoints/checkpoint-1000/tokenizer_config.json\n",
|
844 |
+
"Special tokens file saved in checkpoints/checkpoint-1000/special_tokens_map.json\n",
|
845 |
+
"Deleting older checkpoint [checkpoints/checkpoint-800] due to args.save_total_limit\n",
|
846 |
+
"***** Running Evaluation *****\n",
|
847 |
+
" Num examples = 1024\n",
|
848 |
+
" Batch size = 24\n",
|
849 |
+
"Saving model checkpoint to checkpoints/checkpoint-1100\n",
|
850 |
+
"Configuration saved in checkpoints/checkpoint-1100/config.json\n",
|
851 |
+
"Model weights saved in checkpoints/checkpoint-1100/pytorch_model.bin\n",
|
852 |
+
"tokenizer config file saved in checkpoints/checkpoint-1100/tokenizer_config.json\n",
|
853 |
+
"Special tokens file saved in checkpoints/checkpoint-1100/special_tokens_map.json\n",
|
854 |
+
"Deleting older checkpoint [checkpoints/checkpoint-900] due to args.save_total_limit\n",
|
855 |
+
"***** Running Evaluation *****\n",
|
856 |
+
" Num examples = 1024\n",
|
857 |
+
" Batch size = 24\n",
|
858 |
+
"Saving model checkpoint to checkpoints/checkpoint-1200\n",
|
859 |
+
"Configuration saved in checkpoints/checkpoint-1200/config.json\n",
|
860 |
+
"Model weights saved in checkpoints/checkpoint-1200/pytorch_model.bin\n",
|
861 |
+
"tokenizer config file saved in checkpoints/checkpoint-1200/tokenizer_config.json\n",
|
862 |
+
"Special tokens file saved in checkpoints/checkpoint-1200/special_tokens_map.json\n",
|
863 |
+
"Deleting older checkpoint [checkpoints/checkpoint-1000] due to args.save_total_limit\n",
|
864 |
+
"***** Running Evaluation *****\n",
|
865 |
+
" Num examples = 1024\n",
|
866 |
+
" Batch size = 24\n",
|
867 |
+
"Saving model checkpoint to checkpoints/checkpoint-1300\n",
|
868 |
+
"Configuration saved in checkpoints/checkpoint-1300/config.json\n",
|
869 |
+
"Model weights saved in checkpoints/checkpoint-1300/pytorch_model.bin\n",
|
870 |
+
"tokenizer config file saved in checkpoints/checkpoint-1300/tokenizer_config.json\n",
|
871 |
+
"Special tokens file saved in checkpoints/checkpoint-1300/special_tokens_map.json\n",
|
872 |
+
"Deleting older checkpoint [checkpoints/checkpoint-1100] due to args.save_total_limit\n",
|
873 |
+
"***** Running Evaluation *****\n",
|
874 |
+
" Num examples = 1024\n",
|
875 |
+
" Batch size = 24\n",
|
876 |
+
"Saving model checkpoint to checkpoints/checkpoint-1400\n",
|
877 |
+
"Configuration saved in checkpoints/checkpoint-1400/config.json\n",
|
878 |
+
"Model weights saved in checkpoints/checkpoint-1400/pytorch_model.bin\n",
|
879 |
+
"tokenizer config file saved in checkpoints/checkpoint-1400/tokenizer_config.json\n",
|
880 |
+
"Special tokens file saved in checkpoints/checkpoint-1400/special_tokens_map.json\n",
|
881 |
+
"Deleting older checkpoint [checkpoints/checkpoint-1200] due to args.save_total_limit\n",
|
882 |
+
"***** Running Evaluation *****\n",
|
883 |
+
" Num examples = 1024\n",
|
884 |
+
" Batch size = 24\n",
|
885 |
+
"Saving model checkpoint to checkpoints/checkpoint-1500\n",
|
886 |
+
"Configuration saved in checkpoints/checkpoint-1500/config.json\n",
|
887 |
+
"Model weights saved in checkpoints/checkpoint-1500/pytorch_model.bin\n",
|
888 |
+
"tokenizer config file saved in checkpoints/checkpoint-1500/tokenizer_config.json\n",
|
889 |
+
"Special tokens file saved in checkpoints/checkpoint-1500/special_tokens_map.json\n",
|
890 |
+
"Deleting older checkpoint [checkpoints/checkpoint-1300] due to args.save_total_limit\n",
|
891 |
+
"***** Running Evaluation *****\n",
|
892 |
+
" Num examples = 1024\n",
|
893 |
+
" Batch size = 24\n",
|
894 |
+
"Saving model checkpoint to checkpoints/checkpoint-1600\n",
|
895 |
+
"Configuration saved in checkpoints/checkpoint-1600/config.json\n",
|
896 |
+
"Model weights saved in checkpoints/checkpoint-1600/pytorch_model.bin\n",
|
897 |
+
"tokenizer config file saved in checkpoints/checkpoint-1600/tokenizer_config.json\n",
|
898 |
+
"Special tokens file saved in checkpoints/checkpoint-1600/special_tokens_map.json\n",
|
899 |
+
"Deleting older checkpoint [checkpoints/checkpoint-1400] due to args.save_total_limit\n",
|
900 |
+
"***** Running Evaluation *****\n",
|
901 |
+
" Num examples = 1024\n",
|
902 |
+
" Batch size = 24\n",
|
903 |
+
"Saving model checkpoint to checkpoints/checkpoint-1700\n",
|
904 |
+
"Configuration saved in checkpoints/checkpoint-1700/config.json\n",
|
905 |
+
"Model weights saved in checkpoints/checkpoint-1700/pytorch_model.bin\n",
|
906 |
+
"tokenizer config file saved in checkpoints/checkpoint-1700/tokenizer_config.json\n",
|
907 |
+
"Special tokens file saved in checkpoints/checkpoint-1700/special_tokens_map.json\n",
|
908 |
+
"Deleting older checkpoint [checkpoints/checkpoint-1600] due to args.save_total_limit\n",
|
909 |
+
"***** Running Evaluation *****\n",
|
910 |
+
" Num examples = 1024\n",
|
911 |
+
" Batch size = 24\n",
|
912 |
+
"Saving model checkpoint to checkpoints/checkpoint-1800\n",
|
913 |
+
"Configuration saved in checkpoints/checkpoint-1800/config.json\n",
|
914 |
+
"Model weights saved in checkpoints/checkpoint-1800/pytorch_model.bin\n",
|
915 |
+
"tokenizer config file saved in checkpoints/checkpoint-1800/tokenizer_config.json\n",
|
916 |
+
"Special tokens file saved in checkpoints/checkpoint-1800/special_tokens_map.json\n",
|
917 |
+
"Deleting older checkpoint [checkpoints/checkpoint-1700] due to args.save_total_limit\n",
|
918 |
+
"***** Running Evaluation *****\n",
|
919 |
+
" Num examples = 1024\n",
|
920 |
+
" Batch size = 24\n",
|
921 |
+
"Saving model checkpoint to checkpoints/checkpoint-2200\n",
|
922 |
+
"Configuration saved in checkpoints/checkpoint-2200/config.json\n",
|
923 |
+
"Model weights saved in checkpoints/checkpoint-2200/pytorch_model.bin\n",
|
924 |
+
"tokenizer config file saved in checkpoints/checkpoint-2200/tokenizer_config.json\n",
|
925 |
+
"Special tokens file saved in checkpoints/checkpoint-2200/special_tokens_map.json\n",
|
926 |
+
"Deleting older checkpoint [checkpoints/checkpoint-2100] due to args.save_total_limit\n",
|
927 |
+
"***** Running Evaluation *****\n",
|
928 |
+
" Num examples = 1024\n",
|
929 |
+
" Batch size = 24\n",
|
930 |
+
"Saving model checkpoint to checkpoints/checkpoint-2300\n",
|
931 |
+
"Configuration saved in checkpoints/checkpoint-2300/config.json\n",
|
932 |
+
"Model weights saved in checkpoints/checkpoint-2300/pytorch_model.bin\n",
|
933 |
+
"tokenizer config file saved in checkpoints/checkpoint-2300/tokenizer_config.json\n",
|
934 |
+
"Special tokens file saved in checkpoints/checkpoint-2300/special_tokens_map.json\n",
|
935 |
+
"Deleting older checkpoint [checkpoints/checkpoint-1900] due to args.save_total_limit\n",
|
936 |
+
"***** Running Evaluation *****\n",
|
937 |
+
" Num examples = 1024\n",
|
938 |
+
" Batch size = 24\n",
|
939 |
+
"Saving model checkpoint to checkpoints/checkpoint-2400\n",
|
940 |
+
"Configuration saved in checkpoints/checkpoint-2400/config.json\n",
|
941 |
+
"Model weights saved in checkpoints/checkpoint-2400/pytorch_model.bin\n",
|
942 |
+
"tokenizer config file saved in checkpoints/checkpoint-2400/tokenizer_config.json\n",
|
943 |
+
"Special tokens file saved in checkpoints/checkpoint-2400/special_tokens_map.json\n",
|
944 |
+
"Deleting older checkpoint [checkpoints/checkpoint-2200] due to args.save_total_limit\n",
|
945 |
+
"***** Running Evaluation *****\n",
|
946 |
+
" Num examples = 1024\n",
|
947 |
+
" Batch size = 24\n",
|
948 |
+
"Saving model checkpoint to checkpoints/checkpoint-2500\n",
|
949 |
+
"Configuration saved in checkpoints/checkpoint-2500/config.json\n",
|
950 |
+
"Model weights saved in checkpoints/checkpoint-2500/pytorch_model.bin\n",
|
951 |
+
"tokenizer config file saved in checkpoints/checkpoint-2500/tokenizer_config.json\n",
|
952 |
+
"Special tokens file saved in checkpoints/checkpoint-2500/special_tokens_map.json\n",
|
953 |
+
"Deleting older checkpoint [checkpoints/checkpoint-2300] due to args.save_total_limit\n",
|
954 |
+
"***** Running Evaluation *****\n",
|
955 |
+
" Num examples = 1024\n",
|
956 |
+
" Batch size = 24\n",
|
957 |
+
"Saving model checkpoint to checkpoints/checkpoint-2600\n",
|
958 |
+
"Configuration saved in checkpoints/checkpoint-2600/config.json\n",
|
959 |
+
"Model weights saved in checkpoints/checkpoint-2600/pytorch_model.bin\n",
|
960 |
+
"tokenizer config file saved in checkpoints/checkpoint-2600/tokenizer_config.json\n",
|
961 |
+
"Special tokens file saved in checkpoints/checkpoint-2600/special_tokens_map.json\n",
|
962 |
+
"Deleting older checkpoint [checkpoints/checkpoint-2400] due to args.save_total_limit\n",
|
963 |
+
"***** Running Evaluation *****\n",
|
964 |
+
" Num examples = 1024\n",
|
965 |
+
" Batch size = 24\n",
|
966 |
+
"Saving model checkpoint to checkpoints/checkpoint-2700\n",
|
967 |
+
"Configuration saved in checkpoints/checkpoint-2700/config.json\n",
|
968 |
+
"Model weights saved in checkpoints/checkpoint-2700/pytorch_model.bin\n",
|
969 |
+
"tokenizer config file saved in checkpoints/checkpoint-2700/tokenizer_config.json\n",
|
970 |
+
"Special tokens file saved in checkpoints/checkpoint-2700/special_tokens_map.json\n",
|
971 |
+
"Deleting older checkpoint [checkpoints/checkpoint-2600] due to args.save_total_limit\n",
|
972 |
+
"***** Running Evaluation *****\n",
|
973 |
+
" Num examples = 1024\n",
|
974 |
+
" Batch size = 24\n",
|
975 |
+
"Saving model checkpoint to checkpoints/checkpoint-2800\n",
|
976 |
+
"Configuration saved in checkpoints/checkpoint-2800/config.json\n",
|
977 |
+
"Model weights saved in checkpoints/checkpoint-2800/pytorch_model.bin\n",
|
978 |
+
"tokenizer config file saved in checkpoints/checkpoint-2800/tokenizer_config.json\n",
|
979 |
+
"Special tokens file saved in checkpoints/checkpoint-2800/special_tokens_map.json\n",
|
980 |
+
"Deleting older checkpoint [checkpoints/checkpoint-2500] due to args.save_total_limit\n",
|
981 |
+
"***** Running Evaluation *****\n",
|
982 |
+
" Num examples = 1024\n",
|
983 |
+
" Batch size = 24\n",
|
984 |
+
"Saving model checkpoint to checkpoints/checkpoint-2900\n",
|
985 |
+
"Configuration saved in checkpoints/checkpoint-2900/config.json\n",
|
986 |
+
"Model weights saved in checkpoints/checkpoint-2900/pytorch_model.bin\n",
|
987 |
+
"tokenizer config file saved in checkpoints/checkpoint-2900/tokenizer_config.json\n",
|
988 |
+
"Special tokens file saved in checkpoints/checkpoint-2900/special_tokens_map.json\n",
|
989 |
+
"Deleting older checkpoint [checkpoints/checkpoint-2700] due to args.save_total_limit\n",
|
990 |
+
"***** Running Evaluation *****\n",
|
991 |
+
" Num examples = 1024\n",
|
992 |
+
" Batch size = 24\n",
|
993 |
+
"Saving model checkpoint to checkpoints/checkpoint-3000\n",
|
994 |
+
"Configuration saved in checkpoints/checkpoint-3000/config.json\n",
|
995 |
+
"Model weights saved in checkpoints/checkpoint-3000/pytorch_model.bin\n",
|
996 |
+
"tokenizer config file saved in checkpoints/checkpoint-3000/tokenizer_config.json\n",
|
997 |
+
"Special tokens file saved in checkpoints/checkpoint-3000/special_tokens_map.json\n",
|
998 |
+
"Deleting older checkpoint [checkpoints/checkpoint-2800] due to args.save_total_limit\n",
|
999 |
+
"***** Running Evaluation *****\n",
|
1000 |
+
" Num examples = 1024\n",
|
1001 |
+
" Batch size = 24\n",
|
1002 |
+
"Saving model checkpoint to checkpoints/checkpoint-3100\n",
|
1003 |
+
"Configuration saved in checkpoints/checkpoint-3100/config.json\n",
|
1004 |
+
"Model weights saved in checkpoints/checkpoint-3100/pytorch_model.bin\n",
|
1005 |
+
"tokenizer config file saved in checkpoints/checkpoint-3100/tokenizer_config.json\n",
|
1006 |
+
"Special tokens file saved in checkpoints/checkpoint-3100/special_tokens_map.json\n",
|
1007 |
+
"Deleting older checkpoint [checkpoints/checkpoint-2900] due to args.save_total_limit\n",
|
1008 |
+
"***** Running Evaluation *****\n",
|
1009 |
+
" Num examples = 1024\n",
|
1010 |
+
" Batch size = 24\n",
|
1011 |
+
"Saving model checkpoint to checkpoints/checkpoint-3200\n",
|
1012 |
+
"Configuration saved in checkpoints/checkpoint-3200/config.json\n",
|
1013 |
+
"Model weights saved in checkpoints/checkpoint-3200/pytorch_model.bin\n",
|
1014 |
+
"tokenizer config file saved in checkpoints/checkpoint-3200/tokenizer_config.json\n",
|
1015 |
+
"Special tokens file saved in checkpoints/checkpoint-3200/special_tokens_map.json\n",
|
1016 |
+
"Deleting older checkpoint [checkpoints/checkpoint-3000] due to args.save_total_limit\n",
|
1017 |
+
"***** Running Evaluation *****\n",
|
1018 |
+
" Num examples = 1024\n",
|
1019 |
+
" Batch size = 24\n",
|
1020 |
+
"Saving model checkpoint to checkpoints/checkpoint-3300\n",
|
1021 |
+
"Configuration saved in checkpoints/checkpoint-3300/config.json\n",
|
1022 |
+
"Model weights saved in checkpoints/checkpoint-3300/pytorch_model.bin\n",
|
1023 |
+
"tokenizer config file saved in checkpoints/checkpoint-3300/tokenizer_config.json\n",
|
1024 |
+
"Special tokens file saved in checkpoints/checkpoint-3300/special_tokens_map.json\n",
|
1025 |
+
"Deleting older checkpoint [checkpoints/checkpoint-3100] due to args.save_total_limit\n",
|
1026 |
+
"***** Running Evaluation *****\n",
|
1027 |
+
" Num examples = 1024\n",
|
1028 |
+
" Batch size = 24\n",
|
1029 |
+
"Saving model checkpoint to checkpoints/checkpoint-3400\n",
|
1030 |
+
"Configuration saved in checkpoints/checkpoint-3400/config.json\n",
|
1031 |
+
"Model weights saved in checkpoints/checkpoint-3400/pytorch_model.bin\n",
|
1032 |
+
"tokenizer config file saved in checkpoints/checkpoint-3400/tokenizer_config.json\n",
|
1033 |
+
"Special tokens file saved in checkpoints/checkpoint-3400/special_tokens_map.json\n",
|
1034 |
+
"Deleting older checkpoint [checkpoints/checkpoint-3300] due to args.save_total_limit\n",
|
1035 |
+
"***** Running Evaluation *****\n",
|
1036 |
+
" Num examples = 1024\n",
|
1037 |
+
" Batch size = 24\n",
|
1038 |
+
"Saving model checkpoint to checkpoints/checkpoint-3500\n",
|
1039 |
+
"Configuration saved in checkpoints/checkpoint-3500/config.json\n",
|
1040 |
+
"Model weights saved in checkpoints/checkpoint-3500/pytorch_model.bin\n",
|
1041 |
+
"tokenizer config file saved in checkpoints/checkpoint-3500/tokenizer_config.json\n",
|
1042 |
+
"Special tokens file saved in checkpoints/checkpoint-3500/special_tokens_map.json\n",
|
1043 |
+
"Deleting older checkpoint [checkpoints/checkpoint-3200] due to args.save_total_limit\n",
|
1044 |
+
"***** Running Evaluation *****\n",
|
1045 |
+
" Num examples = 1024\n",
|
1046 |
+
" Batch size = 24\n",
|
1047 |
+
"Saving model checkpoint to checkpoints/checkpoint-3600\n",
|
1048 |
+
"Configuration saved in checkpoints/checkpoint-3600/config.json\n",
|
1049 |
+
"Model weights saved in checkpoints/checkpoint-3600/pytorch_model.bin\n",
|
1050 |
+
"tokenizer config file saved in checkpoints/checkpoint-3600/tokenizer_config.json\n",
|
1051 |
+
"Special tokens file saved in checkpoints/checkpoint-3600/special_tokens_map.json\n",
|
1052 |
+
"Deleting older checkpoint [checkpoints/checkpoint-3400] due to args.save_total_limit\n",
|
1053 |
+
"***** Running Evaluation *****\n",
|
1054 |
+
" Num examples = 1024\n",
|
1055 |
+
" Batch size = 24\n",
|
1056 |
+
"Saving model checkpoint to checkpoints/checkpoint-3700\n",
|
1057 |
+
"Configuration saved in checkpoints/checkpoint-3700/config.json\n",
|
1058 |
+
"Model weights saved in checkpoints/checkpoint-3700/pytorch_model.bin\n",
|
1059 |
+
"tokenizer config file saved in checkpoints/checkpoint-3700/tokenizer_config.json\n",
|
1060 |
+
"Special tokens file saved in checkpoints/checkpoint-3700/special_tokens_map.json\n",
|
1061 |
+
"Deleting older checkpoint [checkpoints/checkpoint-3500] due to args.save_total_limit\n",
|
1062 |
+
"***** Running Evaluation *****\n",
|
1063 |
+
" Num examples = 1024\n",
|
1064 |
+
" Batch size = 24\n",
|
1065 |
+
"Saving model checkpoint to checkpoints/checkpoint-3800\n",
|
1066 |
+
"Configuration saved in checkpoints/checkpoint-3800/config.json\n",
|
1067 |
+
"Model weights saved in checkpoints/checkpoint-3800/pytorch_model.bin\n",
|
1068 |
+
"tokenizer config file saved in checkpoints/checkpoint-3800/tokenizer_config.json\n",
|
1069 |
+
"Special tokens file saved in checkpoints/checkpoint-3800/special_tokens_map.json\n",
|
1070 |
+
"Deleting older checkpoint [checkpoints/checkpoint-3600] due to args.save_total_limit\n",
|
1071 |
+
"***** Running Evaluation *****\n",
|
1072 |
+
" Num examples = 1024\n",
|
1073 |
+
" Batch size = 24\n",
|
1074 |
+
"Saving model checkpoint to checkpoints/checkpoint-3900\n",
|
1075 |
+
"Configuration saved in checkpoints/checkpoint-3900/config.json\n",
|
1076 |
+
"Model weights saved in checkpoints/checkpoint-3900/pytorch_model.bin\n",
|
1077 |
+
"tokenizer config file saved in checkpoints/checkpoint-3900/tokenizer_config.json\n",
|
1078 |
+
"Special tokens file saved in checkpoints/checkpoint-3900/special_tokens_map.json\n",
|
1079 |
+
"Deleting older checkpoint [checkpoints/checkpoint-3700] due to args.save_total_limit\n",
|
1080 |
+
"***** Running Evaluation *****\n",
|
1081 |
+
" Num examples = 1024\n",
|
1082 |
+
" Batch size = 24\n",
|
1083 |
+
"Saving model checkpoint to checkpoints/checkpoint-4000\n",
|
1084 |
+
"Configuration saved in checkpoints/checkpoint-4000/config.json\n",
|
1085 |
+
"Model weights saved in checkpoints/checkpoint-4000/pytorch_model.bin\n",
|
1086 |
+
"tokenizer config file saved in checkpoints/checkpoint-4000/tokenizer_config.json\n",
|
1087 |
+
"Special tokens file saved in checkpoints/checkpoint-4000/special_tokens_map.json\n",
|
1088 |
+
"Deleting older checkpoint [checkpoints/checkpoint-3800] due to args.save_total_limit\n",
|
1089 |
+
"***** Running Evaluation *****\n",
|
1090 |
+
" Num examples = 1024\n",
|
1091 |
+
" Batch size = 24\n",
|
1092 |
+
"Saving model checkpoint to checkpoints/checkpoint-4100\n",
|
1093 |
+
"Configuration saved in checkpoints/checkpoint-4100/config.json\n",
|
1094 |
+
"Model weights saved in checkpoints/checkpoint-4100/pytorch_model.bin\n",
|
1095 |
+
"tokenizer config file saved in checkpoints/checkpoint-4100/tokenizer_config.json\n",
|
1096 |
+
"Special tokens file saved in checkpoints/checkpoint-4100/special_tokens_map.json\n",
|
1097 |
+
"Deleting older checkpoint [checkpoints/checkpoint-3900] due to args.save_total_limit\n",
|
1098 |
+
"***** Running Evaluation *****\n",
|
1099 |
+
" Num examples = 1024\n",
|
1100 |
+
" Batch size = 24\n",
|
1101 |
+
"Saving model checkpoint to checkpoints/checkpoint-4200\n",
|
1102 |
+
"Configuration saved in checkpoints/checkpoint-4200/config.json\n",
|
1103 |
+
"Model weights saved in checkpoints/checkpoint-4200/pytorch_model.bin\n",
|
1104 |
+
"tokenizer config file saved in checkpoints/checkpoint-4200/tokenizer_config.json\n",
|
1105 |
+
"Special tokens file saved in checkpoints/checkpoint-4200/special_tokens_map.json\n",
|
1106 |
+
"Deleting older checkpoint [checkpoints/checkpoint-4100] due to args.save_total_limit\n",
|
1107 |
+
"***** Running Evaluation *****\n",
|
1108 |
+
" Num examples = 1024\n",
|
1109 |
+
" Batch size = 24\n",
|
1110 |
+
"Saving model checkpoint to checkpoints/checkpoint-4300\n",
|
1111 |
+
"Configuration saved in checkpoints/checkpoint-4300/config.json\n",
|
1112 |
+
"Model weights saved in checkpoints/checkpoint-4300/pytorch_model.bin\n",
|
1113 |
+
"tokenizer config file saved in checkpoints/checkpoint-4300/tokenizer_config.json\n",
|
1114 |
+
"Special tokens file saved in checkpoints/checkpoint-4300/special_tokens_map.json\n",
|
1115 |
+
"Deleting older checkpoint [checkpoints/checkpoint-4200] due to args.save_total_limit\n",
|
1116 |
+
"***** Running Evaluation *****\n",
|
1117 |
+
" Num examples = 1024\n",
|
1118 |
+
" Batch size = 24\n",
|
1119 |
+
"Saving model checkpoint to checkpoints/checkpoint-4400\n",
|
1120 |
+
"Configuration saved in checkpoints/checkpoint-4400/config.json\n",
|
1121 |
+
"Model weights saved in checkpoints/checkpoint-4400/pytorch_model.bin\n",
|
1122 |
+
"tokenizer config file saved in checkpoints/checkpoint-4400/tokenizer_config.json\n",
|
1123 |
+
"Special tokens file saved in checkpoints/checkpoint-4400/special_tokens_map.json\n",
|
1124 |
+
"Deleting older checkpoint [checkpoints/checkpoint-4300] due to args.save_total_limit\n",
|
1125 |
+
"***** Running Evaluation *****\n",
|
1126 |
+
" Num examples = 1024\n",
|
1127 |
+
" Batch size = 24\n",
|
1128 |
+
"Saving model checkpoint to checkpoints/checkpoint-4500\n",
|
1129 |
+
"Configuration saved in checkpoints/checkpoint-4500/config.json\n",
|
1130 |
+
"Model weights saved in checkpoints/checkpoint-4500/pytorch_model.bin\n",
|
1131 |
+
"tokenizer config file saved in checkpoints/checkpoint-4500/tokenizer_config.json\n",
|
1132 |
+
"Special tokens file saved in checkpoints/checkpoint-4500/special_tokens_map.json\n",
|
1133 |
+
"Deleting older checkpoint [checkpoints/checkpoint-4400] due to args.save_total_limit\n",
|
1134 |
+
"***** Running Evaluation *****\n",
|
1135 |
+
" Num examples = 1024\n",
|
1136 |
+
" Batch size = 24\n",
|
1137 |
+
"Saving model checkpoint to checkpoints/checkpoint-4600\n",
|
1138 |
+
"Configuration saved in checkpoints/checkpoint-4600/config.json\n",
|
1139 |
+
"Model weights saved in checkpoints/checkpoint-4600/pytorch_model.bin\n",
|
1140 |
+
"tokenizer config file saved in checkpoints/checkpoint-4600/tokenizer_config.json\n",
|
1141 |
+
"Special tokens file saved in checkpoints/checkpoint-4600/special_tokens_map.json\n",
|
1142 |
+
"Deleting older checkpoint [checkpoints/checkpoint-4500] due to args.save_total_limit\n",
|
1143 |
+
"***** Running Evaluation *****\n",
|
1144 |
+
" Num examples = 1024\n",
|
1145 |
+
" Batch size = 24\n",
|
1146 |
+
"Saving model checkpoint to checkpoints/checkpoint-4700\n",
|
1147 |
+
"Configuration saved in checkpoints/checkpoint-4700/config.json\n",
|
1148 |
+
"Model weights saved in checkpoints/checkpoint-4700/pytorch_model.bin\n",
|
1149 |
+
"tokenizer config file saved in checkpoints/checkpoint-4700/tokenizer_config.json\n",
|
1150 |
+
"Special tokens file saved in checkpoints/checkpoint-4700/special_tokens_map.json\n",
|
1151 |
+
"Deleting older checkpoint [checkpoints/checkpoint-4600] due to args.save_total_limit\n",
|
1152 |
+
"***** Running Evaluation *****\n",
|
1153 |
+
" Num examples = 1024\n",
|
1154 |
+
" Batch size = 24\n",
|
1155 |
+
"Saving model checkpoint to checkpoints/checkpoint-4800\n",
|
1156 |
+
"Configuration saved in checkpoints/checkpoint-4800/config.json\n",
|
1157 |
+
"Model weights saved in checkpoints/checkpoint-4800/pytorch_model.bin\n",
|
1158 |
+
"tokenizer config file saved in checkpoints/checkpoint-4800/tokenizer_config.json\n",
|
1159 |
+
"Special tokens file saved in checkpoints/checkpoint-4800/special_tokens_map.json\n",
|
1160 |
+
"Deleting older checkpoint [checkpoints/checkpoint-4700] due to args.save_total_limit\n",
|
1161 |
+
"***** Running Evaluation *****\n",
|
1162 |
+
" Num examples = 1024\n",
|
1163 |
+
" Batch size = 24\n",
|
1164 |
+
"Saving model checkpoint to checkpoints/checkpoint-4900\n",
|
1165 |
+
"Configuration saved in checkpoints/checkpoint-4900/config.json\n",
|
1166 |
+
"Model weights saved in checkpoints/checkpoint-4900/pytorch_model.bin\n",
|
1167 |
+
"tokenizer config file saved in checkpoints/checkpoint-4900/tokenizer_config.json\n",
|
1168 |
+
"Special tokens file saved in checkpoints/checkpoint-4900/special_tokens_map.json\n",
|
1169 |
+
"Deleting older checkpoint [checkpoints/checkpoint-4800] due to args.save_total_limit\n",
|
1170 |
+
"***** Running Evaluation *****\n",
|
1171 |
+
" Num examples = 1024\n",
|
1172 |
+
" Batch size = 24\n",
|
1173 |
+
"Saving model checkpoint to checkpoints/checkpoint-5000\n",
|
1174 |
+
"Configuration saved in checkpoints/checkpoint-5000/config.json\n",
|
1175 |
+
"Model weights saved in checkpoints/checkpoint-5000/pytorch_model.bin\n",
|
1176 |
+
"tokenizer config file saved in checkpoints/checkpoint-5000/tokenizer_config.json\n",
|
1177 |
+
"Special tokens file saved in checkpoints/checkpoint-5000/special_tokens_map.json\n",
|
1178 |
+
"Deleting older checkpoint [checkpoints/checkpoint-4900] due to args.save_total_limit\n",
|
1179 |
+
"\n",
|
1180 |
+
"\n",
|
1181 |
+
"Training completed. Do not forget to share your model on huggingface.co/models =)\n",
|
1182 |
+
"\n",
|
1183 |
+
"\n",
|
1184 |
+
"Loading best model from checkpoints/checkpoint-4000 (score: 1.5726864337921143).\n"
|
1185 |
+
]
|
1186 |
+
},
|
1187 |
+
{
|
1188 |
+
"data": {
|
1189 |
+
"text/plain": [
|
1190 |
+
"TrainOutput(global_step=5000, training_loss=1.7068539916992187, metrics={'train_runtime': 1373.8884, 'train_samples_per_second': 87.343, 'train_steps_per_second': 3.639, 'total_flos': 1.9803672136145664e+16, 'train_loss': 1.7068539916992187, 'epoch': 3.08})"
|
1191 |
+
]
|
1192 |
+
},
|
1193 |
+
"execution_count": 11,
|
1194 |
+
"metadata": {},
|
1195 |
+
"output_type": "execute_result"
|
1196 |
+
}
|
1197 |
+
],
|
1198 |
+
"source": [
|
1199 |
+
"trainer.train()"
|
1200 |
+
]
|
1201 |
+
},
|
1202 |
+
{
|
1203 |
+
"cell_type": "code",
|
1204 |
+
"execution_count": 12,
|
1205 |
+
"id": "86cbf6fb-8e38-4a54-bf1a-987e308ef97d",
|
1206 |
+
"metadata": {},
|
1207 |
+
"outputs": [
|
1208 |
+
{
|
1209 |
+
"name": "stdout",
|
1210 |
+
"output_type": "stream",
|
1211 |
+
"text": [
|
1212 |
+
"checkpoint-4000 checkpoint-5000 runs\n"
|
1213 |
+
]
|
1214 |
+
}
|
1215 |
+
],
|
1216 |
+
"source": [
|
1217 |
+
"!ls checkpoints/"
|
1218 |
+
]
|
1219 |
+
},
|
1220 |
+
{
|
1221 |
+
"cell_type": "code",
|
1222 |
+
"execution_count": 94,
|
1223 |
+
"id": "4a0a8e00-6b91-4c8e-af5c-8ab66c3e5648",
|
1224 |
+
"metadata": {},
|
1225 |
+
"outputs": [],
|
1226 |
+
"source": [
|
1227 |
+
"from torch.utils.data import DataLoader\n",
|
1228 |
+
"\n",
|
1229 |
+
"def calc_metrics(model, dataset, abstract_proba):\n",
|
1230 |
+
" dataloader = DataLoader(\n",
|
1231 |
+
" dataset, batch_size=16, shuffle=False,\n",
|
1232 |
+
" collate_fn=get_collator(tokenizer, abstract_proba=abstract_proba)\n",
|
1233 |
+
" )\n",
|
1234 |
+
" precisions, recalls, top1_accs = [], [], []\n",
|
1235 |
+
" with torch.no_grad():\n",
|
1236 |
+
" for batch in tqdm(dataloader):\n",
|
1237 |
+
" outputs = model(**batch.to('cuda'))\n",
|
1238 |
+
" for labels, preds in zip(batch['labels'], outputs.logits.softmax(-1)):\n",
|
1239 |
+
" top_probs, top_inds = preds.sort(descending=True)\n",
|
1240 |
+
" mask = top_probs.cumsum(0) <= 0.95\n",
|
1241 |
+
" mask[0] = True\n",
|
1242 |
+
" a = set(top_inds[mask].tolist())\n",
|
1243 |
+
" y = set(labels.nonzero().flatten().tolist())\n",
|
1244 |
+
" top1_accs.append(int(top_inds[0]) in y)\n",
|
1245 |
+
" recalls.append(len(y & a) / len(y))\n",
|
1246 |
+
" precisions.append(len(y & a) / len(a))\n",
|
1247 |
+
" return {'Recall@0.95': np.mean(recalls),\n",
|
1248 |
+
" 'Precision@0.95': np.mean(precisions),\n",
|
1249 |
+
" 'Top-1 Accuracy': np.mean(top1_accs)}"
|
1250 |
+
]
|
1251 |
+
},
|
1252 |
+
{
|
1253 |
+
"cell_type": "code",
|
1254 |
+
"execution_count": 97,
|
1255 |
+
"id": "a374a862-3d6c-4d58-8871-a1195dd75e1c",
|
1256 |
+
"metadata": {},
|
1257 |
+
"outputs": [],
|
1258 |
+
"source": [
|
1259 |
+
"from transformers import AutoTokenizer\n",
|
1260 |
+
"from transformers import AutoModelForSequenceClassification\n",
|
1261 |
+
"\n",
|
1262 |
+
"tokenizer = AutoTokenizer.from_pretrained('checkpoints/checkpoint-4000/')\n",
|
1263 |
+
"model = AutoModelForSequenceClassification.from_pretrained('checkpoints/checkpoint-4000/')\n",
|
1264 |
+
"model.to('cuda')\n",
|
1265 |
+
"model.eval();"
|
1266 |
+
]
|
1267 |
+
},
|
1268 |
+
{
|
1269 |
+
"cell_type": "code",
|
1270 |
+
"execution_count": 98,
|
1271 |
+
"id": "1320d0fd-5302-4677-b9ae-4c9d5652db73",
|
1272 |
+
"metadata": {},
|
1273 |
+
"outputs": [
|
1274 |
+
{
|
1275 |
+
"data": {
|
1276 |
+
"application/vnd.jupyter.widget-view+json": {
|
1277 |
+
"model_id": "7da9cb0fa4154813a7db717c9556e4f4",
|
1278 |
+
"version_major": 2,
|
1279 |
+
"version_minor": 0
|
1280 |
+
},
|
1281 |
+
"text/plain": [
|
1282 |
+
" 0%| | 0/64 [00:00<?, ?it/s]"
|
1283 |
+
]
|
1284 |
+
},
|
1285 |
+
"metadata": {},
|
1286 |
+
"output_type": "display_data"
|
1287 |
+
},
|
1288 |
+
{
|
1289 |
+
"data": {
|
1290 |
+
"text/plain": [
|
1291 |
+
"{'Recall@0.95': 0.9289341517857141,\n",
|
1292 |
+
" 'Precision@0.95': 0.28929539856301567,\n",
|
1293 |
+
" 'Top-1 Accuracy': 0.791015625}"
|
1294 |
+
]
|
1295 |
+
},
|
1296 |
+
"execution_count": 98,
|
1297 |
+
"metadata": {},
|
1298 |
+
"output_type": "execute_result"
|
1299 |
+
}
|
1300 |
+
],
|
1301 |
+
"source": [
|
1302 |
+
"calc_metrics(model, dataset['test'], abstract_proba=0.0)"
|
1303 |
+
]
|
1304 |
+
},
|
1305 |
+
{
|
1306 |
+
"cell_type": "code",
|
1307 |
+
"execution_count": 99,
|
1308 |
+
"id": "4c05da39-e50c-4ed8-8c4d-7c2668378d75",
|
1309 |
+
"metadata": {},
|
1310 |
+
"outputs": [
|
1311 |
+
{
|
1312 |
+
"data": {
|
1313 |
+
"application/vnd.jupyter.widget-view+json": {
|
1314 |
+
"model_id": "246a76ecc5de451b858addc37b5c2bd8",
|
1315 |
+
"version_major": 2,
|
1316 |
+
"version_minor": 0
|
1317 |
+
},
|
1318 |
+
"text/plain": [
|
1319 |
+
" 0%| | 0/64 [00:00<?, ?it/s]"
|
1320 |
+
]
|
1321 |
+
},
|
1322 |
+
"metadata": {},
|
1323 |
+
"output_type": "display_data"
|
1324 |
+
},
|
1325 |
+
{
|
1326 |
+
"data": {
|
1327 |
+
"text/plain": [
|
1328 |
+
"{'Recall@0.95': 0.9357700892857144,\n",
|
1329 |
+
" 'Precision@0.95': 0.3308562290529852,\n",
|
1330 |
+
" 'Top-1 Accuracy': 0.87109375}"
|
1331 |
+
]
|
1332 |
+
},
|
1333 |
+
"execution_count": 99,
|
1334 |
+
"metadata": {},
|
1335 |
+
"output_type": "execute_result"
|
1336 |
+
}
|
1337 |
+
],
|
1338 |
+
"source": [
|
1339 |
+
"calc_metrics(model, dataset['test'], abstract_proba=1.0)"
|
1340 |
+
]
|
1341 |
+
},
|
1342 |
+
{
|
1343 |
+
"cell_type": "code",
|
1344 |
+
"execution_count": 100,
|
1345 |
+
"id": "729a58f4-826c-4bc5-9f9a-78c1c30d39d7",
|
1346 |
+
"metadata": {},
|
1347 |
+
"outputs": [],
|
1348 |
+
"source": [
|
1349 |
+
"from transformers import AutoTokenizer\n",
|
1350 |
+
"from transformers import AutoModelForSequenceClassification\n",
|
1351 |
+
"\n",
|
1352 |
+
"tokenizer = AutoTokenizer.from_pretrained('checkpoints/checkpoint-5000/')\n",
|
1353 |
+
"model = AutoModelForSequenceClassification.from_pretrained('checkpoints/checkpoint-5000/')\n",
|
1354 |
+
"model.to('cuda')\n",
|
1355 |
+
"model.eval();"
|
1356 |
+
]
|
1357 |
+
},
|
1358 |
+
{
|
1359 |
+
"cell_type": "code",
|
1360 |
+
"execution_count": 101,
|
1361 |
+
"id": "c03bb517-fd9b-4c6f-9f2b-362f25985f21",
|
1362 |
+
"metadata": {},
|
1363 |
+
"outputs": [
|
1364 |
+
{
|
1365 |
+
"data": {
|
1366 |
+
"application/vnd.jupyter.widget-view+json": {
|
1367 |
+
"model_id": "c4ab0c6ffacd4ffc8be69580a38cca7a",
|
1368 |
+
"version_major": 2,
|
1369 |
+
"version_minor": 0
|
1370 |
+
},
|
1371 |
+
"text/plain": [
|
1372 |
+
" 0%| | 0/64 [00:00<?, ?it/s]"
|
1373 |
+
]
|
1374 |
+
},
|
1375 |
+
"metadata": {},
|
1376 |
+
"output_type": "display_data"
|
1377 |
+
},
|
1378 |
+
{
|
1379 |
+
"data": {
|
1380 |
+
"text/plain": [
|
1381 |
+
"{'Recall@0.95': 0.9222516741071428,\n",
|
1382 |
+
" 'Precision@0.95': 0.32720513773363014,\n",
|
1383 |
+
" 'Top-1 Accuracy': 0.796875}"
|
1384 |
+
]
|
1385 |
+
},
|
1386 |
+
"execution_count": 101,
|
1387 |
+
"metadata": {},
|
1388 |
+
"output_type": "execute_result"
|
1389 |
+
}
|
1390 |
+
],
|
1391 |
+
"source": [
|
1392 |
+
"calc_metrics(model, dataset['test'], abstract_proba=0.0)"
|
1393 |
+
]
|
1394 |
+
},
|
1395 |
+
{
|
1396 |
+
"cell_type": "code",
|
1397 |
+
"execution_count": 102,
|
1398 |
+
"id": "51a9ffa7-e303-4d49-91a9-53f193b3f9fb",
|
1399 |
+
"metadata": {},
|
1400 |
+
"outputs": [
|
1401 |
+
{
|
1402 |
+
"data": {
|
1403 |
+
"application/vnd.jupyter.widget-view+json": {
|
1404 |
+
"model_id": "87523b78c8264e60ac0558872c2c5030",
|
1405 |
+
"version_major": 2,
|
1406 |
+
"version_minor": 0
|
1407 |
+
},
|
1408 |
+
"text/plain": [
|
1409 |
+
" 0%| | 0/64 [00:00<?, ?it/s]"
|
1410 |
+
]
|
1411 |
+
},
|
1412 |
+
"metadata": {},
|
1413 |
+
"output_type": "display_data"
|
1414 |
+
},
|
1415 |
+
{
|
1416 |
+
"data": {
|
1417 |
+
"text/plain": [
|
1418 |
+
"{'Recall@0.95': 0.932661365327381,\n",
|
1419 |
+
" 'Precision@0.95': 0.3758523827747189,\n",
|
1420 |
+
" 'Top-1 Accuracy': 0.87109375}"
|
1421 |
+
]
|
1422 |
+
},
|
1423 |
+
"execution_count": 102,
|
1424 |
+
"metadata": {},
|
1425 |
+
"output_type": "execute_result"
|
1426 |
+
}
|
1427 |
+
],
|
1428 |
+
"source": [
|
1429 |
+
"calc_metrics(model, dataset['test'], abstract_proba=1.0)"
|
1430 |
+
]
|
1431 |
+
},
|
1432 |
+
{
|
1433 |
+
"cell_type": "markdown",
|
1434 |
+
"id": "7d13b1f7-e2a9-418e-8d0c-d5e1f1cf6ce4",
|
1435 |
+
"metadata": {},
|
1436 |
+
"source": [
|
1437 |
+
"* С наличием abstract качество ожидаемо выше;\n",
|
1438 |
+
"* Модели с 4k и 5k итераций очень близки по метрикам, но с 5k точность чуть лучше — на инференс возьмем ее."
|
1439 |
+
]
|
1440 |
+
}
|
1441 |
+
],
|
1442 |
+
"metadata": {
|
1443 |
+
"kernelspec": {
|
1444 |
+
"display_name": "Python 3 (ipykernel)",
|
1445 |
+
"language": "python",
|
1446 |
+
"name": "python3"
|
1447 |
+
},
|
1448 |
+
"language_info": {
|
1449 |
+
"codemirror_mode": {
|
1450 |
+
"name": "ipython",
|
1451 |
+
"version": 3
|
1452 |
+
},
|
1453 |
+
"file_extension": ".py",
|
1454 |
+
"mimetype": "text/x-python",
|
1455 |
+
"name": "python",
|
1456 |
+
"nbconvert_exporter": "python",
|
1457 |
+
"pygments_lexer": "ipython3",
|
1458 |
+
"version": "3.8.10"
|
1459 |
+
}
|
1460 |
+
},
|
1461 |
+
"nbformat": 4,
|
1462 |
+
"nbformat_minor": 5
|
1463 |
+
}
|