{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "d564a252f0994e50a82e49477311aaaf": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_e1281f85ea8e48ac82e6cdf68530b700", "IPY_MODEL_e95c7910fde4463190dbace7fc4bbd43", "IPY_MODEL_fcdbe6f02f30423687d2262a1c260609" ], "layout": "IPY_MODEL_643fa99220814aaea83ef3cf9001971c" } }, "e1281f85ea8e48ac82e6cdf68530b700": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_9618ee7122a44226b57b836136f94df5", "placeholder": "​", "style": "IPY_MODEL_5bd6ab3d866540caa45b9e82fcc7f081", "value": "Fetching 30 files: 100%" } }, "e95c7910fde4463190dbace7fc4bbd43": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_87b0deadc3d54484b21b7aab1ed67379", "max": 30, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_6c0747f6032f4788a4c476aa43246cd3", "value": 30 } }, "fcdbe6f02f30423687d2262a1c260609": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_5001265687cd4a2da0be38685f18d003", "placeholder": "​", "style": "IPY_MODEL_77a09a218d2a421e9b3fe33f8310a37e", "value": " 30/30 [00:00<00:00, 862.80it/s]" } }, "643fa99220814aaea83ef3cf9001971c": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "9618ee7122a44226b57b836136f94df5": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "5bd6ab3d866540caa45b9e82fcc7f081": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "87b0deadc3d54484b21b7aab1ed67379": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "6c0747f6032f4788a4c476aa43246cd3": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "5001265687cd4a2da0be38685f18d003": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "77a09a218d2a421e9b3fe33f8310a37e": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } } } } }, "cells": [ { "cell_type": "markdown", "source": [ "Evaluation Pipeline for XGboost with Embedding features\n" ], "metadata": { "id": "Q1KbUyBa9AOW" } }, { "cell_type": "markdown", "source": [ "Install the required model and change the data and model path" ], "metadata": { "id": "zSikCRSxAh8a" } }, { "cell_type": "code", "source": [ "!pip install -U FlagEmbedding\n", "from FlagEmbedding import BGEM3FlagModel\n", "model = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True)" ], "metadata": { "id": "mKfYf2gyOAYS", "colab": { "base_uri": "https://localhost:8080/", "height": 1000, "referenced_widgets": [ "d564a252f0994e50a82e49477311aaaf", "e1281f85ea8e48ac82e6cdf68530b700", "e95c7910fde4463190dbace7fc4bbd43", "fcdbe6f02f30423687d2262a1c260609", "643fa99220814aaea83ef3cf9001971c", "9618ee7122a44226b57b836136f94df5", "5bd6ab3d866540caa45b9e82fcc7f081", "87b0deadc3d54484b21b7aab1ed67379", "6c0747f6032f4788a4c476aa43246cd3", "5001265687cd4a2da0be38685f18d003", "77a09a218d2a421e9b3fe33f8310a37e" ] }, "outputId": "a4863acc-785a-44cb-d126-0cc9867ead3e" }, "execution_count": 6, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Requirement already satisfied: FlagEmbedding in /usr/local/lib/python3.10/dist-packages (1.3.3)\n", "Requirement already satisfied: torch>=1.6.0 in /usr/local/lib/python3.10/dist-packages (from FlagEmbedding) (2.5.1+cu121)\n", "Requirement already satisfied: transformers==4.44.2 in /usr/local/lib/python3.10/dist-packages (from FlagEmbedding) (4.44.2)\n", "Requirement already satisfied: datasets==2.19.0 in /usr/local/lib/python3.10/dist-packages (from FlagEmbedding) (2.19.0)\n", "Requirement already satisfied: accelerate>=0.20.1 in /usr/local/lib/python3.10/dist-packages (from FlagEmbedding) (1.1.1)\n", "Requirement already satisfied: sentence-transformers in /usr/local/lib/python3.10/dist-packages (from FlagEmbedding) (3.2.1)\n", "Requirement already satisfied: peft in /usr/local/lib/python3.10/dist-packages (from FlagEmbedding) (0.13.2)\n", "Requirement already satisfied: ir-datasets in /usr/local/lib/python3.10/dist-packages (from FlagEmbedding) (0.5.9)\n", "Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (from FlagEmbedding) (0.2.0)\n", "Requirement already satisfied: protobuf in /usr/local/lib/python3.10/dist-packages (from FlagEmbedding) (4.25.5)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets==2.19.0->FlagEmbedding) (3.16.1)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets==2.19.0->FlagEmbedding) (1.26.4)\n", "Requirement already satisfied: pyarrow>=12.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.19.0->FlagEmbedding) (17.0.0)\n", "Requirement already satisfied: pyarrow-hotfix in /usr/local/lib/python3.10/dist-packages (from datasets==2.19.0->FlagEmbedding) (0.6)\n", "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.19.0->FlagEmbedding) (0.3.8)\n", "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets==2.19.0->FlagEmbedding) (2.2.2)\n", "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.19.0->FlagEmbedding) (2.32.3)\n", "Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.19.0->FlagEmbedding) (4.66.6)\n", "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets==2.19.0->FlagEmbedding) (3.5.0)\n", "Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets==2.19.0->FlagEmbedding) (0.70.16)\n", "Requirement already satisfied: fsspec<=2024.3.1,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<=2024.3.1,>=2023.1.0->datasets==2.19.0->FlagEmbedding) (2024.3.1)\n", "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets==2.19.0->FlagEmbedding) (3.11.9)\n", "Requirement already satisfied: huggingface-hub>=0.21.2 in /usr/local/lib/python3.10/dist-packages (from datasets==2.19.0->FlagEmbedding) (0.26.3)\n", "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets==2.19.0->FlagEmbedding) (24.2)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets==2.19.0->FlagEmbedding) (6.0.2)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers==4.44.2->FlagEmbedding) (2024.9.11)\n", "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers==4.44.2->FlagEmbedding) (0.4.5)\n", "Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.10/dist-packages (from transformers==4.44.2->FlagEmbedding) (0.19.1)\n", "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.20.1->FlagEmbedding) (5.9.5)\n", "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.6.0->FlagEmbedding) (4.12.2)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.6.0->FlagEmbedding) (3.4.2)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.6.0->FlagEmbedding) (3.1.4)\n", "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch>=1.6.0->FlagEmbedding) (1.13.1)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch>=1.6.0->FlagEmbedding) (1.3.0)\n", "Requirement already satisfied: beautifulsoup4>=4.4.1 in /usr/local/lib/python3.10/dist-packages (from ir-datasets->FlagEmbedding) (4.12.3)\n", "Requirement already satisfied: inscriptis>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from ir-datasets->FlagEmbedding) (2.5.0)\n", "Requirement already satisfied: lxml>=4.5.2 in /usr/local/lib/python3.10/dist-packages (from ir-datasets->FlagEmbedding) (5.3.0)\n", "Requirement already satisfied: trec-car-tools>=2.5.4 in /usr/local/lib/python3.10/dist-packages (from ir-datasets->FlagEmbedding) (2.6)\n", "Requirement already satisfied: lz4>=3.1.10 in /usr/local/lib/python3.10/dist-packages (from ir-datasets->FlagEmbedding) (4.3.3)\n", "Requirement already satisfied: warc3-wet>=0.2.3 in /usr/local/lib/python3.10/dist-packages (from ir-datasets->FlagEmbedding) (0.2.5)\n", "Requirement already satisfied: warc3-wet-clueweb09>=0.2.5 in /usr/local/lib/python3.10/dist-packages (from ir-datasets->FlagEmbedding) (0.2.5)\n", "Requirement already satisfied: zlib-state>=0.1.3 in /usr/local/lib/python3.10/dist-packages (from ir-datasets->FlagEmbedding) (0.1.9)\n", "Requirement already satisfied: ijson>=3.1.3 in /usr/local/lib/python3.10/dist-packages (from ir-datasets->FlagEmbedding) (3.3.0)\n", "Requirement already satisfied: unlzw3>=0.2.1 in /usr/local/lib/python3.10/dist-packages (from ir-datasets->FlagEmbedding) (0.2.2)\n", "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.10/dist-packages (from sentence-transformers->FlagEmbedding) (1.5.2)\n", "Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from sentence-transformers->FlagEmbedding) (1.13.1)\n", "Requirement already satisfied: Pillow in /usr/local/lib/python3.10/dist-packages (from sentence-transformers->FlagEmbedding) (11.0.0)\n", "Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.10/dist-packages (from beautifulsoup4>=4.4.1->ir-datasets->FlagEmbedding) (2.6)\n", "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (2.4.4)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (1.3.1)\n", "Requirement already satisfied: async-timeout<6.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (4.0.3)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (24.2.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (1.5.0)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (6.1.0)\n", "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (0.2.1)\n", "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (1.18.3)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.19.0->FlagEmbedding) (3.4.0)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.19.0->FlagEmbedding) (3.10)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.19.0->FlagEmbedding) (2.2.3)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets==2.19.0->FlagEmbedding) (2024.8.30)\n", "Requirement already satisfied: cbor>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from trec-car-tools>=2.5.4->ir-datasets->FlagEmbedding) (1.0.0)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.6.0->FlagEmbedding) (3.0.2)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.19.0->FlagEmbedding) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.19.0->FlagEmbedding) (2024.2)\n", "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets==2.19.0->FlagEmbedding) (2024.2)\n", "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->sentence-transformers->FlagEmbedding) (1.4.2)\n", "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->sentence-transformers->FlagEmbedding) (3.5.0)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->datasets==2.19.0->FlagEmbedding) (1.16.0)\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "Fetching 30 files: 0%| | 0/30 [00:00= 0.5).astype(int)\n", "\n", "print(\"\\nModel Performance:\")\n", "accuracy = accuracy_score(y, y_pred)\n", "print(f\"Accuracy: {accuracy:.4f}\")\n", "\n", "print(\"\\nClassification Report:\\n\", classification_report(y, y_pred))\n", "\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "O5L-7-7c-6B2", "outputId": "39441b83-0302-40e8-aaff-bbb32ce0474a" }, "execution_count": 4, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "You're using a XLMRobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "Encoding titles...\n", "Processed 20/20 titles\n", "Loading XGBoost model...\n", "Making predictions...\n", "\n", "Model Performance:\n", "Accuracy: 0.7500\n", "\n", "Classification Report:\n", " precision recall f1-score support\n", "\n", " 0 0.73 0.80 0.76 10\n", " 1 0.78 0.70 0.74 10\n", "\n", " accuracy 0.75 20\n", " macro avg 0.75 0.75 0.75 20\n", "weighted avg 0.75 0.75 0.75 20\n", "\n" ] } ] } ] }