gefedya commited on
Commit
01e0066
1 Parent(s): 2b96aaf

Upload model_training_and_testing.ipynb

Browse files
Files changed (1) hide show
  1. model_training_and_testing.ipynb +1241 -0
model_training_and_testing.ipynb ADDED
@@ -0,0 +1,1241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "yViS84WpSENo"
7
+ },
8
+ "source": [
9
+ "### Sentiment analysis on twitter dataset\n",
10
+ "\n",
11
+ "Возьмем модель `sentiment-roberta-large-english` и дообучим ее на датасете твитов `carblacac/twitter-sentiment-analysis`."
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": 2,
17
+ "metadata": {
18
+ "id": "X-X7InJOGwT_"
19
+ },
20
+ "outputs": [],
21
+ "source": [
22
+ "#!streamlit run /home/theodore/anaconda3/envs/ds/lib/python3.10/site-packages/ipykernel_launcher.py"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": 3,
28
+ "metadata": {
29
+ "colab": {
30
+ "base_uri": "https://localhost:8080/"
31
+ },
32
+ "id": "RKHGyhG9G6WF",
33
+ "outputId": "bbaeac00-caa0-4eec-b77e-e44dbdf78f94"
34
+ },
35
+ "outputs": [
36
+ {
37
+ "name": "stdout",
38
+ "output_type": "stream",
39
+ "text": [
40
+ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
41
+ "Requirement already satisfied: transformers in /usr/local/lib/python3.9/dist-packages (4.28.1)\n",
42
+ "Requirement already satisfied: datasets in /usr/local/lib/python3.9/dist-packages (2.11.0)\n",
43
+ "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.9/dist-packages (from transformers) (4.65.0)\n",
44
+ "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.9/dist-packages (from transformers) (1.22.4)\n",
45
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.9/dist-packages (from transformers) (3.11.0)\n",
46
+ "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.9/dist-packages (from transformers) (6.0)\n",
47
+ "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.9/dist-packages (from transformers) (23.0)\n",
48
+ "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.9/dist-packages (from transformers) (2022.10.31)\n",
49
+ "Requirement already satisfied: huggingface-hub<1.0,>=0.11.0 in /usr/local/lib/python3.9/dist-packages (from transformers) (0.13.4)\n",
50
+ "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.9/dist-packages (from transformers) (0.13.3)\n",
51
+ "Requirement already satisfied: requests in /usr/local/lib/python3.9/dist-packages (from transformers) (2.27.1)\n",
52
+ "Requirement already satisfied: aiohttp in /usr/local/lib/python3.9/dist-packages (from datasets) (3.8.4)\n",
53
+ "Requirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.9/dist-packages (from datasets) (2023.4.0)\n",
54
+ "Requirement already satisfied: xxhash in /usr/local/lib/python3.9/dist-packages (from datasets) (3.2.0)\n",
55
+ "Requirement already satisfied: pandas in /usr/local/lib/python3.9/dist-packages (from datasets) (1.5.3)\n",
56
+ "Requirement already satisfied: responses<0.19 in /usr/local/lib/python3.9/dist-packages (from datasets) (0.18.0)\n",
57
+ "Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.9/dist-packages (from datasets) (9.0.0)\n",
58
+ "Requirement already satisfied: dill<0.3.7,>=0.3.0 in /usr/local/lib/python3.9/dist-packages (from datasets) (0.3.6)\n",
59
+ "Requirement already satisfied: multiprocess in /usr/local/lib/python3.9/dist-packages (from datasets) (0.70.14)\n",
60
+ "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets) (2.0.12)\n",
61
+ "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets) (1.8.2)\n",
62
+ "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets) (4.0.2)\n",
63
+ "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets) (22.2.0)\n",
64
+ "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets) (1.3.1)\n",
65
+ "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets) (1.3.3)\n",
66
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets) (6.0.4)\n",
67
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.9/dist-packages (from huggingface-hub<1.0,>=0.11.0->transformers) (4.5.0)\n",
68
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.9/dist-packages (from requests->transformers) (2022.12.7)\n",
69
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.9/dist-packages (from requests->transformers) (3.4)\n",
70
+ "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.9/dist-packages (from requests->transformers) (1.26.15)\n",
71
+ "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.9/dist-packages (from pandas->datasets) (2.8.2)\n",
72
+ "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.9/dist-packages (from pandas->datasets) (2022.7.1)\n",
73
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.9/dist-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.16.0)\n"
74
+ ]
75
+ }
76
+ ],
77
+ "source": [
78
+ "!pip install transformers datasets"
79
+ ]
80
+ },
81
+ {
82
+ "cell_type": "code",
83
+ "execution_count": 1,
84
+ "metadata": {
85
+ "id": "ohaBadE4GwUA"
86
+ },
87
+ "outputs": [],
88
+ "source": [
89
+ "# import json\n",
90
+ "\n",
91
+ "# with open(\"arxivData.json\", 'r') as f:\n",
92
+ "# arxiv_data = json.load(f)\n",
93
+ "\n",
94
+ "# arxiv_data[0]['tag']"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "code",
99
+ "execution_count": 1,
100
+ "metadata": {
101
+ "id": "HAP_PN0iGwUB"
102
+ },
103
+ "outputs": [],
104
+ "source": [
105
+ "from datasets import load_dataset\n",
106
+ "import transformers\n",
107
+ "import pandas as pd\n",
108
+ "import numpy as np"
109
+ ]
110
+ },
111
+ {
112
+ "cell_type": "code",
113
+ "execution_count": 2,
114
+ "metadata": {
115
+ "colab": {
116
+ "base_uri": "https://localhost:8080/",
117
+ "height": 86,
118
+ "referenced_widgets": [
119
+ "06de297475874730a793d3255cfe9a0d",
120
+ "1a192819128a429b96c571f48c7de552",
121
+ "4e50c449ee2047148e35c76fddf38e08",
122
+ "12483d6de04546feb5ee3b99ede894dc",
123
+ "6bc412e1b7c04160b58f83db5841d255",
124
+ "8e44d4f0218f4c43bb28a22a1b0472f7",
125
+ "9b9e621263e24472be473d87b2d54db5",
126
+ "f4947eda1c18457a80aa7c68df79d32a",
127
+ "11585997b0494e3db0f0121538e4f9f4",
128
+ "dc511eb900064df8a10177ba5ea1e7ff",
129
+ "2973a9eaa7f34f70bb02d4ffaea67462"
130
+ ]
131
+ },
132
+ "id": "ctTETRGcGwUC",
133
+ "outputId": "49c41fe7-acc0-4678-d257-aac1502edc72"
134
+ },
135
+ "outputs": [
136
+ {
137
+ "name": "stderr",
138
+ "output_type": "stream",
139
+ "text": [
140
+ "Found cached dataset twitter-sentiment-analysis (/home/theodore/.cache/huggingface/datasets/carblacac___twitter-sentiment-analysis/default/1.0.0/cd65e23e456de6a4f7264e305380b0ffe804d6f5bfd361c0ec0f68d8d1fab95b)\n"
141
+ ]
142
+ },
143
+ {
144
+ "data": {
145
+ "application/vnd.jupyter.widget-view+json": {
146
+ "model_id": "ebd17ca95bce4cc3b9d0f9597a5d9a91",
147
+ "version_major": 2,
148
+ "version_minor": 0
149
+ },
150
+ "text/plain": [
151
+ " 0%| | 0/3 [00:00<?, ?it/s]"
152
+ ]
153
+ },
154
+ "metadata": {},
155
+ "output_type": "display_data"
156
+ }
157
+ ],
158
+ "source": [
159
+ "data = load_dataset(\"carblacac/twitter-sentiment-analysis\")"
160
+ ]
161
+ },
162
+ {
163
+ "cell_type": "code",
164
+ "execution_count": 3,
165
+ "metadata": {
166
+ "colab": {
167
+ "base_uri": "https://localhost:8080/"
168
+ },
169
+ "id": "HRevJqXIGwUC",
170
+ "outputId": "87eb5fc6-817f-4abe-c34d-57320c91229f"
171
+ },
172
+ "outputs": [
173
+ {
174
+ "data": {
175
+ "text/plain": [
176
+ "DatasetDict({\n",
177
+ " train: Dataset({\n",
178
+ " features: ['text', 'feeling'],\n",
179
+ " num_rows: 119988\n",
180
+ " })\n",
181
+ " validation: Dataset({\n",
182
+ " features: ['text', 'feeling'],\n",
183
+ " num_rows: 29997\n",
184
+ " })\n",
185
+ " test: Dataset({\n",
186
+ " features: ['text', 'feeling'],\n",
187
+ " num_rows: 61998\n",
188
+ " })\n",
189
+ "})"
190
+ ]
191
+ },
192
+ "execution_count": 3,
193
+ "metadata": {},
194
+ "output_type": "execute_result"
195
+ }
196
+ ],
197
+ "source": [
198
+ "data"
199
+ ]
200
+ },
201
+ {
202
+ "cell_type": "code",
203
+ "execution_count": 4,
204
+ "metadata": {
205
+ "id": "CAUZ4zWgGwUD"
206
+ },
207
+ "outputs": [],
208
+ "source": [
209
+ "# tweets_df = pd.read_csv(\"twitter_sentiment.csv\", names=[\"target\", \"ids\", \"date\", \"flag\", \"user\", \"text\"],\n",
210
+ "# encoding='utf-8', encoding_errors='ignore')\n",
211
+ "# tweets_df.head()\n",
212
+ "# #tweets_df.shape"
213
+ ]
214
+ },
215
+ {
216
+ "cell_type": "code",
217
+ "execution_count": 5,
218
+ "metadata": {
219
+ "colab": {
220
+ "base_uri": "https://localhost:8080/",
221
+ "height": 72,
222
+ "referenced_widgets": [
223
+ "5e4bdf9e5c8347159c83772d892f2e4f",
224
+ "dfc20d27805d4ba9a388cfa24384285e",
225
+ "de9e434a5192455ab2882dff5f0ed3d9",
226
+ "c41e383cd66949e7979a9f83c932c3bf",
227
+ "af522a70d4e04cfa8bdf5807cf30153b",
228
+ "330be1100e454f9da15ad9553bfaf4a2",
229
+ "f815ae8495f44970b648c942d2ac11c0",
230
+ "6018eca5deaf41858ce694efc6b3624d",
231
+ "51291a2ef41f4e9fbe68d1e4151fb4e2",
232
+ "ec6d9ab9eda24231b577c75136f902f8",
233
+ "7fa05036dfb345898c06ecb09dca0363"
234
+ ]
235
+ },
236
+ "id": "utD3wZZLGwUE",
237
+ "outputId": "8c7db1ba-21f4-48cb-d89f-95b286bc8ee7"
238
+ },
239
+ "outputs": [
240
+ {
241
+ "name": "stderr",
242
+ "output_type": "stream",
243
+ "text": [
244
+ "Loading cached processed dataset at /home/theodore/.cache/huggingface/datasets/carblacac___twitter-sentiment-analysis/default/1.0.0/cd65e23e456de6a4f7264e305380b0ffe804d6f5bfd361c0ec0f68d8d1fab95b/cache-798641a24e025595.arrow\n",
245
+ "Loading cached processed dataset at /home/theodore/.cache/huggingface/datasets/carblacac___twitter-sentiment-analysis/default/1.0.0/cd65e23e456de6a4f7264e305380b0ffe804d6f5bfd361c0ec0f68d8d1fab95b/cache-5ac250acc1a0a444.arrow\n",
246
+ "Loading cached processed dataset at /home/theodore/.cache/huggingface/datasets/carblacac___twitter-sentiment-analysis/default/1.0.0/cd65e23e456de6a4f7264e305380b0ffe804d6f5bfd361c0ec0f68d8d1fab95b/cache-6ee382dff31af9fb.arrow\n"
247
+ ]
248
+ }
249
+ ],
250
+ "source": [
251
+ "from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModel, Trainer, TrainingArguments, LineByLineTextDataset\n",
252
+ "from transformers import pipeline\n",
253
+ "\n",
254
+ "# classifier = pipeline('sentiment-analysis', model=\"siebert/sentiment-roberta-large-english\")\n",
255
+ "\n",
256
+ "\n",
257
+ "# print(\"Preparing the training data...\")\n",
258
+ "# # dataset = LineByLineTextDataset(\n",
259
+ "# # file_path=<MAKE_YOUR_DATA_HERE>, tokenizer=tokenizer, block_size=128)\n",
260
+ "\n",
261
+ "tokenizer = AutoTokenizer.from_pretrained(\"siebert/sentiment-roberta-large-english\")\n",
262
+ "model = AutoModelForSequenceClassification.from_pretrained(\"siebert/sentiment-roberta-large-english\", num_labels=2)\n",
263
+ "\n",
264
+ "# print(\"Dataset ready!\")\n",
265
+ "dataset = data.map(lambda xs: tokenizer(xs[\"text\"], truncation=True, padding='max_length'))"
266
+ ]
267
+ },
268
+ {
269
+ "cell_type": "code",
270
+ "execution_count": 10,
271
+ "metadata": {
272
+ "id": "J2Q492BsOedY"
273
+ },
274
+ "outputs": [],
275
+ "source": [
276
+ "# model.to('cpu');"
277
+ ]
278
+ },
279
+ {
280
+ "cell_type": "code",
281
+ "execution_count": 7,
282
+ "metadata": {
283
+ "id": "xRYsJ1KTGwUF"
284
+ },
285
+ "outputs": [],
286
+ "source": [
287
+ "from datasets import ClassLabel\n",
288
+ "\n",
289
+ "dataset = dataset.rename_column(\"feeling\", \"labels\")\n",
290
+ "dataset = dataset.cast_column(\"labels\", ClassLabel(names=['0', '1']))\n",
291
+ "dataset = dataset.align_labels_with_mapping({'0': 1, '1': 0}, \"labels\")"
292
+ ]
293
+ },
294
+ {
295
+ "cell_type": "code",
296
+ "execution_count": 13,
297
+ "metadata": {
298
+ "colab": {
299
+ "base_uri": "https://localhost:8080/",
300
+ "height": 339
301
+ },
302
+ "id": "BsGX2H5tGwUG",
303
+ "outputId": "dbcb10ea-a7e8-431c-b96e-9fad8b448eb0"
304
+ },
305
+ "outputs": [
306
+ {
307
+ "name": "stderr",
308
+ "output_type": "stream",
309
+ "text": [
310
+ "/usr/local/lib/python3.9/dist-packages/transformers/optimization.py:391: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
311
+ " warnings.warn(\n"
312
+ ]
313
+ },
314
+ {
315
+ "data": {
316
+ "text/html": [
317
+ "\n",
318
+ " <div>\n",
319
+ " \n",
320
+ " <progress value='2500' max='2500' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
321
+ " [2500/2500 54:00, Epoch 1/1]\n",
322
+ " </div>\n",
323
+ " <table border=\"1\" class=\"dataframe\">\n",
324
+ " <thead>\n",
325
+ " <tr style=\"text-align: left;\">\n",
326
+ " <th>Step</th>\n",
327
+ " <th>Training Loss</th>\n",
328
+ " </tr>\n",
329
+ " </thead>\n",
330
+ " <tbody>\n",
331
+ " <tr>\n",
332
+ " <td>500</td>\n",
333
+ " <td>0.721100</td>\n",
334
+ " </tr>\n",
335
+ " <tr>\n",
336
+ " <td>1000</td>\n",
337
+ " <td>0.709100</td>\n",
338
+ " </tr>\n",
339
+ " <tr>\n",
340
+ " <td>1500</td>\n",
341
+ " <td>0.708400</td>\n",
342
+ " </tr>\n",
343
+ " <tr>\n",
344
+ " <td>2000</td>\n",
345
+ " <td>0.703900</td>\n",
346
+ " </tr>\n",
347
+ " <tr>\n",
348
+ " <td>2500</td>\n",
349
+ " <td>0.701600</td>\n",
350
+ " </tr>\n",
351
+ " </tbody>\n",
352
+ "</table><p>"
353
+ ],
354
+ "text/plain": [
355
+ "<IPython.core.display.HTML object>"
356
+ ]
357
+ },
358
+ "metadata": {},
359
+ "output_type": "display_data"
360
+ },
361
+ {
362
+ "data": {
363
+ "text/plain": [
364
+ "TrainOutput(global_step=2500, training_loss=0.7088180541992187, metrics={'train_runtime': 3241.3296, 'train_samples_per_second': 3.085, 'train_steps_per_second': 0.771, 'total_flos': 9319313633280000.0, 'train_loss': 0.7088180541992187, 'epoch': 1.0})"
365
+ ]
366
+ },
367
+ "execution_count": 13,
368
+ "metadata": {},
369
+ "output_type": "execute_result"
370
+ }
371
+ ],
372
+ "source": [
373
+ "trainer = Trainer(\n",
374
+ " model=model, train_dataset=dataset[\"train\"].shuffle().select(range(10000)),\n",
375
+ " eval_dataset = dataset['test'].select(range(5000)),\n",
376
+ " args=TrainingArguments(\n",
377
+ " output_dir=\"./my_saved_model\", overwrite_output_dir=True,\n",
378
+ " num_train_epochs=1, per_device_train_batch_size=4,\n",
379
+ " save_steps=10_000, save_total_limit=2),\n",
380
+ ")\n",
381
+ "\n",
382
+ "trainer.train()"
383
+ ]
384
+ },
385
+ {
386
+ "cell_type": "code",
387
+ "execution_count": 14,
388
+ "metadata": {
389
+ "id": "DZLeB7tSWoKO"
390
+ },
391
+ "outputs": [],
392
+ "source": [
393
+ "import torch\n",
394
+ "torch.save(model.state_dict(), \"./model_cached.pth\")"
395
+ ]
396
+ },
397
+ {
398
+ "cell_type": "code",
399
+ "execution_count": null,
400
+ "metadata": {
401
+ "id": "PRFXZltKGwUH"
402
+ },
403
+ "outputs": [],
404
+ "source": [
405
+ "import pandas as pd\n",
406
+ "df = pd.read_json(\"./arxivData.json\")\n",
407
+ "df.loc[0]"
408
+ ]
409
+ },
410
+ {
411
+ "cell_type": "code",
412
+ "execution_count": 18,
413
+ "metadata": {
414
+ "colab": {
415
+ "base_uri": "https://localhost:8080/"
416
+ },
417
+ "id": "tBObhO5PfOGZ",
418
+ "outputId": "3771e041-dac7-43ac-ff4e-fc202155a004"
419
+ },
420
+ "outputs": [
421
+ {
422
+ "name": "stdout",
423
+ "output_type": "stream",
424
+ "text": [
425
+ "Mounted at /content/gdrive\n"
426
+ ]
427
+ }
428
+ ],
429
+ "source": [
430
+ "from google.colab import drive\n",
431
+ "drive.mount('/content/gdrive')"
432
+ ]
433
+ },
434
+ {
435
+ "cell_type": "code",
436
+ "execution_count": 19,
437
+ "metadata": {
438
+ "id": "6O7BNuwXGwUI"
439
+ },
440
+ "outputs": [],
441
+ "source": [
442
+ "model_save_name = 'model_cached.pth'\n",
443
+ "path = F\"/content/gdrive/My Drive/cached_model.pth\" \n",
444
+ "torch.save(model.state_dict(), path)"
445
+ ]
446
+ },
447
+ {
448
+ "cell_type": "code",
449
+ "execution_count": 6,
450
+ "metadata": {},
451
+ "outputs": [],
452
+ "source": [
453
+ "import numpy as np\n",
454
+ "import torch\n",
455
+ "from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModel, Trainer, TrainingArguments, LineByLineTextDataset\n",
456
+ "import evaluate\n",
457
+ "\n",
458
+ "\n",
459
+ "model_enhanced = AutoModelForSequenceClassification.from_pretrained(\"siebert/sentiment-roberta-large-english\", num_labels=2)\n",
460
+ "model_enhanced.load_state_dict(torch.load('model_cached_2.pth', map_location=torch.device('cpu')))\n",
461
+ "\n",
462
+ "metric = evaluate.load(\"accuracy\")\n",
463
+ "\n",
464
+ "def compute_metrics(eval_pred):\n",
465
+ " logits, labels = eval_pred\n",
466
+ " predictions = np.argmax(logits, axis=-1)\n",
467
+ " return metric.compute(predictions=predictions, references=labels)\n"
468
+ ]
469
+ },
470
+ {
471
+ "cell_type": "code",
472
+ "execution_count": 8,
473
+ "metadata": {},
474
+ "outputs": [],
475
+ "source": [
476
+ "task_evaluator = evaluate.evaluator(\"text-classification\")\n",
477
+ "data_to_evaluate = dataset[\"test\"].select(range(500))\n",
478
+ "data_to_evaluate = data_to_evaluate.rename_column(\"feeling\", \"label\")\n",
479
+ "\n",
480
+ "results = task_evaluator.compute(\n",
481
+ " model_or_pipeline=model_enhanced,\n",
482
+ " tokenizer=tokenizer,\n",
483
+ " data=data_to_evaluate,\n",
484
+ " label_mapping={\"POSITIVE\": 1.0, \"NEGATIVE\": 0.0},\n",
485
+ " metric=\"accuracy\",\n",
486
+ " strategy=\"simple\"\n",
487
+ ")"
488
+ ]
489
+ },
490
+ {
491
+ "cell_type": "code",
492
+ "execution_count": 9,
493
+ "metadata": {},
494
+ "outputs": [
495
+ {
496
+ "data": {
497
+ "text/plain": [
498
+ "{'accuracy': 0.758,\n",
499
+ " 'total_time_in_seconds': 123.58850028699817,\n",
500
+ " 'samples_per_second': 4.045683852776724,\n",
501
+ " 'latency_in_seconds': 0.2471770005739963}"
502
+ ]
503
+ },
504
+ "execution_count": 9,
505
+ "metadata": {},
506
+ "output_type": "execute_result"
507
+ }
508
+ ],
509
+ "source": [
510
+ "results"
511
+ ]
512
+ },
513
+ {
514
+ "cell_type": "code",
515
+ "execution_count": null,
516
+ "metadata": {},
517
+ "outputs": [],
518
+ "source": []
519
+ }
520
+ ],
521
+ "metadata": {
522
+ "accelerator": "GPU",
523
+ "colab": {
524
+ "provenance": []
525
+ },
526
+ "gpuClass": "standard",
527
+ "kernelspec": {
528
+ "display_name": "ds",
529
+ "language": "python",
530
+ "name": "python3"
531
+ },
532
+ "language_info": {
533
+ "codemirror_mode": {
534
+ "name": "ipython",
535
+ "version": 3
536
+ },
537
+ "file_extension": ".py",
538
+ "mimetype": "text/x-python",
539
+ "name": "python",
540
+ "nbconvert_exporter": "python",
541
+ "pygments_lexer": "ipython3",
542
+ "version": "3.10.8"
543
+ },
544
+ "orig_nbformat": 4,
545
+ "vscode": {
546
+ "interpreter": {
547
+ "hash": "8743941f33f095367e79a82533efd22bcd3b3a1e7031447532076cc1f09e6391"
548
+ }
549
+ },
550
+ "widgets": {
551
+ "application/vnd.jupyter.widget-state+json": {
552
+ "06de297475874730a793d3255cfe9a0d": {
553
+ "model_module": "@jupyter-widgets/controls",
554
+ "model_module_version": "1.5.0",
555
+ "model_name": "HBoxModel",
556
+ "state": {
557
+ "_dom_classes": [],
558
+ "_model_module": "@jupyter-widgets/controls",
559
+ "_model_module_version": "1.5.0",
560
+ "_model_name": "HBoxModel",
561
+ "_view_count": null,
562
+ "_view_module": "@jupyter-widgets/controls",
563
+ "_view_module_version": "1.5.0",
564
+ "_view_name": "HBoxView",
565
+ "box_style": "",
566
+ "children": [
567
+ "IPY_MODEL_1a192819128a429b96c571f48c7de552",
568
+ "IPY_MODEL_4e50c449ee2047148e35c76fddf38e08",
569
+ "IPY_MODEL_12483d6de04546feb5ee3b99ede894dc"
570
+ ],
571
+ "layout": "IPY_MODEL_6bc412e1b7c04160b58f83db5841d255"
572
+ }
573
+ },
574
+ "11585997b0494e3db0f0121538e4f9f4": {
575
+ "model_module": "@jupyter-widgets/controls",
576
+ "model_module_version": "1.5.0",
577
+ "model_name": "ProgressStyleModel",
578
+ "state": {
579
+ "_model_module": "@jupyter-widgets/controls",
580
+ "_model_module_version": "1.5.0",
581
+ "_model_name": "ProgressStyleModel",
582
+ "_view_count": null,
583
+ "_view_module": "@jupyter-widgets/base",
584
+ "_view_module_version": "1.2.0",
585
+ "_view_name": "StyleView",
586
+ "bar_color": null,
587
+ "description_width": ""
588
+ }
589
+ },
590
+ "12483d6de04546feb5ee3b99ede894dc": {
591
+ "model_module": "@jupyter-widgets/controls",
592
+ "model_module_version": "1.5.0",
593
+ "model_name": "HTMLModel",
594
+ "state": {
595
+ "_dom_classes": [],
596
+ "_model_module": "@jupyter-widgets/controls",
597
+ "_model_module_version": "1.5.0",
598
+ "_model_name": "HTMLModel",
599
+ "_view_count": null,
600
+ "_view_module": "@jupyter-widgets/controls",
601
+ "_view_module_version": "1.5.0",
602
+ "_view_name": "HTMLView",
603
+ "description": "",
604
+ "description_tooltip": null,
605
+ "layout": "IPY_MODEL_dc511eb900064df8a10177ba5ea1e7ff",
606
+ "placeholder": "​",
607
+ "style": "IPY_MODEL_2973a9eaa7f34f70bb02d4ffaea67462",
608
+ "value": " 3/3 [00:00&lt;00:00, 55.23it/s]"
609
+ }
610
+ },
611
+ "1a192819128a429b96c571f48c7de552": {
612
+ "model_module": "@jupyter-widgets/controls",
613
+ "model_module_version": "1.5.0",
614
+ "model_name": "HTMLModel",
615
+ "state": {
616
+ "_dom_classes": [],
617
+ "_model_module": "@jupyter-widgets/controls",
618
+ "_model_module_version": "1.5.0",
619
+ "_model_name": "HTMLModel",
620
+ "_view_count": null,
621
+ "_view_module": "@jupyter-widgets/controls",
622
+ "_view_module_version": "1.5.0",
623
+ "_view_name": "HTMLView",
624
+ "description": "",
625
+ "description_tooltip": null,
626
+ "layout": "IPY_MODEL_8e44d4f0218f4c43bb28a22a1b0472f7",
627
+ "placeholder": "​",
628
+ "style": "IPY_MODEL_9b9e621263e24472be473d87b2d54db5",
629
+ "value": "100%"
630
+ }
631
+ },
632
+ "2973a9eaa7f34f70bb02d4ffaea67462": {
633
+ "model_module": "@jupyter-widgets/controls",
634
+ "model_module_version": "1.5.0",
635
+ "model_name": "DescriptionStyleModel",
636
+ "state": {
637
+ "_model_module": "@jupyter-widgets/controls",
638
+ "_model_module_version": "1.5.0",
639
+ "_model_name": "DescriptionStyleModel",
640
+ "_view_count": null,
641
+ "_view_module": "@jupyter-widgets/base",
642
+ "_view_module_version": "1.2.0",
643
+ "_view_name": "StyleView",
644
+ "description_width": ""
645
+ }
646
+ },
647
+ "330be1100e454f9da15ad9553bfaf4a2": {
648
+ "model_module": "@jupyter-widgets/base",
649
+ "model_module_version": "1.2.0",
650
+ "model_name": "LayoutModel",
651
+ "state": {
652
+ "_model_module": "@jupyter-widgets/base",
653
+ "_model_module_version": "1.2.0",
654
+ "_model_name": "LayoutModel",
655
+ "_view_count": null,
656
+ "_view_module": "@jupyter-widgets/base",
657
+ "_view_module_version": "1.2.0",
658
+ "_view_name": "LayoutView",
659
+ "align_content": null,
660
+ "align_items": null,
661
+ "align_self": null,
662
+ "border": null,
663
+ "bottom": null,
664
+ "display": null,
665
+ "flex": null,
666
+ "flex_flow": null,
667
+ "grid_area": null,
668
+ "grid_auto_columns": null,
669
+ "grid_auto_flow": null,
670
+ "grid_auto_rows": null,
671
+ "grid_column": null,
672
+ "grid_gap": null,
673
+ "grid_row": null,
674
+ "grid_template_areas": null,
675
+ "grid_template_columns": null,
676
+ "grid_template_rows": null,
677
+ "height": null,
678
+ "justify_content": null,
679
+ "justify_items": null,
680
+ "left": null,
681
+ "margin": null,
682
+ "max_height": null,
683
+ "max_width": null,
684
+ "min_height": null,
685
+ "min_width": null,
686
+ "object_fit": null,
687
+ "object_position": null,
688
+ "order": null,
689
+ "overflow": null,
690
+ "overflow_x": null,
691
+ "overflow_y": null,
692
+ "padding": null,
693
+ "right": null,
694
+ "top": null,
695
+ "visibility": null,
696
+ "width": null
697
+ }
698
+ },
699
+ "4e50c449ee2047148e35c76fddf38e08": {
700
+ "model_module": "@jupyter-widgets/controls",
701
+ "model_module_version": "1.5.0",
702
+ "model_name": "FloatProgressModel",
703
+ "state": {
704
+ "_dom_classes": [],
705
+ "_model_module": "@jupyter-widgets/controls",
706
+ "_model_module_version": "1.5.0",
707
+ "_model_name": "FloatProgressModel",
708
+ "_view_count": null,
709
+ "_view_module": "@jupyter-widgets/controls",
710
+ "_view_module_version": "1.5.0",
711
+ "_view_name": "ProgressView",
712
+ "bar_style": "success",
713
+ "description": "",
714
+ "description_tooltip": null,
715
+ "layout": "IPY_MODEL_f4947eda1c18457a80aa7c68df79d32a",
716
+ "max": 3,
717
+ "min": 0,
718
+ "orientation": "horizontal",
719
+ "style": "IPY_MODEL_11585997b0494e3db0f0121538e4f9f4",
720
+ "value": 3
721
+ }
722
+ },
723
+ "51291a2ef41f4e9fbe68d1e4151fb4e2": {
724
+ "model_module": "@jupyter-widgets/controls",
725
+ "model_module_version": "1.5.0",
726
+ "model_name": "ProgressStyleModel",
727
+ "state": {
728
+ "_model_module": "@jupyter-widgets/controls",
729
+ "_model_module_version": "1.5.0",
730
+ "_model_name": "ProgressStyleModel",
731
+ "_view_count": null,
732
+ "_view_module": "@jupyter-widgets/base",
733
+ "_view_module_version": "1.2.0",
734
+ "_view_name": "StyleView",
735
+ "bar_color": null,
736
+ "description_width": ""
737
+ }
738
+ },
739
+ "5e4bdf9e5c8347159c83772d892f2e4f": {
740
+ "model_module": "@jupyter-widgets/controls",
741
+ "model_module_version": "1.5.0",
742
+ "model_name": "HBoxModel",
743
+ "state": {
744
+ "_dom_classes": [],
745
+ "_model_module": "@jupyter-widgets/controls",
746
+ "_model_module_version": "1.5.0",
747
+ "_model_name": "HBoxModel",
748
+ "_view_count": null,
749
+ "_view_module": "@jupyter-widgets/controls",
750
+ "_view_module_version": "1.5.0",
751
+ "_view_name": "HBoxView",
752
+ "box_style": "",
753
+ "children": [
754
+ "IPY_MODEL_dfc20d27805d4ba9a388cfa24384285e",
755
+ "IPY_MODEL_de9e434a5192455ab2882dff5f0ed3d9",
756
+ "IPY_MODEL_c41e383cd66949e7979a9f83c932c3bf"
757
+ ],
758
+ "layout": "IPY_MODEL_af522a70d4e04cfa8bdf5807cf30153b"
759
+ }
760
+ },
761
+ "6018eca5deaf41858ce694efc6b3624d": {
762
+ "model_module": "@jupyter-widgets/base",
763
+ "model_module_version": "1.2.0",
764
+ "model_name": "LayoutModel",
765
+ "state": {
766
+ "_model_module": "@jupyter-widgets/base",
767
+ "_model_module_version": "1.2.0",
768
+ "_model_name": "LayoutModel",
769
+ "_view_count": null,
770
+ "_view_module": "@jupyter-widgets/base",
771
+ "_view_module_version": "1.2.0",
772
+ "_view_name": "LayoutView",
773
+ "align_content": null,
774
+ "align_items": null,
775
+ "align_self": null,
776
+ "border": null,
777
+ "bottom": null,
778
+ "display": null,
779
+ "flex": null,
780
+ "flex_flow": null,
781
+ "grid_area": null,
782
+ "grid_auto_columns": null,
783
+ "grid_auto_flow": null,
784
+ "grid_auto_rows": null,
785
+ "grid_column": null,
786
+ "grid_gap": null,
787
+ "grid_row": null,
788
+ "grid_template_areas": null,
789
+ "grid_template_columns": null,
790
+ "grid_template_rows": null,
791
+ "height": null,
792
+ "justify_content": null,
793
+ "justify_items": null,
794
+ "left": null,
795
+ "margin": null,
796
+ "max_height": null,
797
+ "max_width": null,
798
+ "min_height": null,
799
+ "min_width": null,
800
+ "object_fit": null,
801
+ "object_position": null,
802
+ "order": null,
803
+ "overflow": null,
804
+ "overflow_x": null,
805
+ "overflow_y": null,
806
+ "padding": null,
807
+ "right": null,
808
+ "top": null,
809
+ "visibility": null,
810
+ "width": null
811
+ }
812
+ },
813
+ "6bc412e1b7c04160b58f83db5841d255": {
814
+ "model_module": "@jupyter-widgets/base",
815
+ "model_module_version": "1.2.0",
816
+ "model_name": "LayoutModel",
817
+ "state": {
818
+ "_model_module": "@jupyter-widgets/base",
819
+ "_model_module_version": "1.2.0",
820
+ "_model_name": "LayoutModel",
821
+ "_view_count": null,
822
+ "_view_module": "@jupyter-widgets/base",
823
+ "_view_module_version": "1.2.0",
824
+ "_view_name": "LayoutView",
825
+ "align_content": null,
826
+ "align_items": null,
827
+ "align_self": null,
828
+ "border": null,
829
+ "bottom": null,
830
+ "display": null,
831
+ "flex": null,
832
+ "flex_flow": null,
833
+ "grid_area": null,
834
+ "grid_auto_columns": null,
835
+ "grid_auto_flow": null,
836
+ "grid_auto_rows": null,
837
+ "grid_column": null,
838
+ "grid_gap": null,
839
+ "grid_row": null,
840
+ "grid_template_areas": null,
841
+ "grid_template_columns": null,
842
+ "grid_template_rows": null,
843
+ "height": null,
844
+ "justify_content": null,
845
+ "justify_items": null,
846
+ "left": null,
847
+ "margin": null,
848
+ "max_height": null,
849
+ "max_width": null,
850
+ "min_height": null,
851
+ "min_width": null,
852
+ "object_fit": null,
853
+ "object_position": null,
854
+ "order": null,
855
+ "overflow": null,
856
+ "overflow_x": null,
857
+ "overflow_y": null,
858
+ "padding": null,
859
+ "right": null,
860
+ "top": null,
861
+ "visibility": null,
862
+ "width": null
863
+ }
864
+ },
865
+ "7fa05036dfb345898c06ecb09dca0363": {
866
+ "model_module": "@jupyter-widgets/controls",
867
+ "model_module_version": "1.5.0",
868
+ "model_name": "DescriptionStyleModel",
869
+ "state": {
870
+ "_model_module": "@jupyter-widgets/controls",
871
+ "_model_module_version": "1.5.0",
872
+ "_model_name": "DescriptionStyleModel",
873
+ "_view_count": null,
874
+ "_view_module": "@jupyter-widgets/base",
875
+ "_view_module_version": "1.2.0",
876
+ "_view_name": "StyleView",
877
+ "description_width": ""
878
+ }
879
+ },
880
+ "8e44d4f0218f4c43bb28a22a1b0472f7": {
881
+ "model_module": "@jupyter-widgets/base",
882
+ "model_module_version": "1.2.0",
883
+ "model_name": "LayoutModel",
884
+ "state": {
885
+ "_model_module": "@jupyter-widgets/base",
886
+ "_model_module_version": "1.2.0",
887
+ "_model_name": "LayoutModel",
888
+ "_view_count": null,
889
+ "_view_module": "@jupyter-widgets/base",
890
+ "_view_module_version": "1.2.0",
891
+ "_view_name": "LayoutView",
892
+ "align_content": null,
893
+ "align_items": null,
894
+ "align_self": null,
895
+ "border": null,
896
+ "bottom": null,
897
+ "display": null,
898
+ "flex": null,
899
+ "flex_flow": null,
900
+ "grid_area": null,
901
+ "grid_auto_columns": null,
902
+ "grid_auto_flow": null,
903
+ "grid_auto_rows": null,
904
+ "grid_column": null,
905
+ "grid_gap": null,
906
+ "grid_row": null,
907
+ "grid_template_areas": null,
908
+ "grid_template_columns": null,
909
+ "grid_template_rows": null,
910
+ "height": null,
911
+ "justify_content": null,
912
+ "justify_items": null,
913
+ "left": null,
914
+ "margin": null,
915
+ "max_height": null,
916
+ "max_width": null,
917
+ "min_height": null,
918
+ "min_width": null,
919
+ "object_fit": null,
920
+ "object_position": null,
921
+ "order": null,
922
+ "overflow": null,
923
+ "overflow_x": null,
924
+ "overflow_y": null,
925
+ "padding": null,
926
+ "right": null,
927
+ "top": null,
928
+ "visibility": null,
929
+ "width": null
930
+ }
931
+ },
932
+ "9b9e621263e24472be473d87b2d54db5": {
933
+ "model_module": "@jupyter-widgets/controls",
934
+ "model_module_version": "1.5.0",
935
+ "model_name": "DescriptionStyleModel",
936
+ "state": {
937
+ "_model_module": "@jupyter-widgets/controls",
938
+ "_model_module_version": "1.5.0",
939
+ "_model_name": "DescriptionStyleModel",
940
+ "_view_count": null,
941
+ "_view_module": "@jupyter-widgets/base",
942
+ "_view_module_version": "1.2.0",
943
+ "_view_name": "StyleView",
944
+ "description_width": ""
945
+ }
946
+ },
947
+ "af522a70d4e04cfa8bdf5807cf30153b": {
948
+ "model_module": "@jupyter-widgets/base",
949
+ "model_module_version": "1.2.0",
950
+ "model_name": "LayoutModel",
951
+ "state": {
952
+ "_model_module": "@jupyter-widgets/base",
953
+ "_model_module_version": "1.2.0",
954
+ "_model_name": "LayoutModel",
955
+ "_view_count": null,
956
+ "_view_module": "@jupyter-widgets/base",
957
+ "_view_module_version": "1.2.0",
958
+ "_view_name": "LayoutView",
959
+ "align_content": null,
960
+ "align_items": null,
961
+ "align_self": null,
962
+ "border": null,
963
+ "bottom": null,
964
+ "display": null,
965
+ "flex": null,
966
+ "flex_flow": null,
967
+ "grid_area": null,
968
+ "grid_auto_columns": null,
969
+ "grid_auto_flow": null,
970
+ "grid_auto_rows": null,
971
+ "grid_column": null,
972
+ "grid_gap": null,
973
+ "grid_row": null,
974
+ "grid_template_areas": null,
975
+ "grid_template_columns": null,
976
+ "grid_template_rows": null,
977
+ "height": null,
978
+ "justify_content": null,
979
+ "justify_items": null,
980
+ "left": null,
981
+ "margin": null,
982
+ "max_height": null,
983
+ "max_width": null,
984
+ "min_height": null,
985
+ "min_width": null,
986
+ "object_fit": null,
987
+ "object_position": null,
988
+ "order": null,
989
+ "overflow": null,
990
+ "overflow_x": null,
991
+ "overflow_y": null,
992
+ "padding": null,
993
+ "right": null,
994
+ "top": null,
995
+ "visibility": "hidden",
996
+ "width": null
997
+ }
998
+ },
999
+ "c41e383cd66949e7979a9f83c932c3bf": {
1000
+ "model_module": "@jupyter-widgets/controls",
1001
+ "model_module_version": "1.5.0",
1002
+ "model_name": "HTMLModel",
1003
+ "state": {
1004
+ "_dom_classes": [],
1005
+ "_model_module": "@jupyter-widgets/controls",
1006
+ "_model_module_version": "1.5.0",
1007
+ "_model_name": "HTMLModel",
1008
+ "_view_count": null,
1009
+ "_view_module": "@jupyter-widgets/controls",
1010
+ "_view_module_version": "1.5.0",
1011
+ "_view_name": "HTMLView",
1012
+ "description": "",
1013
+ "description_tooltip": null,
1014
+ "layout": "IPY_MODEL_ec6d9ab9eda24231b577c75136f902f8",
1015
+ "placeholder": "​",
1016
+ "style": "IPY_MODEL_7fa05036dfb345898c06ecb09dca0363",
1017
+ "value": " 61869/61998 [00:48&lt;00:00, 1766.87 examples/s]"
1018
+ }
1019
+ },
1020
+ "dc511eb900064df8a10177ba5ea1e7ff": {
1021
+ "model_module": "@jupyter-widgets/base",
1022
+ "model_module_version": "1.2.0",
1023
+ "model_name": "LayoutModel",
1024
+ "state": {
1025
+ "_model_module": "@jupyter-widgets/base",
1026
+ "_model_module_version": "1.2.0",
1027
+ "_model_name": "LayoutModel",
1028
+ "_view_count": null,
1029
+ "_view_module": "@jupyter-widgets/base",
1030
+ "_view_module_version": "1.2.0",
1031
+ "_view_name": "LayoutView",
1032
+ "align_content": null,
1033
+ "align_items": null,
1034
+ "align_self": null,
1035
+ "border": null,
1036
+ "bottom": null,
1037
+ "display": null,
1038
+ "flex": null,
1039
+ "flex_flow": null,
1040
+ "grid_area": null,
1041
+ "grid_auto_columns": null,
1042
+ "grid_auto_flow": null,
1043
+ "grid_auto_rows": null,
1044
+ "grid_column": null,
1045
+ "grid_gap": null,
1046
+ "grid_row": null,
1047
+ "grid_template_areas": null,
1048
+ "grid_template_columns": null,
1049
+ "grid_template_rows": null,
1050
+ "height": null,
1051
+ "justify_content": null,
1052
+ "justify_items": null,
1053
+ "left": null,
1054
+ "margin": null,
1055
+ "max_height": null,
1056
+ "max_width": null,
1057
+ "min_height": null,
1058
+ "min_width": null,
1059
+ "object_fit": null,
1060
+ "object_position": null,
1061
+ "order": null,
1062
+ "overflow": null,
1063
+ "overflow_x": null,
1064
+ "overflow_y": null,
1065
+ "padding": null,
1066
+ "right": null,
1067
+ "top": null,
1068
+ "visibility": null,
1069
+ "width": null
1070
+ }
1071
+ },
1072
+ "de9e434a5192455ab2882dff5f0ed3d9": {
1073
+ "model_module": "@jupyter-widgets/controls",
1074
+ "model_module_version": "1.5.0",
1075
+ "model_name": "FloatProgressModel",
1076
+ "state": {
1077
+ "_dom_classes": [],
1078
+ "_model_module": "@jupyter-widgets/controls",
1079
+ "_model_module_version": "1.5.0",
1080
+ "_model_name": "FloatProgressModel",
1081
+ "_view_count": null,
1082
+ "_view_module": "@jupyter-widgets/controls",
1083
+ "_view_module_version": "1.5.0",
1084
+ "_view_name": "ProgressView",
1085
+ "bar_style": "",
1086
+ "description": "",
1087
+ "description_tooltip": null,
1088
+ "layout": "IPY_MODEL_6018eca5deaf41858ce694efc6b3624d",
1089
+ "max": 61998,
1090
+ "min": 0,
1091
+ "orientation": "horizontal",
1092
+ "style": "IPY_MODEL_51291a2ef41f4e9fbe68d1e4151fb4e2",
1093
+ "value": 61998
1094
+ }
1095
+ },
1096
+ "dfc20d27805d4ba9a388cfa24384285e": {
1097
+ "model_module": "@jupyter-widgets/controls",
1098
+ "model_module_version": "1.5.0",
1099
+ "model_name": "HTMLModel",
1100
+ "state": {
1101
+ "_dom_classes": [],
1102
+ "_model_module": "@jupyter-widgets/controls",
1103
+ "_model_module_version": "1.5.0",
1104
+ "_model_name": "HTMLModel",
1105
+ "_view_count": null,
1106
+ "_view_module": "@jupyter-widgets/controls",
1107
+ "_view_module_version": "1.5.0",
1108
+ "_view_name": "HTMLView",
1109
+ "description": "",
1110
+ "description_tooltip": null,
1111
+ "layout": "IPY_MODEL_330be1100e454f9da15ad9553bfaf4a2",
1112
+ "placeholder": "​",
1113
+ "style": "IPY_MODEL_f815ae8495f44970b648c942d2ac11c0",
1114
+ "value": "Map: 100%"
1115
+ }
1116
+ },
1117
+ "ec6d9ab9eda24231b577c75136f902f8": {
1118
+ "model_module": "@jupyter-widgets/base",
1119
+ "model_module_version": "1.2.0",
1120
+ "model_name": "LayoutModel",
1121
+ "state": {
1122
+ "_model_module": "@jupyter-widgets/base",
1123
+ "_model_module_version": "1.2.0",
1124
+ "_model_name": "LayoutModel",
1125
+ "_view_count": null,
1126
+ "_view_module": "@jupyter-widgets/base",
1127
+ "_view_module_version": "1.2.0",
1128
+ "_view_name": "LayoutView",
1129
+ "align_content": null,
1130
+ "align_items": null,
1131
+ "align_self": null,
1132
+ "border": null,
1133
+ "bottom": null,
1134
+ "display": null,
1135
+ "flex": null,
1136
+ "flex_flow": null,
1137
+ "grid_area": null,
1138
+ "grid_auto_columns": null,
1139
+ "grid_auto_flow": null,
1140
+ "grid_auto_rows": null,
1141
+ "grid_column": null,
1142
+ "grid_gap": null,
1143
+ "grid_row": null,
1144
+ "grid_template_areas": null,
1145
+ "grid_template_columns": null,
1146
+ "grid_template_rows": null,
1147
+ "height": null,
1148
+ "justify_content": null,
1149
+ "justify_items": null,
1150
+ "left": null,
1151
+ "margin": null,
1152
+ "max_height": null,
1153
+ "max_width": null,
1154
+ "min_height": null,
1155
+ "min_width": null,
1156
+ "object_fit": null,
1157
+ "object_position": null,
1158
+ "order": null,
1159
+ "overflow": null,
1160
+ "overflow_x": null,
1161
+ "overflow_y": null,
1162
+ "padding": null,
1163
+ "right": null,
1164
+ "top": null,
1165
+ "visibility": null,
1166
+ "width": null
1167
+ }
1168
+ },
1169
+ "f4947eda1c18457a80aa7c68df79d32a": {
1170
+ "model_module": "@jupyter-widgets/base",
1171
+ "model_module_version": "1.2.0",
1172
+ "model_name": "LayoutModel",
1173
+ "state": {
1174
+ "_model_module": "@jupyter-widgets/base",
1175
+ "_model_module_version": "1.2.0",
1176
+ "_model_name": "LayoutModel",
1177
+ "_view_count": null,
1178
+ "_view_module": "@jupyter-widgets/base",
1179
+ "_view_module_version": "1.2.0",
1180
+ "_view_name": "LayoutView",
1181
+ "align_content": null,
1182
+ "align_items": null,
1183
+ "align_self": null,
1184
+ "border": null,
1185
+ "bottom": null,
1186
+ "display": null,
1187
+ "flex": null,
1188
+ "flex_flow": null,
1189
+ "grid_area": null,
1190
+ "grid_auto_columns": null,
1191
+ "grid_auto_flow": null,
1192
+ "grid_auto_rows": null,
1193
+ "grid_column": null,
1194
+ "grid_gap": null,
1195
+ "grid_row": null,
1196
+ "grid_template_areas": null,
1197
+ "grid_template_columns": null,
1198
+ "grid_template_rows": null,
1199
+ "height": null,
1200
+ "justify_content": null,
1201
+ "justify_items": null,
1202
+ "left": null,
1203
+ "margin": null,
1204
+ "max_height": null,
1205
+ "max_width": null,
1206
+ "min_height": null,
1207
+ "min_width": null,
1208
+ "object_fit": null,
1209
+ "object_position": null,
1210
+ "order": null,
1211
+ "overflow": null,
1212
+ "overflow_x": null,
1213
+ "overflow_y": null,
1214
+ "padding": null,
1215
+ "right": null,
1216
+ "top": null,
1217
+ "visibility": null,
1218
+ "width": null
1219
+ }
1220
+ },
1221
+ "f815ae8495f44970b648c942d2ac11c0": {
1222
+ "model_module": "@jupyter-widgets/controls",
1223
+ "model_module_version": "1.5.0",
1224
+ "model_name": "DescriptionStyleModel",
1225
+ "state": {
1226
+ "_model_module": "@jupyter-widgets/controls",
1227
+ "_model_module_version": "1.5.0",
1228
+ "_model_name": "DescriptionStyleModel",
1229
+ "_view_count": null,
1230
+ "_view_module": "@jupyter-widgets/base",
1231
+ "_view_module_version": "1.2.0",
1232
+ "_view_name": "StyleView",
1233
+ "description_width": ""
1234
+ }
1235
+ }
1236
+ }
1237
+ }
1238
+ },
1239
+ "nbformat": 4,
1240
+ "nbformat_minor": 0
1241
+ }