ksvmuralidhar commited on
Commit
df9c590
1 Parent(s): 1325bee

Upload 2 files

Browse files
insert_into_db_sent_tran_tsdae.ipynb ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "1aafbf18-de38-4fcf-8245-e2e9a584971f",
7
+ "metadata": {
8
+ "tags": []
9
+ },
10
+ "outputs": [],
11
+ "source": [
12
+ "# ! pip install pymilvus==2.3.4\n",
13
+ "# ! pip install pyarrow==12.0.0\n",
14
+ "# !pip install -U sentence-transformers"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": 2,
20
+ "id": "f1d8f101-f51b-4a50-b150-86e87c50c453",
21
+ "metadata": {
22
+ "tags": []
23
+ },
24
+ "outputs": [],
25
+ "source": [
26
+ "import numpy as np\n",
27
+ "import tensorflow as tf\n",
28
+ "from tqdm import tqdm\n",
29
+ "from dotenv import load_dotenv\n",
30
+ "import os\n",
31
+ "import pandas as pd\n",
32
+ "from pymilvus import connections, utility\n",
33
+ "from pymilvus import Collection, DataType, FieldSchema, CollectionSchema\n",
34
+ "import multiprocessing\n",
35
+ "from sentence_transformers import SentenceTransformer"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": 4,
41
+ "id": "4ad4e3ac-9685-4f12-8043-5fbcc373d3e1",
42
+ "metadata": {},
43
+ "outputs": [
44
+ {
45
+ "data": {
46
+ "text/plain": [
47
+ "[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]"
48
+ ]
49
+ },
50
+ "execution_count": 4,
51
+ "metadata": {},
52
+ "output_type": "execute_result"
53
+ }
54
+ ],
55
+ "source": [
56
+ "tf.config.list_physical_devices('GPU')"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "execution_count": 5,
62
+ "id": "da71d832-b8a7-452b-b736-538a3c069b54",
63
+ "metadata": {
64
+ "tags": []
65
+ },
66
+ "outputs": [
67
+ {
68
+ "data": {
69
+ "text/html": [
70
+ "<div>\n",
71
+ "<style scoped>\n",
72
+ " .dataframe tbody tr th:only-of-type {\n",
73
+ " vertical-align: middle;\n",
74
+ " }\n",
75
+ "\n",
76
+ " .dataframe tbody tr th {\n",
77
+ " vertical-align: top;\n",
78
+ " }\n",
79
+ "\n",
80
+ " .dataframe thead th {\n",
81
+ " text-align: right;\n",
82
+ " }\n",
83
+ "</style>\n",
84
+ "<table border=\"1\" class=\"dataframe\">\n",
85
+ " <thead>\n",
86
+ " <tr style=\"text-align: right;\">\n",
87
+ " <th></th>\n",
88
+ " <th>index</th>\n",
89
+ " <th>category</th>\n",
90
+ " <th>short_description</th>\n",
91
+ " </tr>\n",
92
+ " </thead>\n",
93
+ " <tbody>\n",
94
+ " <tr>\n",
95
+ " <th>0</th>\n",
96
+ " <td>0</td>\n",
97
+ " <td>SCIENCE</td>\n",
98
+ " <td>A closer look at water-splitting's solar fuel ...</td>\n",
99
+ " </tr>\n",
100
+ " <tr>\n",
101
+ " <th>1</th>\n",
102
+ " <td>1</td>\n",
103
+ " <td>SCIENCE</td>\n",
104
+ " <td>An irresistible scent makes locusts swarm, stu...</td>\n",
105
+ " </tr>\n",
106
+ " <tr>\n",
107
+ " <th>2</th>\n",
108
+ " <td>2</td>\n",
109
+ " <td>SCIENCE</td>\n",
110
+ " <td>Artificial intelligence warning: AI will know ...</td>\n",
111
+ " </tr>\n",
112
+ " <tr>\n",
113
+ " <th>3</th>\n",
114
+ " <td>3</td>\n",
115
+ " <td>SCIENCE</td>\n",
116
+ " <td>Glaciers Could Have Sculpted Mars Valleys: Study</td>\n",
117
+ " </tr>\n",
118
+ " <tr>\n",
119
+ " <th>4</th>\n",
120
+ " <td>4</td>\n",
121
+ " <td>SCIENCE</td>\n",
122
+ " <td>Perseid meteor shower 2020: What time and how ...</td>\n",
123
+ " </tr>\n",
124
+ " <tr>\n",
125
+ " <th>...</th>\n",
126
+ " <td>...</td>\n",
127
+ " <td>...</td>\n",
128
+ " <td>...</td>\n",
129
+ " </tr>\n",
130
+ " <tr>\n",
131
+ " <th>311171</th>\n",
132
+ " <td>311171</td>\n",
133
+ " <td>TECH</td>\n",
134
+ " <td>RIM CEO Thorsten Heins' 'Significant' Plans Fo...</td>\n",
135
+ " </tr>\n",
136
+ " <tr>\n",
137
+ " <th>311172</th>\n",
138
+ " <td>311172</td>\n",
139
+ " <td>SPORTS</td>\n",
140
+ " <td>Maria Sharapova Stunned By Victoria Azarenka I...</td>\n",
141
+ " </tr>\n",
142
+ " <tr>\n",
143
+ " <th>311173</th>\n",
144
+ " <td>311173</td>\n",
145
+ " <td>SPORTS</td>\n",
146
+ " <td>Giants Over Patriots, Jets Over Colts Among M...</td>\n",
147
+ " </tr>\n",
148
+ " <tr>\n",
149
+ " <th>311174</th>\n",
150
+ " <td>311174</td>\n",
151
+ " <td>SPORTS</td>\n",
152
+ " <td>Aldon Smith Arrested: 49ers Linebacker Busted ...</td>\n",
153
+ " </tr>\n",
154
+ " <tr>\n",
155
+ " <th>311175</th>\n",
156
+ " <td>311175</td>\n",
157
+ " <td>SPORTS</td>\n",
158
+ " <td>Dwight Howard Rips Teammates After Magic Loss ...</td>\n",
159
+ " </tr>\n",
160
+ " </tbody>\n",
161
+ "</table>\n",
162
+ "<p>311176 rows × 3 columns</p>\n",
163
+ "</div>"
164
+ ],
165
+ "text/plain": [
166
+ " index category short_description\n",
167
+ "0 0 SCIENCE A closer look at water-splitting's solar fuel ...\n",
168
+ "1 1 SCIENCE An irresistible scent makes locusts swarm, stu...\n",
169
+ "2 2 SCIENCE Artificial intelligence warning: AI will know ...\n",
170
+ "3 3 SCIENCE Glaciers Could Have Sculpted Mars Valleys: Study\n",
171
+ "4 4 SCIENCE Perseid meteor shower 2020: What time and how ...\n",
172
+ "... ... ... ...\n",
173
+ "311171 311171 TECH RIM CEO Thorsten Heins' 'Significant' Plans Fo...\n",
174
+ "311172 311172 SPORTS Maria Sharapova Stunned By Victoria Azarenka I...\n",
175
+ "311173 311173 SPORTS Giants Over Patriots, Jets Over Colts Among M...\n",
176
+ "311174 311174 SPORTS Aldon Smith Arrested: 49ers Linebacker Busted ...\n",
177
+ "311175 311175 SPORTS Dwight Howard Rips Teammates After Magic Loss ...\n",
178
+ "\n",
179
+ "[311176 rows x 3 columns]"
180
+ ]
181
+ },
182
+ "execution_count": 5,
183
+ "metadata": {},
184
+ "output_type": "execute_result"
185
+ }
186
+ ],
187
+ "source": [
188
+ "data = pd.read_csv('labelled_newscatcher_dataset.csv', sep=\";\", usecols=['title', 'topic'])\n",
189
+ "json_data=pd.read_json('News_Category_Dataset_v3.json', lines=True)\n",
190
+ "data.drop_duplicates(subset=['title'], inplace=True)\n",
191
+ "json_data.drop_duplicates(subset=['headline'], inplace=True)\n",
192
+ "json_data = json_data[['headline', 'category']].copy()\n",
193
+ "json_data.rename(columns={'headline': 'title'}, inplace=True)\n",
194
+ "data.rename(columns={'topic': 'category'}, inplace=True)\n",
195
+ "data = pd.concat([data, json_data], axis=0)\n",
196
+ "data.drop_duplicates(subset=['title'], inplace=True)\n",
197
+ "data.reset_index(drop=True, inplace=True)\n",
198
+ "data.reset_index(inplace=True)\n",
199
+ "data.rename(columns={'title': 'short_description'}, inplace=True)\n",
200
+ "data"
201
+ ]
202
+ },
203
+ {
204
+ "cell_type": "code",
205
+ "execution_count": 6,
206
+ "id": "796f85b1-12dc-42cb-b431-65c88738b607",
207
+ "metadata": {},
208
+ "outputs": [
209
+ {
210
+ "data": {
211
+ "text/plain": [
212
+ "False"
213
+ ]
214
+ },
215
+ "execution_count": 6,
216
+ "metadata": {},
217
+ "output_type": "execute_result"
218
+ }
219
+ ],
220
+ "source": [
221
+ "any(data['short_description'].duplicated())"
222
+ ]
223
+ },
224
+ {
225
+ "cell_type": "code",
226
+ "execution_count": 7,
227
+ "id": "5f46251b-156a-4a72-ab89-6abb6d810006",
228
+ "metadata": {
229
+ "tags": []
230
+ },
231
+ "outputs": [],
232
+ "source": [
233
+ "data.to_csv('news_processed.csv', index=False)"
234
+ ]
235
+ },
236
+ {
237
+ "cell_type": "code",
238
+ "execution_count": 8,
239
+ "id": "1463ea34-4447-464a-b0a3-da5e09892a09",
240
+ "metadata": {},
241
+ "outputs": [],
242
+ "source": [
243
+ "class TextVectorizer:\n",
244
+ " '''\n",
245
+ " sentence transformers to extract sentence embeddings\n",
246
+ " '''\n",
247
+ " def vectorize(self, x):\n",
248
+ " sent_model = SentenceTransformer('multi-qa-distilbert-cos-v1_finetuned')\n",
249
+ " sen_embeddings = sent_model.encode(x)\n",
250
+ " return sen_embeddings"
251
+ ]
252
+ },
253
+ {
254
+ "cell_type": "code",
255
+ "execution_count": 9,
256
+ "id": "47a714f3-8948-470b-9caf-93ed2bbf4894",
257
+ "metadata": {
258
+ "tags": []
259
+ },
260
+ "outputs": [],
261
+ "source": [
262
+ "vectorizer = TextVectorizer()"
263
+ ]
264
+ },
265
+ {
266
+ "cell_type": "code",
267
+ "execution_count": 10,
268
+ "id": "8b1586e5-2923-4632-a3db-fd2364124d6f",
269
+ "metadata": {
270
+ "tags": []
271
+ },
272
+ "outputs": [
273
+ {
274
+ "data": {
275
+ "text/plain": [
276
+ "320"
277
+ ]
278
+ },
279
+ "execution_count": 10,
280
+ "metadata": {},
281
+ "output_type": "execute_result"
282
+ }
283
+ ],
284
+ "source": [
285
+ "# getting max length of article descriptions to be used for VARCHAR while defining schema\n",
286
+ "max_desc_len = max([len(s) for s in data['short_description']])\n",
287
+ "max_desc_len"
288
+ ]
289
+ },
290
+ {
291
+ "cell_type": "code",
292
+ "execution_count": 11,
293
+ "id": "debe0ef4-b877-495a-872e-47f720b758a9",
294
+ "metadata": {},
295
+ "outputs": [
296
+ {
297
+ "data": {
298
+ "text/plain": [
299
+ "14"
300
+ ]
301
+ },
302
+ "execution_count": 11,
303
+ "metadata": {},
304
+ "output_type": "execute_result"
305
+ }
306
+ ],
307
+ "source": [
308
+ "# getting max length of article categories to be used for VARCHAR while defining schema\n",
309
+ "max_cat_len = max([len(s) for s in data['category']])\n",
310
+ "max_cat_len"
311
+ ]
312
+ },
313
+ {
314
+ "cell_type": "code",
315
+ "execution_count": 12,
316
+ "id": "80489f00-e59f-46ab-a933-97145928176c",
317
+ "metadata": {
318
+ "tags": []
319
+ },
320
+ "outputs": [],
321
+ "source": [
322
+ "# # Reading milvus URI & API token from secrets.env\n",
323
+ "load_dotenv('secrets.env')\n",
324
+ "uri = os.environ.get(\"URI\")\n",
325
+ "token = os.environ.get(\"TOKEN\")"
326
+ ]
327
+ },
328
+ {
329
+ "cell_type": "code",
330
+ "execution_count": 13,
331
+ "id": "0bf69f22-e113-43a5-be81-77224cafd856",
332
+ "metadata": {
333
+ "tags": []
334
+ },
335
+ "outputs": [
336
+ {
337
+ "name": "stdout",
338
+ "output_type": "stream",
339
+ "text": [
340
+ "Connected to DB\n"
341
+ ]
342
+ }
343
+ ],
344
+ "source": [
345
+ "connections.connect(\"default\", uri=uri, token=token)\n",
346
+ "print(f\"Connected to DB\")"
347
+ ]
348
+ },
349
+ {
350
+ "cell_type": "code",
351
+ "execution_count": 14,
352
+ "id": "8da06a3b-2005-4c02-a168-dc84bcde7064",
353
+ "metadata": {
354
+ "tags": []
355
+ },
356
+ "outputs": [],
357
+ "source": [
358
+ "collection_name = 'news_collection_sent_tran_finetuned'\n",
359
+ "check_collection = utility.has_collection(collection_name)"
360
+ ]
361
+ },
362
+ {
363
+ "cell_type": "code",
364
+ "execution_count": 15,
365
+ "id": "33342612-1380-4d1a-a8e7-931476e07979",
366
+ "metadata": {
367
+ "tags": []
368
+ },
369
+ "outputs": [
370
+ {
371
+ "name": "stdout",
372
+ "output_type": "stream",
373
+ "text": [
374
+ "Droped Existing collection\n"
375
+ ]
376
+ }
377
+ ],
378
+ "source": [
379
+ "if check_collection:\n",
380
+ " drop_result = utility.drop_collection(collection_name)\n",
381
+ " print(\"Droped Existing collection\")"
382
+ ]
383
+ },
384
+ {
385
+ "cell_type": "code",
386
+ "execution_count": 16,
387
+ "id": "fc8ae048-d586-41e7-9678-75e1752c1693",
388
+ "metadata": {
389
+ "tags": []
390
+ },
391
+ "outputs": [
392
+ {
393
+ "name": "stdout",
394
+ "output_type": "stream",
395
+ "text": [
396
+ "Creating the collection\n",
397
+ "Schema: {'auto_id': False, 'description': 'collection of news articles', 'fields': [{'name': 'article_id', 'description': 'primary id', 'type': <DataType.INT64: 5>, 'is_primary': True, 'auto_id': False}, {'name': 'article_embed', 'description': '', 'type': <DataType.FLOAT_VECTOR: 101>, 'params': {'dim': 768}}, {'name': 'article_desc', 'description': 'short description of the article', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 370}}, {'name': 'article_category', 'description': 'category of the article', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 64}}]}\n",
398
+ "Success!\n"
399
+ ]
400
+ }
401
+ ],
402
+ "source": [
403
+ "# Creating collection schema\n",
404
+ "dim = 768 # embeddings dim\n",
405
+ "article_id = FieldSchema(name=\"article_id\", dtype=DataType.INT64, is_primary=True, description=\"primary id\") # primary key\n",
406
+ "article_embed_field = FieldSchema(name=\"article_embed\", dtype=DataType.FLOAT_VECTOR, dim=dim) # description embeddings\n",
407
+ "article_desc = FieldSchema(name=\"article_desc\", dtype=DataType.VARCHAR, max_length=(max_desc_len + 50), # using max_desc_len to specify VARCHAR len \n",
408
+ " is_primary=False, description=\"short description of the article\") # short description of article\n",
409
+ "article_cat = FieldSchema(name=\"article_category\", dtype=DataType.VARCHAR, max_length=(max_cat_len + 50), # using max_desc_len to specify VARCHAR len \n",
410
+ " is_primary=False, description=\"category of the article\") # category of article\n",
411
+ "schema = CollectionSchema(fields=[article_id, article_embed_field, article_desc, article_cat], \n",
412
+ " auto_id=False, description=\"collection of news articles\")\n",
413
+ "print(f\"Creating the collection\")\n",
414
+ "collection = Collection(name=collection_name, schema=schema)\n",
415
+ "print(f\"Schema: {schema}\")\n",
416
+ "print(\"Success!\")"
417
+ ]
418
+ },
419
+ {
420
+ "cell_type": "code",
421
+ "execution_count": 17,
422
+ "id": "cca82380-98f6-4c44-aac6-86d4ae3484d0",
423
+ "metadata": {},
424
+ "outputs": [
425
+ {
426
+ "name": "stdout",
427
+ "output_type": "stream",
428
+ "text": [
429
+ "[0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000, 11000, 12000, 13000, 14000, 15000, 16000, 17000, 18000, 19000, 20000, 21000, 22000, 23000, 24000, 25000, 26000, 27000, 28000, 29000, 30000, 31000, 32000, 33000, 34000, 35000, 36000, 37000, 38000, 39000, 40000, 41000, 42000, 43000, 44000, 45000, 46000, 47000, 48000, 49000, 50000, 51000, 52000, 53000, 54000, 55000, 56000, 57000, 58000, 59000, 60000, 61000, 62000, 63000, 64000, 65000, 66000, 67000, 68000, 69000, 70000, 71000, 72000, 73000, 74000, 75000, 76000, 77000, 78000, 79000, 80000, 81000, 82000, 83000, 84000, 85000, 86000, 87000, 88000, 89000, 90000, 91000, 92000, 93000, 94000, 95000, 96000, 97000, 98000, 99000, 100000, 101000, 102000, 103000, 104000, 105000, 106000, 107000, 108000, 109000, 110000, 111000, 112000, 113000, 114000, 115000, 116000, 117000, 118000, 119000, 120000, 121000, 122000, 123000, 124000, 125000, 126000, 127000, 128000, 129000, 130000, 131000, 132000, 133000, 134000, 135000, 136000, 137000, 138000, 139000, 140000, 141000, 142000, 143000, 144000, 145000, 146000, 147000, 148000, 149000, 150000, 151000, 152000, 153000, 154000, 155000, 156000, 157000, 158000, 159000, 160000, 161000, 162000, 163000, 164000, 165000, 166000, 167000, 168000, 169000, 170000, 171000, 172000, 173000, 174000, 175000, 176000, 177000, 178000, 179000, 180000, 181000, 182000, 183000, 184000, 185000, 186000, 187000, 188000, 189000, 190000, 191000, 192000, 193000, 194000, 195000, 196000, 197000, 198000, 199000, 200000, 201000, 202000, 203000, 204000, 205000, 206000, 207000, 208000, 209000, 210000, 211000, 212000, 213000, 214000, 215000, 216000, 217000, 218000, 219000, 220000, 221000, 222000, 223000, 224000, 225000, 226000, 227000, 228000, 229000, 230000, 231000, 232000, 233000, 234000, 235000, 236000, 237000, 238000, 239000, 240000, 241000, 242000, 243000, 244000, 245000, 246000, 247000, 248000, 249000, 250000, 251000, 252000, 253000, 254000, 255000, 256000, 257000, 258000, 259000, 260000, 261000, 262000, 263000, 264000, 265000, 266000, 267000, 268000, 269000, 270000, 271000, 272000, 273000, 274000, 275000, 276000, 277000, 278000, 279000, 280000, 281000, 282000, 283000, 284000, 285000, 286000, 287000, 288000, 289000, 290000, 291000, 292000, 293000, 294000, 295000, 296000, 297000, 298000, 299000, 300000, 301000, 302000, 303000, 304000, 305000, 306000, 307000, 308000, 309000, 310000, 311000, 311176]\n"
430
+ ]
431
+ }
432
+ ],
433
+ "source": [
434
+ "cuts = [*range(0, len(data), 1000)]\n",
435
+ "cuts.append(len(data))\n",
436
+ "print(cuts)"
437
+ ]
438
+ },
439
+ {
440
+ "cell_type": "code",
441
+ "execution_count": null,
442
+ "id": "066c67ac-01a6-4151-8e85-5869ddce1c0a",
443
+ "metadata": {},
444
+ "outputs": [],
445
+ "source": [
446
+ "article_id = []\n",
447
+ "article_desc = []\n",
448
+ "article_embed = []\n",
449
+ "article_cat = []\n",
450
+ "try:\n",
451
+ " for i in tqdm(range(len(cuts)-1)):\n",
452
+ " df = data.iloc[cuts[i]: cuts[i+1]].copy()\n",
453
+ " article_id = [*df['index']]\n",
454
+ " article_desc = [*df['short_description']]\n",
455
+ " article_cat = [*df['category']]\n",
456
+ " results = []\n",
457
+ " article_embed = vectorizer.vectorize(article_desc)\n",
458
+ " docs = [article_id, article_embed, article_desc, article_cat]\n",
459
+ " ins_resp = collection.insert(docs)\n",
460
+ " print(ins_resp)\n",
461
+ " article_id = []\n",
462
+ " article_desc = []\n",
463
+ " article_embed = []\n",
464
+ " article_cat = []\n",
465
+ " if i == 0:\n",
466
+ " index_params = {\"index_type\": \"AUTOINDEX\", \"metric_type\": \"L2\", \"params\": {}} \n",
467
+ " collection.create_index(field_name='article_embed', index_params=index_params)\n",
468
+ " collection = Collection(name=collection_name)\n",
469
+ " collection.load()\n",
470
+ "except:\n",
471
+ " raise"
472
+ ]
473
+ },
474
+ {
475
+ "cell_type": "code",
476
+ "execution_count": null,
477
+ "id": "d50177fa-fd0c-48ad-bc9b-a7bdc826a628",
478
+ "metadata": {},
479
+ "outputs": [],
480
+ "source": []
481
+ }
482
+ ],
483
+ "metadata": {
484
+ "kernelspec": {
485
+ "display_name": "Python (tf_gpu)",
486
+ "language": "python",
487
+ "name": "tf_gpu"
488
+ },
489
+ "language_info": {
490
+ "codemirror_mode": {
491
+ "name": "ipython",
492
+ "version": 3
493
+ },
494
+ "file_extension": ".py",
495
+ "mimetype": "text/x-python",
496
+ "name": "python",
497
+ "nbconvert_exporter": "python",
498
+ "pygments_lexer": "ipython3",
499
+ "version": "3.9.18"
500
+ }
501
+ },
502
+ "nbformat": 4,
503
+ "nbformat_minor": 5
504
+ }
tsdae_finetune_sent_transformer.ipynb ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "f0dc396f",
6
+ "metadata": {},
7
+ "source": [
8
+ "### TSDAE: Fine-tune sentence transformers using unsupervised learning with Pytorch\n",
9
+ "https://www.sbert.net/examples/unsupervised_learning/TSDAE/README.html"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": 2,
15
+ "id": "34329058",
16
+ "metadata": {
17
+ "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
18
+ "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5",
19
+ "execution": {
20
+ "iopub.execute_input": "2024-01-13T16:51:37.887254Z",
21
+ "iopub.status.busy": "2024-01-13T16:51:37.886390Z",
22
+ "iopub.status.idle": "2024-01-13T16:51:37.891827Z",
23
+ "shell.execute_reply": "2024-01-13T16:51:37.890706Z",
24
+ "shell.execute_reply.started": "2024-01-13T16:51:37.887212Z"
25
+ }
26
+ },
27
+ "outputs": [],
28
+ "source": [
29
+ "# !pip install sentence_transformers==2.2.2"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": 4,
35
+ "id": "ebd138d8",
36
+ "metadata": {
37
+ "execution": {
38
+ "iopub.execute_input": "2024-01-13T16:51:50.310221Z",
39
+ "iopub.status.busy": "2024-01-13T16:51:50.309586Z",
40
+ "iopub.status.idle": "2024-01-13T16:51:50.315850Z",
41
+ "shell.execute_reply": "2024-01-13T16:51:50.314927Z",
42
+ "shell.execute_reply.started": "2024-01-13T16:51:50.310185Z"
43
+ }
44
+ },
45
+ "outputs": [],
46
+ "source": [
47
+ "import pandas as pd\n",
48
+ "import numpy as np\n",
49
+ "import string\n",
50
+ "from tqdm import tqdm\n",
51
+ "from numpy.linalg import norm\n",
52
+ "from sentence_transformers import SentenceTransformer, LoggingHandler\n",
53
+ "from sentence_transformers import models, util, datasets, evaluation, losses\n",
54
+ "from torch.utils.data import DataLoader"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": 6,
60
+ "id": "7ef2d063",
61
+ "metadata": {
62
+ "execution": {
63
+ "iopub.execute_input": "2024-01-13T13:24:44.770211Z",
64
+ "iopub.status.busy": "2024-01-13T13:24:44.769806Z",
65
+ "iopub.status.idle": "2024-01-13T13:24:44.775042Z",
66
+ "shell.execute_reply": "2024-01-13T13:24:44.773860Z",
67
+ "shell.execute_reply.started": "2024-01-13T13:24:44.770177Z"
68
+ }
69
+ },
70
+ "outputs": [],
71
+ "source": [
72
+ "# import nltk\n",
73
+ "# nltk.download('punkt')"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": 5,
79
+ "id": "453c1add",
80
+ "metadata": {
81
+ "execution": {
82
+ "iopub.execute_input": "2024-01-13T16:51:54.070689Z",
83
+ "iopub.status.busy": "2024-01-13T16:51:54.069945Z",
84
+ "iopub.status.idle": "2024-01-13T16:51:54.809726Z",
85
+ "shell.execute_reply": "2024-01-13T16:51:54.808920Z",
86
+ "shell.execute_reply.started": "2024-01-13T16:51:54.070657Z"
87
+ }
88
+ },
89
+ "outputs": [],
90
+ "source": [
91
+ "data = pd.read_csv('news_processed.csv', usecols=['short_description'])"
92
+ ]
93
+ },
94
+ {
95
+ "cell_type": "code",
96
+ "execution_count": 6,
97
+ "id": "61629b79",
98
+ "metadata": {
99
+ "execution": {
100
+ "iopub.execute_input": "2024-01-13T16:51:55.125758Z",
101
+ "iopub.status.busy": "2024-01-13T16:51:55.124990Z",
102
+ "iopub.status.idle": "2024-01-13T16:51:55.180470Z",
103
+ "shell.execute_reply": "2024-01-13T16:51:55.179559Z",
104
+ "shell.execute_reply.started": "2024-01-13T16:51:55.125716Z"
105
+ }
106
+ },
107
+ "outputs": [],
108
+ "source": [
109
+ "data.dropna(inplace=True)"
110
+ ]
111
+ },
112
+ {
113
+ "cell_type": "code",
114
+ "execution_count": 7,
115
+ "id": "14162e2c",
116
+ "metadata": {
117
+ "execution": {
118
+ "iopub.execute_input": "2024-01-13T16:51:55.817535Z",
119
+ "iopub.status.busy": "2024-01-13T16:51:55.816764Z",
120
+ "iopub.status.idle": "2024-01-13T16:51:55.836578Z",
121
+ "shell.execute_reply": "2024-01-13T16:51:55.835697Z",
122
+ "shell.execute_reply.started": "2024-01-13T16:51:55.817499Z"
123
+ }
124
+ },
125
+ "outputs": [
126
+ {
127
+ "data": {
128
+ "text/plain": [
129
+ "'Experimental coronavirus vaccine prevents severe disease in mice'"
130
+ ]
131
+ },
132
+ "execution_count": 7,
133
+ "metadata": {},
134
+ "output_type": "execute_result"
135
+ }
136
+ ],
137
+ "source": [
138
+ "data['short_description'][1000]"
139
+ ]
140
+ },
141
+ {
142
+ "cell_type": "code",
143
+ "execution_count": 14,
144
+ "id": "8fd1a801",
145
+ "metadata": {
146
+ "execution": {
147
+ "iopub.execute_input": "2024-01-13T17:08:09.906826Z",
148
+ "iopub.status.busy": "2024-01-13T17:08:09.906182Z",
149
+ "iopub.status.idle": "2024-01-13T17:08:09.914834Z",
150
+ "shell.execute_reply": "2024-01-13T17:08:09.913737Z",
151
+ "shell.execute_reply.started": "2024-01-13T17:08:09.906795Z"
152
+ }
153
+ },
154
+ "outputs": [],
155
+ "source": [
156
+ "def finetune_model(data: pd.DataFrame, col_to_use: str='short_description', \n",
157
+ " model_id: str=\"sentence-transformers/multi-qa-distilbert-cos-v1\", \n",
158
+ " batch_size: int=8, epochs: int=2):\n",
159
+ " \n",
160
+ "# https://www.sbert.net/examples/unsupervised_learning/TSDAE/README.html\n",
161
+ " \n",
162
+ " word_embedding_model = models.Transformer(model_id)\n",
163
+ " pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), 'cls')\n",
164
+ " model = SentenceTransformer(modules=[word_embedding_model, pooling_model])\n",
165
+ " \n",
166
+ " train_examples = data[col_to_use].tolist()\n",
167
+ " train_dataset = datasets.DenoisingAutoEncoderDataset(train_examples)\n",
168
+ " train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
169
+ " train_loss = losses.DenoisingAutoEncoderLoss(model, decoder_name_or_path=\"sentence-transformers/paraphrase-distilroberta-base-v2\", tie_encoder_decoder=False)\n",
170
+ "# train_loss = losses.DenoisingAutoEncoderLoss(model, decoder_name_or_path=model_id, tie_encoder_decoder=True)\n",
171
+ "\n",
172
+ " \n",
173
+ " model.fit(\n",
174
+ " train_objectives=[(train_dataloader, train_loss)],\n",
175
+ " epochs=epochs,\n",
176
+ " weight_decay=0,\n",
177
+ " scheduler='constantlr',\n",
178
+ " optimizer_params={'lr': 3e-5},\n",
179
+ " show_progress_bar=True\n",
180
+ " )\n",
181
+ " model_save_path = model_id + '_finetuned'\n",
182
+ " model.save(model_save_path)\n",
183
+ " return model_save_path"
184
+ ]
185
+ },
186
+ {
187
+ "cell_type": "code",
188
+ "execution_count": null,
189
+ "id": "b17a9779",
190
+ "metadata": {
191
+ "scrolled": true
192
+ },
193
+ "outputs": [],
194
+ "source": [
195
+ "# fine-tune sentence transformer\n",
196
+ "finetuned_model_id = finetune_model(data=data)\n",
197
+ "finetuned_model = SentenceTransformer(finetuned_model_id)"
198
+ ]
199
+ },
200
+ {
201
+ "cell_type": "code",
202
+ "execution_count": null,
203
+ "id": "4d87c128",
204
+ "metadata": {},
205
+ "outputs": [],
206
+ "source": []
207
+ }
208
+ ],
209
+ "metadata": {
210
+ "kaggle": {
211
+ "accelerator": "gpu",
212
+ "dataSources": [
213
+ {
214
+ "datasetId": 4298708,
215
+ "sourceId": 7394110,
216
+ "sourceType": "datasetVersion"
217
+ }
218
+ ],
219
+ "dockerImageVersionId": 30636,
220
+ "isGpuEnabled": true,
221
+ "isInternetEnabled": true,
222
+ "language": "python",
223
+ "sourceType": "notebook"
224
+ },
225
+ "kernelspec": {
226
+ "display_name": "Python 3 (ipykernel)",
227
+ "language": "python",
228
+ "name": "python3"
229
+ },
230
+ "language_info": {
231
+ "codemirror_mode": {
232
+ "name": "ipython",
233
+ "version": 3
234
+ },
235
+ "file_extension": ".py",
236
+ "mimetype": "text/x-python",
237
+ "name": "python",
238
+ "nbconvert_exporter": "python",
239
+ "pygments_lexer": "ipython3",
240
+ "version": "3.10.12"
241
+ }
242
+ },
243
+ "nbformat": 4,
244
+ "nbformat_minor": 5
245
+ }