Jingyuan-Zhu commited on
Commit
4f23a2b
1 Parent(s): 7f56829

Upload XGBoost_embedding.ipynb

Browse files
Files changed (1) hide show
  1. XGBoost_embedding.ipynb +593 -0
XGBoost_embedding.ipynb ADDED
@@ -0,0 +1,593 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": []
7
+ },
8
+ "kernelspec": {
9
+ "name": "python3",
10
+ "display_name": "Python 3"
11
+ },
12
+ "language_info": {
13
+ "name": "python"
14
+ },
15
+ "widgets": {
16
+ "application/vnd.jupyter.widget-state+json": {
17
+ "d564a252f0994e50a82e49477311aaaf": {
18
+ "model_module": "@jupyter-widgets/controls",
19
+ "model_name": "HBoxModel",
20
+ "model_module_version": "1.5.0",
21
+ "state": {
22
+ "_dom_classes": [],
23
+ "_model_module": "@jupyter-widgets/controls",
24
+ "_model_module_version": "1.5.0",
25
+ "_model_name": "HBoxModel",
26
+ "_view_count": null,
27
+ "_view_module": "@jupyter-widgets/controls",
28
+ "_view_module_version": "1.5.0",
29
+ "_view_name": "HBoxView",
30
+ "box_style": "",
31
+ "children": [
32
+ "IPY_MODEL_e1281f85ea8e48ac82e6cdf68530b700",
33
+ "IPY_MODEL_e95c7910fde4463190dbace7fc4bbd43",
34
+ "IPY_MODEL_fcdbe6f02f30423687d2262a1c260609"
35
+ ],
36
+ "layout": "IPY_MODEL_643fa99220814aaea83ef3cf9001971c"
37
+ }
38
+ },
39
+ "e1281f85ea8e48ac82e6cdf68530b700": {
40
+ "model_module": "@jupyter-widgets/controls",
41
+ "model_name": "HTMLModel",
42
+ "model_module_version": "1.5.0",
43
+ "state": {
44
+ "_dom_classes": [],
45
+ "_model_module": "@jupyter-widgets/controls",
46
+ "_model_module_version": "1.5.0",
47
+ "_model_name": "HTMLModel",
48
+ "_view_count": null,
49
+ "_view_module": "@jupyter-widgets/controls",
50
+ "_view_module_version": "1.5.0",
51
+ "_view_name": "HTMLView",
52
+ "description": "",
53
+ "description_tooltip": null,
54
+ "layout": "IPY_MODEL_9618ee7122a44226b57b836136f94df5",
55
+ "placeholder": "​",
56
+ "style": "IPY_MODEL_5bd6ab3d866540caa45b9e82fcc7f081",
57
+ "value": "Fetching 30 files: 100%"
58
+ }
59
+ },
60
+ "e95c7910fde4463190dbace7fc4bbd43": {
61
+ "model_module": "@jupyter-widgets/controls",
62
+ "model_name": "FloatProgressModel",
63
+ "model_module_version": "1.5.0",
64
+ "state": {
65
+ "_dom_classes": [],
66
+ "_model_module": "@jupyter-widgets/controls",
67
+ "_model_module_version": "1.5.0",
68
+ "_model_name": "FloatProgressModel",
69
+ "_view_count": null,
70
+ "_view_module": "@jupyter-widgets/controls",
71
+ "_view_module_version": "1.5.0",
72
+ "_view_name": "ProgressView",
73
+ "bar_style": "success",
74
+ "description": "",
75
+ "description_tooltip": null,
76
+ "layout": "IPY_MODEL_87b0deadc3d54484b21b7aab1ed67379",
77
+ "max": 30,
78
+ "min": 0,
79
+ "orientation": "horizontal",
80
+ "style": "IPY_MODEL_6c0747f6032f4788a4c476aa43246cd3",
81
+ "value": 30
82
+ }
83
+ },
84
+ "fcdbe6f02f30423687d2262a1c260609": {
85
+ "model_module": "@jupyter-widgets/controls",
86
+ "model_name": "HTMLModel",
87
+ "model_module_version": "1.5.0",
88
+ "state": {
89
+ "_dom_classes": [],
90
+ "_model_module": "@jupyter-widgets/controls",
91
+ "_model_module_version": "1.5.0",
92
+ "_model_name": "HTMLModel",
93
+ "_view_count": null,
94
+ "_view_module": "@jupyter-widgets/controls",
95
+ "_view_module_version": "1.5.0",
96
+ "_view_name": "HTMLView",
97
+ "description": "",
98
+ "description_tooltip": null,
99
+ "layout": "IPY_MODEL_5001265687cd4a2da0be38685f18d003",
100
+ "placeholder": "​",
101
+ "style": "IPY_MODEL_77a09a218d2a421e9b3fe33f8310a37e",
102
+ "value": " 30/30 [00:00<00:00, 862.80it/s]"
103
+ }
104
+ },
105
+ "643fa99220814aaea83ef3cf9001971c": {
106
+ "model_module": "@jupyter-widgets/base",
107
+ "model_name": "LayoutModel",
108
+ "model_module_version": "1.2.0",
109
+ "state": {
110
+ "_model_module": "@jupyter-widgets/base",
111
+ "_model_module_version": "1.2.0",
112
+ "_model_name": "LayoutModel",
113
+ "_view_count": null,
114
+ "_view_module": "@jupyter-widgets/base",
115
+ "_view_module_version": "1.2.0",
116
+ "_view_name": "LayoutView",
117
+ "align_content": null,
118
+ "align_items": null,
119
+ "align_self": null,
120
+ "border": null,
121
+ "bottom": null,
122
+ "display": null,
123
+ "flex": null,
124
+ "flex_flow": null,
125
+ "grid_area": null,
126
+ "grid_auto_columns": null,
127
+ "grid_auto_flow": null,
128
+ "grid_auto_rows": null,
129
+ "grid_column": null,
130
+ "grid_gap": null,
131
+ "grid_row": null,
132
+ "grid_template_areas": null,
133
+ "grid_template_columns": null,
134
+ "grid_template_rows": null,
135
+ "height": null,
136
+ "justify_content": null,
137
+ "justify_items": null,
138
+ "left": null,
139
+ "margin": null,
140
+ "max_height": null,
141
+ "max_width": null,
142
+ "min_height": null,
143
+ "min_width": null,
144
+ "object_fit": null,
145
+ "object_position": null,
146
+ "order": null,
147
+ "overflow": null,
148
+ "overflow_x": null,
149
+ "overflow_y": null,
150
+ "padding": null,
151
+ "right": null,
152
+ "top": null,
153
+ "visibility": null,
154
+ "width": null
155
+ }
156
+ },
157
+ "9618ee7122a44226b57b836136f94df5": {
158
+ "model_module": "@jupyter-widgets/base",
159
+ "model_name": "LayoutModel",
160
+ "model_module_version": "1.2.0",
161
+ "state": {
162
+ "_model_module": "@jupyter-widgets/base",
163
+ "_model_module_version": "1.2.0",
164
+ "_model_name": "LayoutModel",
165
+ "_view_count": null,
166
+ "_view_module": "@jupyter-widgets/base",
167
+ "_view_module_version": "1.2.0",
168
+ "_view_name": "LayoutView",
169
+ "align_content": null,
170
+ "align_items": null,
171
+ "align_self": null,
172
+ "border": null,
173
+ "bottom": null,
174
+ "display": null,
175
+ "flex": null,
176
+ "flex_flow": null,
177
+ "grid_area": null,
178
+ "grid_auto_columns": null,
179
+ "grid_auto_flow": null,
180
+ "grid_auto_rows": null,
181
+ "grid_column": null,
182
+ "grid_gap": null,
183
+ "grid_row": null,
184
+ "grid_template_areas": null,
185
+ "grid_template_columns": null,
186
+ "grid_template_rows": null,
187
+ "height": null,
188
+ "justify_content": null,
189
+ "justify_items": null,
190
+ "left": null,
191
+ "margin": null,
192
+ "max_height": null,
193
+ "max_width": null,
194
+ "min_height": null,
195
+ "min_width": null,
196
+ "object_fit": null,
197
+ "object_position": null,
198
+ "order": null,
199
+ "overflow": null,
200
+ "overflow_x": null,
201
+ "overflow_y": null,
202
+ "padding": null,
203
+ "right": null,
204
+ "top": null,
205
+ "visibility": null,
206
+ "width": null
207
+ }
208
+ },
209
+ "5bd6ab3d866540caa45b9e82fcc7f081": {
210
+ "model_module": "@jupyter-widgets/controls",
211
+ "model_name": "DescriptionStyleModel",
212
+ "model_module_version": "1.5.0",
213
+ "state": {
214
+ "_model_module": "@jupyter-widgets/controls",
215
+ "_model_module_version": "1.5.0",
216
+ "_model_name": "DescriptionStyleModel",
217
+ "_view_count": null,
218
+ "_view_module": "@jupyter-widgets/base",
219
+ "_view_module_version": "1.2.0",
220
+ "_view_name": "StyleView",
221
+ "description_width": ""
222
+ }
223
+ },
224
+ "87b0deadc3d54484b21b7aab1ed67379": {
225
+ "model_module": "@jupyter-widgets/base",
226
+ "model_name": "LayoutModel",
227
+ "model_module_version": "1.2.0",
228
+ "state": {
229
+ "_model_module": "@jupyter-widgets/base",
230
+ "_model_module_version": "1.2.0",
231
+ "_model_name": "LayoutModel",
232
+ "_view_count": null,
233
+ "_view_module": "@jupyter-widgets/base",
234
+ "_view_module_version": "1.2.0",
235
+ "_view_name": "LayoutView",
236
+ "align_content": null,
237
+ "align_items": null,
238
+ "align_self": null,
239
+ "border": null,
240
+ "bottom": null,
241
+ "display": null,
242
+ "flex": null,
243
+ "flex_flow": null,
244
+ "grid_area": null,
245
+ "grid_auto_columns": null,
246
+ "grid_auto_flow": null,
247
+ "grid_auto_rows": null,
248
+ "grid_column": null,
249
+ "grid_gap": null,
250
+ "grid_row": null,
251
+ "grid_template_areas": null,
252
+ "grid_template_columns": null,
253
+ "grid_template_rows": null,
254
+ "height": null,
255
+ "justify_content": null,
256
+ "justify_items": null,
257
+ "left": null,
258
+ "margin": null,
259
+ "max_height": null,
260
+ "max_width": null,
261
+ "min_height": null,
262
+ "min_width": null,
263
+ "object_fit": null,
264
+ "object_position": null,
265
+ "order": null,
266
+ "overflow": null,
267
+ "overflow_x": null,
268
+ "overflow_y": null,
269
+ "padding": null,
270
+ "right": null,
271
+ "top": null,
272
+ "visibility": null,
273
+ "width": null
274
+ }
275
+ },
276
+ "6c0747f6032f4788a4c476aa43246cd3": {
277
+ "model_module": "@jupyter-widgets/controls",
278
+ "model_name": "ProgressStyleModel",
279
+ "model_module_version": "1.5.0",
280
+ "state": {
281
+ "_model_module": "@jupyter-widgets/controls",
282
+ "_model_module_version": "1.5.0",
283
+ "_model_name": "ProgressStyleModel",
284
+ "_view_count": null,
285
+ "_view_module": "@jupyter-widgets/base",
286
+ "_view_module_version": "1.2.0",
287
+ "_view_name": "StyleView",
288
+ "bar_color": null,
289
+ "description_width": ""
290
+ }
291
+ },
292
+ "5001265687cd4a2da0be38685f18d003": {
293
+ "model_module": "@jupyter-widgets/base",
294
+ "model_name": "LayoutModel",
295
+ "model_module_version": "1.2.0",
296
+ "state": {
297
+ "_model_module": "@jupyter-widgets/base",
298
+ "_model_module_version": "1.2.0",
299
+ "_model_name": "LayoutModel",
300
+ "_view_count": null,
301
+ "_view_module": "@jupyter-widgets/base",
302
+ "_view_module_version": "1.2.0",
303
+ "_view_name": "LayoutView",
304
+ "align_content": null,
305
+ "align_items": null,
306
+ "align_self": null,
307
+ "border": null,
308
+ "bottom": null,
309
+ "display": null,
310
+ "flex": null,
311
+ "flex_flow": null,
312
+ "grid_area": null,
313
+ "grid_auto_columns": null,
314
+ "grid_auto_flow": null,
315
+ "grid_auto_rows": null,
316
+ "grid_column": null,
317
+ "grid_gap": null,
318
+ "grid_row": null,
319
+ "grid_template_areas": null,
320
+ "grid_template_columns": null,
321
+ "grid_template_rows": null,
322
+ "height": null,
323
+ "justify_content": null,
324
+ "justify_items": null,
325
+ "left": null,
326
+ "margin": null,
327
+ "max_height": null,
328
+ "max_width": null,
329
+ "min_height": null,
330
+ "min_width": null,
331
+ "object_fit": null,
332
+ "object_position": null,
333
+ "order": null,
334
+ "overflow": null,
335
+ "overflow_x": null,
336
+ "overflow_y": null,
337
+ "padding": null,
338
+ "right": null,
339
+ "top": null,
340
+ "visibility": null,
341
+ "width": null
342
+ }
343
+ },
344
+ "77a09a218d2a421e9b3fe33f8310a37e": {
345
+ "model_module": "@jupyter-widgets/controls",
346
+ "model_name": "DescriptionStyleModel",
347
+ "model_module_version": "1.5.0",
348
+ "state": {
349
+ "_model_module": "@jupyter-widgets/controls",
350
+ "_model_module_version": "1.5.0",
351
+ "_model_name": "DescriptionStyleModel",
352
+ "_view_count": null,
353
+ "_view_module": "@jupyter-widgets/base",
354
+ "_view_module_version": "1.2.0",
355
+ "_view_name": "StyleView",
356
+ "description_width": ""
357
+ }
358
+ }
359
+ }
360
+ }
361
+ },
362
+ "cells": [
363
+ {
364
+ "cell_type": "markdown",
365
+ "source": [
366
+ "Evaluation Pipeline for XGboost with Embedding features\n"
367
+ ],
368
+ "metadata": {
369
+ "id": "Q1KbUyBa9AOW"
370
+ }
371
+ },
372
+ {
373
+ "cell_type": "markdown",
374
+ "source": [
375
+ "Install the required model and change the data and model path"
376
+ ],
377
+ "metadata": {
378
+ "id": "zSikCRSxAh8a"
379
+ }
380
+ },
381
+ {
382
+ "cell_type": "code",
383
+ "source": [
384
+ "!pip install -U FlagEmbedding\n",
385
+ "from FlagEmbedding import BGEM3FlagModel\n",
386
+ "model = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True)"
387
+ ],
388
+ "metadata": {
389
+ "id": "mKfYf2gyOAYS",
390
+ "colab": {
391
+ "base_uri": "https://localhost:8080/",
392
+ "height": 1000,
393
+ "referenced_widgets": [
394
+ "d564a252f0994e50a82e49477311aaaf",
395
+ "e1281f85ea8e48ac82e6cdf68530b700",
396
+ "e95c7910fde4463190dbace7fc4bbd43",
397
+ "fcdbe6f02f30423687d2262a1c260609",
398
+ "643fa99220814aaea83ef3cf9001971c",
399
+ "9618ee7122a44226b57b836136f94df5",
400
+ "5bd6ab3d866540caa45b9e82fcc7f081",
401
+ "87b0deadc3d54484b21b7aab1ed67379",
402
+ "6c0747f6032f4788a4c476aa43246cd3",
403
+ "5001265687cd4a2da0be38685f18d003",
404
+ "77a09a218d2a421e9b3fe33f8310a37e"
405
+ ]
406
+ },
407
+ "outputId": "a4863acc-785a-44cb-d126-0cc9867ead3e"
408
+ },
409
+ "execution_count": 6,
410
+ "outputs": [
411
+ {
412
+ "output_type": "stream",
413
+ "name": "stdout",
414
+ "text": [
415
+ "Requirement already satisfied: FlagEmbedding in /usr/local/lib/python3.10/dist-packages (1.3.3)\n",
416
+ "Requirement already satisfied: torch>=1.6.0 in /usr/local/lib/python3.10/dist-packages (from FlagEmbedding) (2.5.1+cu121)\n",
417
+ "Requirement already satisfied: transformers==4.44.2 in /usr/local/lib/python3.10/dist-packages (from FlagEmbedding) (4.44.2)\n",
418
+ "Requirement already satisfied: datasets==2.19.0 in /usr/local/lib/python3.10/dist-packages (from FlagEmbedding) (2.19.0)\n",
419
+ "Requirement already satisfied: accelerate>=0.20.1 in /usr/local/lib/python3.10/dist-packages (from FlagEmbedding) (1.1.1)\n",
420
+ "Requirement already satisfied: sentence-transformers in /usr/local/lib/python3.10/dist-packages (from FlagEmbedding) (3.2.1)\n",
421
+ "Requirement already satisfied: peft in /usr/local/lib/python3.10/dist-packages (from FlagEmbedding) (0.13.2)\n",
422
+ "Requirement already satisfied: ir-datasets in /usr/local/lib/python3.10/dist-packages (from FlagEmbedding) (0.5.9)\n",
423
+ "Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (from FlagEmbedding) (0.2.0)\n",
424
+ "Requirement already satisfied: protobuf in /usr/local/lib/python3.10/dist-packages (from FlagEmbedding) (4.25.5)\n",
425
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets==2.19.0->FlagEmbedding) (3.16.1)\n",
426
+ "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets==2.19.0->FlagEmbedding) (1.26.4)\n",
427
+ "Requirement already satisfied: pyarrow>=12.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.19.0->FlagEmbedding) (17.0.0)\n",
428
+ "Requirement already satisfied: pyarrow-hotfix in /usr/local/lib/python3.10/dist-packages (from datasets==2.19.0->FlagEmbedding) (0.6)\n",
429
+ "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.19.0->FlagEmbedding) (0.3.8)\n",
430
+ "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets==2.19.0->FlagEmbedding) (2.2.2)\n",
431
+ "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.19.0->FlagEmbedding) (2.32.3)\n",
432
+ "Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.19.0->FlagEmbedding) (4.66.6)\n",
433
+ "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets==2.19.0->FlagEmbedding) (3.5.0)\n",
434
+ "Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets==2.19.0->FlagEmbedding) (0.70.16)\n",
435
+ "Requirement already satisfied: fsspec<=2024.3.1,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<=2024.3.1,>=2023.1.0->datasets==2.19.0->FlagEmbedding) (2024.3.1)\n",
436
+ "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets==2.19.0->FlagEmbedding) (3.11.9)\n",
437
+ "Requirement already satisfied: huggingface-hub>=0.21.2 in /usr/local/lib/python3.10/dist-packages (from datasets==2.19.0->FlagEmbedding) (0.26.3)\n",
438
+ "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets==2.19.0->FlagEmbedding) (24.2)\n",
439
+ "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.19.0->FlagEmbedding) (6.0.2)\n",
440
+ "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers==4.44.2->FlagEmbedding) (2024.9.11)\n",
441
+ "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers==4.44.2->FlagEmbedding) (0.4.5)\n",
442
+ "Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.10/dist-packages (from transformers==4.44.2->FlagEmbedding) (0.19.1)\n",
443
+ "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.20.1->FlagEmbedding) (5.9.5)\n",
444
+ "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.6.0->FlagEmbedding) (4.12.2)\n",
445
+ "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.6.0->FlagEmbedding) (3.4.2)\n",
446
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.6.0->FlagEmbedding) (3.1.4)\n",
447
+ "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch>=1.6.0->FlagEmbedding) (1.13.1)\n",
448
+ "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch>=1.6.0->FlagEmbedding) (1.3.0)\n",
449
+ "Requirement already satisfied: beautifulsoup4>=4.4.1 in /usr/local/lib/python3.10/dist-packages (from ir-datasets->FlagEmbedding) (4.12.3)\n",
450
+ "Requirement already satisfied: inscriptis>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from ir-datasets->FlagEmbedding) (2.5.0)\n",
451
+ "Requirement already satisfied: lxml>=4.5.2 in /usr/local/lib/python3.10/dist-packages (from ir-datasets->FlagEmbedding) (5.3.0)\n",
452
+ "Requirement already satisfied: trec-car-tools>=2.5.4 in /usr/local/lib/python3.10/dist-packages (from ir-datasets->FlagEmbedding) (2.6)\n",
453
+ "Requirement already satisfied: lz4>=3.1.10 in /usr/local/lib/python3.10/dist-packages (from ir-datasets->FlagEmbedding) (4.3.3)\n",
454
+ "Requirement already satisfied: warc3-wet>=0.2.3 in /usr/local/lib/python3.10/dist-packages (from ir-datasets->FlagEmbedding) (0.2.5)\n",
455
+ "Requirement already satisfied: warc3-wet-clueweb09>=0.2.5 in /usr/local/lib/python3.10/dist-packages (from ir-datasets->FlagEmbedding) (0.2.5)\n",
456
+ "Requirement already satisfied: zlib-state>=0.1.3 in /usr/local/lib/python3.10/dist-packages (from ir-datasets->FlagEmbedding) (0.1.9)\n",
457
+ "Requirement already satisfied: ijson>=3.1.3 in /usr/local/lib/python3.10/dist-packages (from ir-datasets->FlagEmbedding) (3.3.0)\n",
458
+ "Requirement already satisfied: unlzw3>=0.2.1 in /usr/local/lib/python3.10/dist-packages (from ir-datasets->FlagEmbedding) (0.2.2)\n",
459
+ "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.10/dist-packages (from sentence-transformers->FlagEmbedding) (1.5.2)\n",
460
+ "Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from sentence-transformers->FlagEmbedding) (1.13.1)\n",
461
+ "Requirement already satisfied: Pillow in /usr/local/lib/python3.10/dist-packages (from sentence-transformers->FlagEmbedding) (11.0.0)\n",
462
+ "Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.10/dist-packages (from beautifulsoup4>=4.4.1->ir-datasets->FlagEmbedding) (2.6)\n",
463
+ "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (2.4.4)\n",
464
+ "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (1.3.1)\n",
465
+ "Requirement already satisfied: async-timeout<6.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (4.0.3)\n",
466
+ "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (24.2.0)\n",
467
+ "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (1.5.0)\n",
468
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (6.1.0)\n",
469
+ "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (0.2.1)\n",
470
+ "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (1.18.3)\n",
471
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.19.0->FlagEmbedding) (3.4.0)\n",
472
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.19.0->FlagEmbedding) (3.10)\n",
473
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.19.0->FlagEmbedding) (2.2.3)\n",
474
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.19.0->FlagEmbedding) (2024.8.30)\n",
475
+ "Requirement already satisfied: cbor>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from trec-car-tools>=2.5.4->ir-datasets->FlagEmbedding) (1.0.0)\n",
476
+ "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.6.0->FlagEmbedding) (3.0.2)\n",
477
+ "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.19.0->FlagEmbedding) (2.8.2)\n",
478
+ "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.19.0->FlagEmbedding) (2024.2)\n",
479
+ "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.19.0->FlagEmbedding) (2024.2)\n",
480
+ "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->sentence-transformers->FlagEmbedding) (1.4.2)\n",
481
+ "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->sentence-transformers->FlagEmbedding) (3.5.0)\n",
482
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->datasets==2.19.0->FlagEmbedding) (1.16.0)\n"
483
+ ]
484
+ },
485
+ {
486
+ "output_type": "display_data",
487
+ "data": {
488
+ "text/plain": [
489
+ "Fetching 30 files: 0%| | 0/30 [00:00<?, ?it/s]"
490
+ ],
491
+ "application/vnd.jupyter.widget-view+json": {
492
+ "version_major": 2,
493
+ "version_minor": 0,
494
+ "model_id": "d564a252f0994e50a82e49477311aaaf"
495
+ }
496
+ },
497
+ "metadata": {}
498
+ }
499
+ ]
500
+ },
501
+ {
502
+ "cell_type": "code",
503
+ "source": [
504
+ "#import pandas as pd\n",
505
+ "import xgboost as xgb\n",
506
+ "from sklearn.metrics import accuracy_score, classification_report\n",
507
+ "\n",
508
+ "csv_file_path = '/content/drive/Shared drives/5190_NLP_Project/test_data_random_subset.csv'\n",
509
+ "model_path = '/content/drive/Shared drives/5190_NLP_Project/xgboost_model/final_xgboost_model.json'\n",
510
+ "\n",
511
+ "data = pd.read_csv(csv_file_path, index_col=0)\n",
512
+ "\n",
513
+ "titles = data['title'].tolist()\n",
514
+ "labels = data['labels'].tolist()\n",
515
+ "\n",
516
+ "labels = [1 if label == 0 else 0 for label in labels]\n",
517
+ "\n",
518
+ "batch_size = 32\n",
519
+ "embeddings = []\n",
520
+ "\n",
521
+ "print('Encoding titles...')\n",
522
+ "for i in range(0, len(titles), batch_size):\n",
523
+ " batch = titles[i:i + batch_size]\n",
524
+ " batch_embeddings = model.encode(batch, batch_size=batch_size, max_length=512)['dense_vecs']\n",
525
+ " embeddings.extend(batch_embeddings)\n",
526
+ " print(f\"Processed {i + len(batch)}/{len(titles)} titles\")\n",
527
+ "\n",
528
+ "embeddings_df = pd.DataFrame(embeddings)\n",
529
+ "embeddings_df['label'] = labels\n",
530
+ "\n",
531
+ "X = embeddings_df.iloc[:, :-1].values\n",
532
+ "y = embeddings_df['label'].values\n",
533
+ "\n",
534
+ "\n",
535
+ "print('Loading XGBoost model...')\n",
536
+ "xgboost_model = xgb.XGBClassifier()\n",
537
+ "xgboost_model.load_model(model_path)\n",
538
+ "\n",
539
+ "print('Making predictions...')\n",
540
+ "y_pred_prob = xgboost_model.predict(X)\n",
541
+ "y_pred = (y_pred_prob >= 0.5).astype(int)\n",
542
+ "\n",
543
+ "print(\"\\nModel Performance:\")\n",
544
+ "accuracy = accuracy_score(y, y_pred)\n",
545
+ "print(f\"Accuracy: {accuracy:.4f}\")\n",
546
+ "\n",
547
+ "print(\"\\nClassification Report:\\n\", classification_report(y, y_pred))\n",
548
+ "\n"
549
+ ],
550
+ "metadata": {
551
+ "colab": {
552
+ "base_uri": "https://localhost:8080/"
553
+ },
554
+ "id": "O5L-7-7c-6B2",
555
+ "outputId": "39441b83-0302-40e8-aaff-bbb32ce0474a"
556
+ },
557
+ "execution_count": 4,
558
+ "outputs": [
559
+ {
560
+ "output_type": "stream",
561
+ "name": "stderr",
562
+ "text": [
563
+ "You're using a XLMRobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
564
+ ]
565
+ },
566
+ {
567
+ "output_type": "stream",
568
+ "name": "stdout",
569
+ "text": [
570
+ "Encoding titles...\n",
571
+ "Processed 20/20 titles\n",
572
+ "Loading XGBoost model...\n",
573
+ "Making predictions...\n",
574
+ "\n",
575
+ "Model Performance:\n",
576
+ "Accuracy: 0.7500\n",
577
+ "\n",
578
+ "Classification Report:\n",
579
+ " precision recall f1-score support\n",
580
+ "\n",
581
+ " 0 0.73 0.80 0.76 10\n",
582
+ " 1 0.78 0.70 0.74 10\n",
583
+ "\n",
584
+ " accuracy 0.75 20\n",
585
+ " macro avg 0.75 0.75 0.75 20\n",
586
+ "weighted avg 0.75 0.75 0.75 20\n",
587
+ "\n"
588
+ ]
589
+ }
590
+ ]
591
+ }
592
+ ]
593
+ }