ksvmuralidhar
commited on
Commit
•
df9c590
1
Parent(s):
1325bee
Upload 2 files
Browse files
insert_into_db_sent_tran_tsdae.ipynb
ADDED
@@ -0,0 +1,504 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"id": "1aafbf18-de38-4fcf-8245-e2e9a584971f",
|
7 |
+
"metadata": {
|
8 |
+
"tags": []
|
9 |
+
},
|
10 |
+
"outputs": [],
|
11 |
+
"source": [
|
12 |
+
"# ! pip install pymilvus==2.3.4\n",
|
13 |
+
"# ! pip install pyarrow==12.0.0\n",
|
14 |
+
"# !pip install -U sentence-transformers"
|
15 |
+
]
|
16 |
+
},
|
17 |
+
{
|
18 |
+
"cell_type": "code",
|
19 |
+
"execution_count": 2,
|
20 |
+
"id": "f1d8f101-f51b-4a50-b150-86e87c50c453",
|
21 |
+
"metadata": {
|
22 |
+
"tags": []
|
23 |
+
},
|
24 |
+
"outputs": [],
|
25 |
+
"source": [
|
26 |
+
"import numpy as np\n",
|
27 |
+
"import tensorflow as tf\n",
|
28 |
+
"from tqdm import tqdm\n",
|
29 |
+
"from dotenv import load_dotenv\n",
|
30 |
+
"import os\n",
|
31 |
+
"import pandas as pd\n",
|
32 |
+
"from pymilvus import connections, utility\n",
|
33 |
+
"from pymilvus import Collection, DataType, FieldSchema, CollectionSchema\n",
|
34 |
+
"import multiprocessing\n",
|
35 |
+
"from sentence_transformers import SentenceTransformer"
|
36 |
+
]
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"cell_type": "code",
|
40 |
+
"execution_count": 4,
|
41 |
+
"id": "4ad4e3ac-9685-4f12-8043-5fbcc373d3e1",
|
42 |
+
"metadata": {},
|
43 |
+
"outputs": [
|
44 |
+
{
|
45 |
+
"data": {
|
46 |
+
"text/plain": [
|
47 |
+
"[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]"
|
48 |
+
]
|
49 |
+
},
|
50 |
+
"execution_count": 4,
|
51 |
+
"metadata": {},
|
52 |
+
"output_type": "execute_result"
|
53 |
+
}
|
54 |
+
],
|
55 |
+
"source": [
|
56 |
+
"tf.config.list_physical_devices('GPU')"
|
57 |
+
]
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"cell_type": "code",
|
61 |
+
"execution_count": 5,
|
62 |
+
"id": "da71d832-b8a7-452b-b736-538a3c069b54",
|
63 |
+
"metadata": {
|
64 |
+
"tags": []
|
65 |
+
},
|
66 |
+
"outputs": [
|
67 |
+
{
|
68 |
+
"data": {
|
69 |
+
"text/html": [
|
70 |
+
"<div>\n",
|
71 |
+
"<style scoped>\n",
|
72 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
73 |
+
" vertical-align: middle;\n",
|
74 |
+
" }\n",
|
75 |
+
"\n",
|
76 |
+
" .dataframe tbody tr th {\n",
|
77 |
+
" vertical-align: top;\n",
|
78 |
+
" }\n",
|
79 |
+
"\n",
|
80 |
+
" .dataframe thead th {\n",
|
81 |
+
" text-align: right;\n",
|
82 |
+
" }\n",
|
83 |
+
"</style>\n",
|
84 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
85 |
+
" <thead>\n",
|
86 |
+
" <tr style=\"text-align: right;\">\n",
|
87 |
+
" <th></th>\n",
|
88 |
+
" <th>index</th>\n",
|
89 |
+
" <th>category</th>\n",
|
90 |
+
" <th>short_description</th>\n",
|
91 |
+
" </tr>\n",
|
92 |
+
" </thead>\n",
|
93 |
+
" <tbody>\n",
|
94 |
+
" <tr>\n",
|
95 |
+
" <th>0</th>\n",
|
96 |
+
" <td>0</td>\n",
|
97 |
+
" <td>SCIENCE</td>\n",
|
98 |
+
" <td>A closer look at water-splitting's solar fuel ...</td>\n",
|
99 |
+
" </tr>\n",
|
100 |
+
" <tr>\n",
|
101 |
+
" <th>1</th>\n",
|
102 |
+
" <td>1</td>\n",
|
103 |
+
" <td>SCIENCE</td>\n",
|
104 |
+
" <td>An irresistible scent makes locusts swarm, stu...</td>\n",
|
105 |
+
" </tr>\n",
|
106 |
+
" <tr>\n",
|
107 |
+
" <th>2</th>\n",
|
108 |
+
" <td>2</td>\n",
|
109 |
+
" <td>SCIENCE</td>\n",
|
110 |
+
" <td>Artificial intelligence warning: AI will know ...</td>\n",
|
111 |
+
" </tr>\n",
|
112 |
+
" <tr>\n",
|
113 |
+
" <th>3</th>\n",
|
114 |
+
" <td>3</td>\n",
|
115 |
+
" <td>SCIENCE</td>\n",
|
116 |
+
" <td>Glaciers Could Have Sculpted Mars Valleys: Study</td>\n",
|
117 |
+
" </tr>\n",
|
118 |
+
" <tr>\n",
|
119 |
+
" <th>4</th>\n",
|
120 |
+
" <td>4</td>\n",
|
121 |
+
" <td>SCIENCE</td>\n",
|
122 |
+
" <td>Perseid meteor shower 2020: What time and how ...</td>\n",
|
123 |
+
" </tr>\n",
|
124 |
+
" <tr>\n",
|
125 |
+
" <th>...</th>\n",
|
126 |
+
" <td>...</td>\n",
|
127 |
+
" <td>...</td>\n",
|
128 |
+
" <td>...</td>\n",
|
129 |
+
" </tr>\n",
|
130 |
+
" <tr>\n",
|
131 |
+
" <th>311171</th>\n",
|
132 |
+
" <td>311171</td>\n",
|
133 |
+
" <td>TECH</td>\n",
|
134 |
+
" <td>RIM CEO Thorsten Heins' 'Significant' Plans Fo...</td>\n",
|
135 |
+
" </tr>\n",
|
136 |
+
" <tr>\n",
|
137 |
+
" <th>311172</th>\n",
|
138 |
+
" <td>311172</td>\n",
|
139 |
+
" <td>SPORTS</td>\n",
|
140 |
+
" <td>Maria Sharapova Stunned By Victoria Azarenka I...</td>\n",
|
141 |
+
" </tr>\n",
|
142 |
+
" <tr>\n",
|
143 |
+
" <th>311173</th>\n",
|
144 |
+
" <td>311173</td>\n",
|
145 |
+
" <td>SPORTS</td>\n",
|
146 |
+
" <td>Giants Over Patriots, Jets Over Colts Among M...</td>\n",
|
147 |
+
" </tr>\n",
|
148 |
+
" <tr>\n",
|
149 |
+
" <th>311174</th>\n",
|
150 |
+
" <td>311174</td>\n",
|
151 |
+
" <td>SPORTS</td>\n",
|
152 |
+
" <td>Aldon Smith Arrested: 49ers Linebacker Busted ...</td>\n",
|
153 |
+
" </tr>\n",
|
154 |
+
" <tr>\n",
|
155 |
+
" <th>311175</th>\n",
|
156 |
+
" <td>311175</td>\n",
|
157 |
+
" <td>SPORTS</td>\n",
|
158 |
+
" <td>Dwight Howard Rips Teammates After Magic Loss ...</td>\n",
|
159 |
+
" </tr>\n",
|
160 |
+
" </tbody>\n",
|
161 |
+
"</table>\n",
|
162 |
+
"<p>311176 rows × 3 columns</p>\n",
|
163 |
+
"</div>"
|
164 |
+
],
|
165 |
+
"text/plain": [
|
166 |
+
" index category short_description\n",
|
167 |
+
"0 0 SCIENCE A closer look at water-splitting's solar fuel ...\n",
|
168 |
+
"1 1 SCIENCE An irresistible scent makes locusts swarm, stu...\n",
|
169 |
+
"2 2 SCIENCE Artificial intelligence warning: AI will know ...\n",
|
170 |
+
"3 3 SCIENCE Glaciers Could Have Sculpted Mars Valleys: Study\n",
|
171 |
+
"4 4 SCIENCE Perseid meteor shower 2020: What time and how ...\n",
|
172 |
+
"... ... ... ...\n",
|
173 |
+
"311171 311171 TECH RIM CEO Thorsten Heins' 'Significant' Plans Fo...\n",
|
174 |
+
"311172 311172 SPORTS Maria Sharapova Stunned By Victoria Azarenka I...\n",
|
175 |
+
"311173 311173 SPORTS Giants Over Patriots, Jets Over Colts Among M...\n",
|
176 |
+
"311174 311174 SPORTS Aldon Smith Arrested: 49ers Linebacker Busted ...\n",
|
177 |
+
"311175 311175 SPORTS Dwight Howard Rips Teammates After Magic Loss ...\n",
|
178 |
+
"\n",
|
179 |
+
"[311176 rows x 3 columns]"
|
180 |
+
]
|
181 |
+
},
|
182 |
+
"execution_count": 5,
|
183 |
+
"metadata": {},
|
184 |
+
"output_type": "execute_result"
|
185 |
+
}
|
186 |
+
],
|
187 |
+
"source": [
|
188 |
+
"data = pd.read_csv('labelled_newscatcher_dataset.csv', sep=\";\", usecols=['title', 'topic'])\n",
|
189 |
+
"json_data=pd.read_json('News_Category_Dataset_v3.json', lines=True)\n",
|
190 |
+
"data.drop_duplicates(subset=['title'], inplace=True)\n",
|
191 |
+
"json_data.drop_duplicates(subset=['headline'], inplace=True)\n",
|
192 |
+
"json_data = json_data[['headline', 'category']].copy()\n",
|
193 |
+
"json_data.rename(columns={'headline': 'title'}, inplace=True)\n",
|
194 |
+
"data.rename(columns={'topic': 'category'}, inplace=True)\n",
|
195 |
+
"data = pd.concat([data, json_data], axis=0)\n",
|
196 |
+
"data.drop_duplicates(subset=['title'], inplace=True)\n",
|
197 |
+
"data.reset_index(drop=True, inplace=True)\n",
|
198 |
+
"data.reset_index(inplace=True)\n",
|
199 |
+
"data.rename(columns={'title': 'short_description'}, inplace=True)\n",
|
200 |
+
"data"
|
201 |
+
]
|
202 |
+
},
|
203 |
+
{
|
204 |
+
"cell_type": "code",
|
205 |
+
"execution_count": 6,
|
206 |
+
"id": "796f85b1-12dc-42cb-b431-65c88738b607",
|
207 |
+
"metadata": {},
|
208 |
+
"outputs": [
|
209 |
+
{
|
210 |
+
"data": {
|
211 |
+
"text/plain": [
|
212 |
+
"False"
|
213 |
+
]
|
214 |
+
},
|
215 |
+
"execution_count": 6,
|
216 |
+
"metadata": {},
|
217 |
+
"output_type": "execute_result"
|
218 |
+
}
|
219 |
+
],
|
220 |
+
"source": [
|
221 |
+
"any(data['short_description'].duplicated())"
|
222 |
+
]
|
223 |
+
},
|
224 |
+
{
|
225 |
+
"cell_type": "code",
|
226 |
+
"execution_count": 7,
|
227 |
+
"id": "5f46251b-156a-4a72-ab89-6abb6d810006",
|
228 |
+
"metadata": {
|
229 |
+
"tags": []
|
230 |
+
},
|
231 |
+
"outputs": [],
|
232 |
+
"source": [
|
233 |
+
"data.to_csv('news_processed.csv', index=False)"
|
234 |
+
]
|
235 |
+
},
|
236 |
+
{
|
237 |
+
"cell_type": "code",
|
238 |
+
"execution_count": 8,
|
239 |
+
"id": "1463ea34-4447-464a-b0a3-da5e09892a09",
|
240 |
+
"metadata": {},
|
241 |
+
"outputs": [],
|
242 |
+
"source": [
|
243 |
+
"class TextVectorizer:\n",
|
244 |
+
" '''\n",
|
245 |
+
" sentence transformers to extract sentence embeddings\n",
|
246 |
+
" '''\n",
|
247 |
+
" def vectorize(self, x):\n",
|
248 |
+
" sent_model = SentenceTransformer('multi-qa-distilbert-cos-v1_finetuned')\n",
|
249 |
+
" sen_embeddings = sent_model.encode(x)\n",
|
250 |
+
" return sen_embeddings"
|
251 |
+
]
|
252 |
+
},
|
253 |
+
{
|
254 |
+
"cell_type": "code",
|
255 |
+
"execution_count": 9,
|
256 |
+
"id": "47a714f3-8948-470b-9caf-93ed2bbf4894",
|
257 |
+
"metadata": {
|
258 |
+
"tags": []
|
259 |
+
},
|
260 |
+
"outputs": [],
|
261 |
+
"source": [
|
262 |
+
"vectorizer = TextVectorizer()"
|
263 |
+
]
|
264 |
+
},
|
265 |
+
{
|
266 |
+
"cell_type": "code",
|
267 |
+
"execution_count": 10,
|
268 |
+
"id": "8b1586e5-2923-4632-a3db-fd2364124d6f",
|
269 |
+
"metadata": {
|
270 |
+
"tags": []
|
271 |
+
},
|
272 |
+
"outputs": [
|
273 |
+
{
|
274 |
+
"data": {
|
275 |
+
"text/plain": [
|
276 |
+
"320"
|
277 |
+
]
|
278 |
+
},
|
279 |
+
"execution_count": 10,
|
280 |
+
"metadata": {},
|
281 |
+
"output_type": "execute_result"
|
282 |
+
}
|
283 |
+
],
|
284 |
+
"source": [
|
285 |
+
"# getting max length of article descriptions to be used for VARCHAR while defining schema\n",
|
286 |
+
"max_desc_len = max([len(s) for s in data['short_description']])\n",
|
287 |
+
"max_desc_len"
|
288 |
+
]
|
289 |
+
},
|
290 |
+
{
|
291 |
+
"cell_type": "code",
|
292 |
+
"execution_count": 11,
|
293 |
+
"id": "debe0ef4-b877-495a-872e-47f720b758a9",
|
294 |
+
"metadata": {},
|
295 |
+
"outputs": [
|
296 |
+
{
|
297 |
+
"data": {
|
298 |
+
"text/plain": [
|
299 |
+
"14"
|
300 |
+
]
|
301 |
+
},
|
302 |
+
"execution_count": 11,
|
303 |
+
"metadata": {},
|
304 |
+
"output_type": "execute_result"
|
305 |
+
}
|
306 |
+
],
|
307 |
+
"source": [
|
308 |
+
"# getting max length of article categories to be used for VARCHAR while defining schema\n",
|
309 |
+
"max_cat_len = max([len(s) for s in data['category']])\n",
|
310 |
+
"max_cat_len"
|
311 |
+
]
|
312 |
+
},
|
313 |
+
{
|
314 |
+
"cell_type": "code",
|
315 |
+
"execution_count": 12,
|
316 |
+
"id": "80489f00-e59f-46ab-a933-97145928176c",
|
317 |
+
"metadata": {
|
318 |
+
"tags": []
|
319 |
+
},
|
320 |
+
"outputs": [],
|
321 |
+
"source": [
|
322 |
+
"# # Reading milvus URI & API token from secrets.env\n",
|
323 |
+
"load_dotenv('secrets.env')\n",
|
324 |
+
"uri = os.environ.get(\"URI\")\n",
|
325 |
+
"token = os.environ.get(\"TOKEN\")"
|
326 |
+
]
|
327 |
+
},
|
328 |
+
{
|
329 |
+
"cell_type": "code",
|
330 |
+
"execution_count": 13,
|
331 |
+
"id": "0bf69f22-e113-43a5-be81-77224cafd856",
|
332 |
+
"metadata": {
|
333 |
+
"tags": []
|
334 |
+
},
|
335 |
+
"outputs": [
|
336 |
+
{
|
337 |
+
"name": "stdout",
|
338 |
+
"output_type": "stream",
|
339 |
+
"text": [
|
340 |
+
"Connected to DB\n"
|
341 |
+
]
|
342 |
+
}
|
343 |
+
],
|
344 |
+
"source": [
|
345 |
+
"connections.connect(\"default\", uri=uri, token=token)\n",
|
346 |
+
"print(f\"Connected to DB\")"
|
347 |
+
]
|
348 |
+
},
|
349 |
+
{
|
350 |
+
"cell_type": "code",
|
351 |
+
"execution_count": 14,
|
352 |
+
"id": "8da06a3b-2005-4c02-a168-dc84bcde7064",
|
353 |
+
"metadata": {
|
354 |
+
"tags": []
|
355 |
+
},
|
356 |
+
"outputs": [],
|
357 |
+
"source": [
|
358 |
+
"collection_name = 'news_collection_sent_tran_finetuned'\n",
|
359 |
+
"check_collection = utility.has_collection(collection_name)"
|
360 |
+
]
|
361 |
+
},
|
362 |
+
{
|
363 |
+
"cell_type": "code",
|
364 |
+
"execution_count": 15,
|
365 |
+
"id": "33342612-1380-4d1a-a8e7-931476e07979",
|
366 |
+
"metadata": {
|
367 |
+
"tags": []
|
368 |
+
},
|
369 |
+
"outputs": [
|
370 |
+
{
|
371 |
+
"name": "stdout",
|
372 |
+
"output_type": "stream",
|
373 |
+
"text": [
|
374 |
+
"Droped Existing collection\n"
|
375 |
+
]
|
376 |
+
}
|
377 |
+
],
|
378 |
+
"source": [
|
379 |
+
"if check_collection:\n",
|
380 |
+
" drop_result = utility.drop_collection(collection_name)\n",
|
381 |
+
" print(\"Droped Existing collection\")"
|
382 |
+
]
|
383 |
+
},
|
384 |
+
{
|
385 |
+
"cell_type": "code",
|
386 |
+
"execution_count": 16,
|
387 |
+
"id": "fc8ae048-d586-41e7-9678-75e1752c1693",
|
388 |
+
"metadata": {
|
389 |
+
"tags": []
|
390 |
+
},
|
391 |
+
"outputs": [
|
392 |
+
{
|
393 |
+
"name": "stdout",
|
394 |
+
"output_type": "stream",
|
395 |
+
"text": [
|
396 |
+
"Creating the collection\n",
|
397 |
+
"Schema: {'auto_id': False, 'description': 'collection of news articles', 'fields': [{'name': 'article_id', 'description': 'primary id', 'type': <DataType.INT64: 5>, 'is_primary': True, 'auto_id': False}, {'name': 'article_embed', 'description': '', 'type': <DataType.FLOAT_VECTOR: 101>, 'params': {'dim': 768}}, {'name': 'article_desc', 'description': 'short description of the article', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 370}}, {'name': 'article_category', 'description': 'category of the article', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 64}}]}\n",
|
398 |
+
"Success!\n"
|
399 |
+
]
|
400 |
+
}
|
401 |
+
],
|
402 |
+
"source": [
|
403 |
+
"# Creating collection schema\n",
|
404 |
+
"dim = 768 # embeddings dim\n",
|
405 |
+
"article_id = FieldSchema(name=\"article_id\", dtype=DataType.INT64, is_primary=True, description=\"primary id\") # primary key\n",
|
406 |
+
"article_embed_field = FieldSchema(name=\"article_embed\", dtype=DataType.FLOAT_VECTOR, dim=dim) # description embeddings\n",
|
407 |
+
"article_desc = FieldSchema(name=\"article_desc\", dtype=DataType.VARCHAR, max_length=(max_desc_len + 50), # using max_desc_len to specify VARCHAR len \n",
|
408 |
+
" is_primary=False, description=\"short description of the article\") # short description of article\n",
|
409 |
+
"article_cat = FieldSchema(name=\"article_category\", dtype=DataType.VARCHAR, max_length=(max_cat_len + 50), # using max_desc_len to specify VARCHAR len \n",
|
410 |
+
" is_primary=False, description=\"category of the article\") # category of article\n",
|
411 |
+
"schema = CollectionSchema(fields=[article_id, article_embed_field, article_desc, article_cat], \n",
|
412 |
+
" auto_id=False, description=\"collection of news articles\")\n",
|
413 |
+
"print(f\"Creating the collection\")\n",
|
414 |
+
"collection = Collection(name=collection_name, schema=schema)\n",
|
415 |
+
"print(f\"Schema: {schema}\")\n",
|
416 |
+
"print(\"Success!\")"
|
417 |
+
]
|
418 |
+
},
|
419 |
+
{
|
420 |
+
"cell_type": "code",
|
421 |
+
"execution_count": 17,
|
422 |
+
"id": "cca82380-98f6-4c44-aac6-86d4ae3484d0",
|
423 |
+
"metadata": {},
|
424 |
+
"outputs": [
|
425 |
+
{
|
426 |
+
"name": "stdout",
|
427 |
+
"output_type": "stream",
|
428 |
+
"text": [
|
429 |
+
"[0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000, 11000, 12000, 13000, 14000, 15000, 16000, 17000, 18000, 19000, 20000, 21000, 22000, 23000, 24000, 25000, 26000, 27000, 28000, 29000, 30000, 31000, 32000, 33000, 34000, 35000, 36000, 37000, 38000, 39000, 40000, 41000, 42000, 43000, 44000, 45000, 46000, 47000, 48000, 49000, 50000, 51000, 52000, 53000, 54000, 55000, 56000, 57000, 58000, 59000, 60000, 61000, 62000, 63000, 64000, 65000, 66000, 67000, 68000, 69000, 70000, 71000, 72000, 73000, 74000, 75000, 76000, 77000, 78000, 79000, 80000, 81000, 82000, 83000, 84000, 85000, 86000, 87000, 88000, 89000, 90000, 91000, 92000, 93000, 94000, 95000, 96000, 97000, 98000, 99000, 100000, 101000, 102000, 103000, 104000, 105000, 106000, 107000, 108000, 109000, 110000, 111000, 112000, 113000, 114000, 115000, 116000, 117000, 118000, 119000, 120000, 121000, 122000, 123000, 124000, 125000, 126000, 127000, 128000, 129000, 130000, 131000, 132000, 133000, 134000, 135000, 136000, 137000, 138000, 139000, 140000, 141000, 142000, 143000, 144000, 145000, 146000, 147000, 148000, 149000, 150000, 151000, 152000, 153000, 154000, 155000, 156000, 157000, 158000, 159000, 160000, 161000, 162000, 163000, 164000, 165000, 166000, 167000, 168000, 169000, 170000, 171000, 172000, 173000, 174000, 175000, 176000, 177000, 178000, 179000, 180000, 181000, 182000, 183000, 184000, 185000, 186000, 187000, 188000, 189000, 190000, 191000, 192000, 193000, 194000, 195000, 196000, 197000, 198000, 199000, 200000, 201000, 202000, 203000, 204000, 205000, 206000, 207000, 208000, 209000, 210000, 211000, 212000, 213000, 214000, 215000, 216000, 217000, 218000, 219000, 220000, 221000, 222000, 223000, 224000, 225000, 226000, 227000, 228000, 229000, 230000, 231000, 232000, 233000, 234000, 235000, 236000, 237000, 238000, 239000, 240000, 241000, 242000, 243000, 244000, 245000, 246000, 247000, 248000, 249000, 250000, 251000, 252000, 253000, 254000, 255000, 256000, 257000, 258000, 259000, 260000, 261000, 262000, 263000, 264000, 265000, 266000, 267000, 268000, 269000, 270000, 271000, 272000, 273000, 274000, 275000, 276000, 277000, 278000, 279000, 280000, 281000, 282000, 283000, 284000, 285000, 286000, 287000, 288000, 289000, 290000, 291000, 292000, 293000, 294000, 295000, 296000, 297000, 298000, 299000, 300000, 301000, 302000, 303000, 304000, 305000, 306000, 307000, 308000, 309000, 310000, 311000, 311176]\n"
|
430 |
+
]
|
431 |
+
}
|
432 |
+
],
|
433 |
+
"source": [
|
434 |
+
"cuts = [*range(0, len(data), 1000)]\n",
|
435 |
+
"cuts.append(len(data))\n",
|
436 |
+
"print(cuts)"
|
437 |
+
]
|
438 |
+
},
|
439 |
+
{
|
440 |
+
"cell_type": "code",
|
441 |
+
"execution_count": null,
|
442 |
+
"id": "066c67ac-01a6-4151-8e85-5869ddce1c0a",
|
443 |
+
"metadata": {},
|
444 |
+
"outputs": [],
|
445 |
+
"source": [
|
446 |
+
"article_id = []\n",
|
447 |
+
"article_desc = []\n",
|
448 |
+
"article_embed = []\n",
|
449 |
+
"article_cat = []\n",
|
450 |
+
"try:\n",
|
451 |
+
" for i in tqdm(range(len(cuts)-1)):\n",
|
452 |
+
" df = data.iloc[cuts[i]: cuts[i+1]].copy()\n",
|
453 |
+
" article_id = [*df['index']]\n",
|
454 |
+
" article_desc = [*df['short_description']]\n",
|
455 |
+
" article_cat = [*df['category']]\n",
|
456 |
+
" results = []\n",
|
457 |
+
" article_embed = vectorizer.vectorize(article_desc)\n",
|
458 |
+
" docs = [article_id, article_embed, article_desc, article_cat]\n",
|
459 |
+
" ins_resp = collection.insert(docs)\n",
|
460 |
+
" print(ins_resp)\n",
|
461 |
+
" article_id = []\n",
|
462 |
+
" article_desc = []\n",
|
463 |
+
" article_embed = []\n",
|
464 |
+
" article_cat = []\n",
|
465 |
+
" if i == 0:\n",
|
466 |
+
" index_params = {\"index_type\": \"AUTOINDEX\", \"metric_type\": \"L2\", \"params\": {}} \n",
|
467 |
+
" collection.create_index(field_name='article_embed', index_params=index_params)\n",
|
468 |
+
" collection = Collection(name=collection_name)\n",
|
469 |
+
" collection.load()\n",
|
470 |
+
"except:\n",
|
471 |
+
" raise"
|
472 |
+
]
|
473 |
+
},
|
474 |
+
{
|
475 |
+
"cell_type": "code",
|
476 |
+
"execution_count": null,
|
477 |
+
"id": "d50177fa-fd0c-48ad-bc9b-a7bdc826a628",
|
478 |
+
"metadata": {},
|
479 |
+
"outputs": [],
|
480 |
+
"source": []
|
481 |
+
}
|
482 |
+
],
|
483 |
+
"metadata": {
|
484 |
+
"kernelspec": {
|
485 |
+
"display_name": "Python (tf_gpu)",
|
486 |
+
"language": "python",
|
487 |
+
"name": "tf_gpu"
|
488 |
+
},
|
489 |
+
"language_info": {
|
490 |
+
"codemirror_mode": {
|
491 |
+
"name": "ipython",
|
492 |
+
"version": 3
|
493 |
+
},
|
494 |
+
"file_extension": ".py",
|
495 |
+
"mimetype": "text/x-python",
|
496 |
+
"name": "python",
|
497 |
+
"nbconvert_exporter": "python",
|
498 |
+
"pygments_lexer": "ipython3",
|
499 |
+
"version": "3.9.18"
|
500 |
+
}
|
501 |
+
},
|
502 |
+
"nbformat": 4,
|
503 |
+
"nbformat_minor": 5
|
504 |
+
}
|
tsdae_finetune_sent_transformer.ipynb
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "f0dc396f",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"### TSDAE: Fine-tune sentence transformers using unsupervised learning with Pytorch\n",
|
9 |
+
"https://www.sbert.net/examples/unsupervised_learning/TSDAE/README.html"
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": 2,
|
15 |
+
"id": "34329058",
|
16 |
+
"metadata": {
|
17 |
+
"_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
|
18 |
+
"_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5",
|
19 |
+
"execution": {
|
20 |
+
"iopub.execute_input": "2024-01-13T16:51:37.887254Z",
|
21 |
+
"iopub.status.busy": "2024-01-13T16:51:37.886390Z",
|
22 |
+
"iopub.status.idle": "2024-01-13T16:51:37.891827Z",
|
23 |
+
"shell.execute_reply": "2024-01-13T16:51:37.890706Z",
|
24 |
+
"shell.execute_reply.started": "2024-01-13T16:51:37.887212Z"
|
25 |
+
}
|
26 |
+
},
|
27 |
+
"outputs": [],
|
28 |
+
"source": [
|
29 |
+
"# !pip install sentence_transformers==2.2.2"
|
30 |
+
]
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"cell_type": "code",
|
34 |
+
"execution_count": 4,
|
35 |
+
"id": "ebd138d8",
|
36 |
+
"metadata": {
|
37 |
+
"execution": {
|
38 |
+
"iopub.execute_input": "2024-01-13T16:51:50.310221Z",
|
39 |
+
"iopub.status.busy": "2024-01-13T16:51:50.309586Z",
|
40 |
+
"iopub.status.idle": "2024-01-13T16:51:50.315850Z",
|
41 |
+
"shell.execute_reply": "2024-01-13T16:51:50.314927Z",
|
42 |
+
"shell.execute_reply.started": "2024-01-13T16:51:50.310185Z"
|
43 |
+
}
|
44 |
+
},
|
45 |
+
"outputs": [],
|
46 |
+
"source": [
|
47 |
+
"import pandas as pd\n",
|
48 |
+
"import numpy as np\n",
|
49 |
+
"import string\n",
|
50 |
+
"from tqdm import tqdm\n",
|
51 |
+
"from numpy.linalg import norm\n",
|
52 |
+
"from sentence_transformers import SentenceTransformer, LoggingHandler\n",
|
53 |
+
"from sentence_transformers import models, util, datasets, evaluation, losses\n",
|
54 |
+
"from torch.utils.data import DataLoader"
|
55 |
+
]
|
56 |
+
},
|
57 |
+
{
|
58 |
+
"cell_type": "code",
|
59 |
+
"execution_count": 6,
|
60 |
+
"id": "7ef2d063",
|
61 |
+
"metadata": {
|
62 |
+
"execution": {
|
63 |
+
"iopub.execute_input": "2024-01-13T13:24:44.770211Z",
|
64 |
+
"iopub.status.busy": "2024-01-13T13:24:44.769806Z",
|
65 |
+
"iopub.status.idle": "2024-01-13T13:24:44.775042Z",
|
66 |
+
"shell.execute_reply": "2024-01-13T13:24:44.773860Z",
|
67 |
+
"shell.execute_reply.started": "2024-01-13T13:24:44.770177Z"
|
68 |
+
}
|
69 |
+
},
|
70 |
+
"outputs": [],
|
71 |
+
"source": [
|
72 |
+
"# import nltk\n",
|
73 |
+
"# nltk.download('punkt')"
|
74 |
+
]
|
75 |
+
},
|
76 |
+
{
|
77 |
+
"cell_type": "code",
|
78 |
+
"execution_count": 5,
|
79 |
+
"id": "453c1add",
|
80 |
+
"metadata": {
|
81 |
+
"execution": {
|
82 |
+
"iopub.execute_input": "2024-01-13T16:51:54.070689Z",
|
83 |
+
"iopub.status.busy": "2024-01-13T16:51:54.069945Z",
|
84 |
+
"iopub.status.idle": "2024-01-13T16:51:54.809726Z",
|
85 |
+
"shell.execute_reply": "2024-01-13T16:51:54.808920Z",
|
86 |
+
"shell.execute_reply.started": "2024-01-13T16:51:54.070657Z"
|
87 |
+
}
|
88 |
+
},
|
89 |
+
"outputs": [],
|
90 |
+
"source": [
|
91 |
+
"data = pd.read_csv('news_processed.csv', usecols=['short_description'])"
|
92 |
+
]
|
93 |
+
},
|
94 |
+
{
|
95 |
+
"cell_type": "code",
|
96 |
+
"execution_count": 6,
|
97 |
+
"id": "61629b79",
|
98 |
+
"metadata": {
|
99 |
+
"execution": {
|
100 |
+
"iopub.execute_input": "2024-01-13T16:51:55.125758Z",
|
101 |
+
"iopub.status.busy": "2024-01-13T16:51:55.124990Z",
|
102 |
+
"iopub.status.idle": "2024-01-13T16:51:55.180470Z",
|
103 |
+
"shell.execute_reply": "2024-01-13T16:51:55.179559Z",
|
104 |
+
"shell.execute_reply.started": "2024-01-13T16:51:55.125716Z"
|
105 |
+
}
|
106 |
+
},
|
107 |
+
"outputs": [],
|
108 |
+
"source": [
|
109 |
+
"data.dropna(inplace=True)"
|
110 |
+
]
|
111 |
+
},
|
112 |
+
{
|
113 |
+
"cell_type": "code",
|
114 |
+
"execution_count": 7,
|
115 |
+
"id": "14162e2c",
|
116 |
+
"metadata": {
|
117 |
+
"execution": {
|
118 |
+
"iopub.execute_input": "2024-01-13T16:51:55.817535Z",
|
119 |
+
"iopub.status.busy": "2024-01-13T16:51:55.816764Z",
|
120 |
+
"iopub.status.idle": "2024-01-13T16:51:55.836578Z",
|
121 |
+
"shell.execute_reply": "2024-01-13T16:51:55.835697Z",
|
122 |
+
"shell.execute_reply.started": "2024-01-13T16:51:55.817499Z"
|
123 |
+
}
|
124 |
+
},
|
125 |
+
"outputs": [
|
126 |
+
{
|
127 |
+
"data": {
|
128 |
+
"text/plain": [
|
129 |
+
"'Experimental coronavirus vaccine prevents severe disease in mice'"
|
130 |
+
]
|
131 |
+
},
|
132 |
+
"execution_count": 7,
|
133 |
+
"metadata": {},
|
134 |
+
"output_type": "execute_result"
|
135 |
+
}
|
136 |
+
],
|
137 |
+
"source": [
|
138 |
+
"data['short_description'][1000]"
|
139 |
+
]
|
140 |
+
},
|
141 |
+
{
|
142 |
+
"cell_type": "code",
|
143 |
+
"execution_count": 14,
|
144 |
+
"id": "8fd1a801",
|
145 |
+
"metadata": {
|
146 |
+
"execution": {
|
147 |
+
"iopub.execute_input": "2024-01-13T17:08:09.906826Z",
|
148 |
+
"iopub.status.busy": "2024-01-13T17:08:09.906182Z",
|
149 |
+
"iopub.status.idle": "2024-01-13T17:08:09.914834Z",
|
150 |
+
"shell.execute_reply": "2024-01-13T17:08:09.913737Z",
|
151 |
+
"shell.execute_reply.started": "2024-01-13T17:08:09.906795Z"
|
152 |
+
}
|
153 |
+
},
|
154 |
+
"outputs": [],
|
155 |
+
"source": [
|
156 |
+
"def finetune_model(data: pd.DataFrame, col_to_use: str='short_description', \n",
|
157 |
+
" model_id: str=\"sentence-transformers/multi-qa-distilbert-cos-v1\", \n",
|
158 |
+
" batch_size: int=8, epochs: int=2):\n",
|
159 |
+
" \n",
|
160 |
+
"# https://www.sbert.net/examples/unsupervised_learning/TSDAE/README.html\n",
|
161 |
+
" \n",
|
162 |
+
" word_embedding_model = models.Transformer(model_id)\n",
|
163 |
+
" pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), 'cls')\n",
|
164 |
+
" model = SentenceTransformer(modules=[word_embedding_model, pooling_model])\n",
|
165 |
+
" \n",
|
166 |
+
" train_examples = data[col_to_use].tolist()\n",
|
167 |
+
" train_dataset = datasets.DenoisingAutoEncoderDataset(train_examples)\n",
|
168 |
+
" train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
|
169 |
+
" train_loss = losses.DenoisingAutoEncoderLoss(model, decoder_name_or_path=\"sentence-transformers/paraphrase-distilroberta-base-v2\", tie_encoder_decoder=False)\n",
|
170 |
+
"# train_loss = losses.DenoisingAutoEncoderLoss(model, decoder_name_or_path=model_id, tie_encoder_decoder=True)\n",
|
171 |
+
"\n",
|
172 |
+
" \n",
|
173 |
+
" model.fit(\n",
|
174 |
+
" train_objectives=[(train_dataloader, train_loss)],\n",
|
175 |
+
" epochs=epochs,\n",
|
176 |
+
" weight_decay=0,\n",
|
177 |
+
" scheduler='constantlr',\n",
|
178 |
+
" optimizer_params={'lr': 3e-5},\n",
|
179 |
+
" show_progress_bar=True\n",
|
180 |
+
" )\n",
|
181 |
+
" model_save_path = model_id + '_finetuned'\n",
|
182 |
+
" model.save(model_save_path)\n",
|
183 |
+
" return model_save_path"
|
184 |
+
]
|
185 |
+
},
|
186 |
+
{
|
187 |
+
"cell_type": "code",
|
188 |
+
"execution_count": null,
|
189 |
+
"id": "b17a9779",
|
190 |
+
"metadata": {
|
191 |
+
"scrolled": true
|
192 |
+
},
|
193 |
+
"outputs": [],
|
194 |
+
"source": [
|
195 |
+
"# fine-tune sentence transformer\n",
|
196 |
+
"finetuned_model_id = finetune_model(data=data)\n",
|
197 |
+
"finetuned_model = SentenceTransformer(finetuned_model_id)"
|
198 |
+
]
|
199 |
+
},
|
200 |
+
{
|
201 |
+
"cell_type": "code",
|
202 |
+
"execution_count": null,
|
203 |
+
"id": "4d87c128",
|
204 |
+
"metadata": {},
|
205 |
+
"outputs": [],
|
206 |
+
"source": []
|
207 |
+
}
|
208 |
+
],
|
209 |
+
"metadata": {
|
210 |
+
"kaggle": {
|
211 |
+
"accelerator": "gpu",
|
212 |
+
"dataSources": [
|
213 |
+
{
|
214 |
+
"datasetId": 4298708,
|
215 |
+
"sourceId": 7394110,
|
216 |
+
"sourceType": "datasetVersion"
|
217 |
+
}
|
218 |
+
],
|
219 |
+
"dockerImageVersionId": 30636,
|
220 |
+
"isGpuEnabled": true,
|
221 |
+
"isInternetEnabled": true,
|
222 |
+
"language": "python",
|
223 |
+
"sourceType": "notebook"
|
224 |
+
},
|
225 |
+
"kernelspec": {
|
226 |
+
"display_name": "Python 3 (ipykernel)",
|
227 |
+
"language": "python",
|
228 |
+
"name": "python3"
|
229 |
+
},
|
230 |
+
"language_info": {
|
231 |
+
"codemirror_mode": {
|
232 |
+
"name": "ipython",
|
233 |
+
"version": 3
|
234 |
+
},
|
235 |
+
"file_extension": ".py",
|
236 |
+
"mimetype": "text/x-python",
|
237 |
+
"name": "python",
|
238 |
+
"nbconvert_exporter": "python",
|
239 |
+
"pygments_lexer": "ipython3",
|
240 |
+
"version": "3.10.12"
|
241 |
+
}
|
242 |
+
},
|
243 |
+
"nbformat": 4,
|
244 |
+
"nbformat_minor": 5
|
245 |
+
}
|