diff --git "a/aclanthology_visualization.ipynb" "b/aclanthology_visualization.ipynb"
new file mode 100644--- /dev/null
+++ "b/aclanthology_visualization.ipynb"
@@ -0,0 +1,1570 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "gpuType": "T4"
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ },
+ "accelerator": "GPU",
+ "widgets": {
+ "application/vnd.jupyter.widget-state+json": {
+ "1619b254fcbb4cb880d1be5685c74dbc": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HBoxModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_607c048fb1634a7689e355036c144984",
+ "IPY_MODEL_869501d4d38e46f184a66423d93a2745",
+ "IPY_MODEL_c03dc1381a0c430182fe86d8a100b249"
+ ],
+ "layout": "IPY_MODEL_b6d31f4cebc84ef0a563d41482b14cc2"
+ }
+ },
+ "607c048fb1634a7689e355036c144984": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_9ebec18dbdff4913a4902429a726b9e0",
+ "placeholder": "",
+ "style": "IPY_MODEL_c9dc6fbcf53a4c9fb53716a18db6ffbe",
+ "value": "Map: 100%"
+ }
+ },
+ "869501d4d38e46f184a66423d93a2745": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "FloatProgressModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_c9ef5bf8ff3e44358c4557f74c3e379e",
+ "max": 1249,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_0cc5f439950e49eaa4d417396e21e2c4",
+ "value": 1249
+ }
+ },
+ "c03dc1381a0c430182fe86d8a100b249": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_1479cc60b4ac4864b46b592dc1050157",
+ "placeholder": "",
+ "style": "IPY_MODEL_a9e75caedfbf46e0bd0effe1e60065cd",
+ "value": " 1249/1249 [00:02<00:00, 454.98 examples/s]"
+ }
+ },
+ "b6d31f4cebc84ef0a563d41482b14cc2": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "9ebec18dbdff4913a4902429a726b9e0": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "c9dc6fbcf53a4c9fb53716a18db6ffbe": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "c9ef5bf8ff3e44358c4557f74c3e379e": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "0cc5f439950e49eaa4d417396e21e2c4": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "ProgressStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "1479cc60b4ac4864b46b592dc1050157": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "a9e75caedfbf46e0bd0effe1e60065cd": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "4c7b67b7151e4c9fb47eaae2f39a21b8": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HBoxModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_5bd552c8824e407c934978e35e7de980",
+ "IPY_MODEL_d052c01440db4dafb5d699eb57a9d613",
+ "IPY_MODEL_e1bbc114c9054a28a48762831a44ef11"
+ ],
+ "layout": "IPY_MODEL_591bf12de23c41c6aa510f6d6702b30e"
+ }
+ },
+ "5bd552c8824e407c934978e35e7de980": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_5b379ade011143b9bf21c2aedaaf9149",
+ "placeholder": "",
+ "style": "IPY_MODEL_26062d5edbee4879a66829962199ca43",
+ "value": "encoding: 100%"
+ }
+ },
+ "d052c01440db4dafb5d699eb57a9d613": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "FloatProgressModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_c65d6c4a6d0a44d2a9fb8ca75cc5f790",
+ "max": 20,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_3bb8adf35cf74c3cbd3d2c58912041a3",
+ "value": 20
+ }
+ },
+ "e1bbc114c9054a28a48762831a44ef11": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_cc2125fcf9ab49eb9e2be054a4c3fc18",
+ "placeholder": "",
+ "style": "IPY_MODEL_2ff93c21f097436f9ccd61a8c9c8010d",
+ "value": " 20/20 [00:32<00:00, 1.47s/it]"
+ }
+ },
+ "591bf12de23c41c6aa510f6d6702b30e": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "5b379ade011143b9bf21c2aedaaf9149": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "26062d5edbee4879a66829962199ca43": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "c65d6c4a6d0a44d2a9fb8ca75cc5f790": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "3bb8adf35cf74c3cbd3d2c58912041a3": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "ProgressStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "cc2125fcf9ab49eb9e2be054a4c3fc18": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "2ff93c21f097436f9ccd61a8c9c8010d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ }
+ }
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "source": [
+ "In this notebook, we provide the steps to reproduce a plot similar to https://huggingface.co/spaces/gwf-uwaterloo/aclscatter2d\n",
+ "\n",
+ "**Before running this colab, make sure the runtime type is set to GPU.** You can double check this in the \"Checks\" section.\n",
+ "\n",
+ "The plot will be generated using [plotly](https://plotly.com/python/getting-started/)."
+ ],
+ "metadata": {
+ "id": "AeaHYgzwgyOF"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# @title XML file name to download from acl-anthology github page\n",
+ "FILE_NAME = '2023.acl.xml' # @param {type:\"string\"}"
+ ],
+ "metadata": {
+ "cellView": "form",
+ "id": "mQ31dArhTOmd"
+ },
+ "execution_count": 1,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# @title Model name from huggingface\n",
+ "MODEL_NAME = 'allenai/specter2_base' # @param {type:\"string\"}\n",
+ "\n",
+ "ADAPTER_NAME = \"\" # @param {type:\"string\"}"
+ ],
+ "metadata": {
+ "cellView": "form",
+ "id": "jSt0Jpueanvn"
+ },
+ "execution_count": 2,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# @title Inference args\n",
+ "BATCH_SIZE = 64 # @param {type:\"integer\"}"
+ ],
+ "metadata": {
+ "cellView": "form",
+ "id": "HryCbmPBcw5V"
+ },
+ "execution_count": 3,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# @title Visualization args\n",
+ "NUM_CLUSTERS = 50 # @param {type:\"integer\"}"
+ ],
+ "metadata": {
+ "cellView": "form",
+ "id": "qyedQTz5ezl4"
+ },
+ "execution_count": 4,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Setup"
+ ],
+ "metadata": {
+ "id": "jXbz3X1sUHcr"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Install dependencies"
+ ],
+ "metadata": {
+ "id": "O9n1VhtvUQxS"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!pip install datasets\n",
+ "!pip install transformers\n",
+ "!pip install adapter-transformers==3.0.1"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "d0XchP9jUOhb",
+ "outputId": "133dcd54-f647-44bf-e5d4-f91383be6640"
+ },
+ "execution_count": 5,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "\u001b[33mWARNING: Ignoring invalid distribution -lotly (/usr/local/lib/python3.10/dist-packages)\u001b[0m\u001b[33m\n",
+ "\u001b[0mRequirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (2.14.5)\n",
+ "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.23.5)\n",
+ "Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (9.0.0)\n",
+ "Requirement already satisfied: dill<0.3.8,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.7)\n",
+ "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (1.5.3)\n",
+ "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.31.0)\n",
+ "Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.66.1)\n",
+ "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.3.0)\n",
+ "Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.15)\n",
+ "Requirement already satisfied: fsspec[http]<2023.9.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2023.6.0)\n",
+ "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.8.5)\n",
+ "Requirement already satisfied: huggingface-hub<1.0.0,>=0.14.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.17.2)\n",
+ "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (23.1)\n",
+ "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.1)\n",
+ "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (23.1.0)\n",
+ "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (3.2.0)\n",
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.0.4)\n",
+ "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n",
+ "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.9.2)\n",
+ "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.0)\n",
+ "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n",
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets) (3.12.2)\n",
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets) (4.5.0)\n",
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (3.4)\n",
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2.0.4)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2023.7.22)\n",
+ "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n",
+ "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2023.3.post1)\n",
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.16.0)\n",
+ "\u001b[33mWARNING: Ignoring invalid distribution -lotly (/usr/local/lib/python3.10/dist-packages)\u001b[0m\u001b[33m\n",
+ "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution -lotly (/usr/local/lib/python3.10/dist-packages)\u001b[0m\u001b[33m\n",
+ "\u001b[0mRequirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.33.2)\n",
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.12.2)\n",
+ "Requirement already satisfied: huggingface-hub<1.0,>=0.15.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.17.2)\n",
+ "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.23.5)\n",
+ "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.1)\n",
+ "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)\n",
+ "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2023.6.3)\n",
+ "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.31.0)\n",
+ "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.13.3)\n",
+ "Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.3.3)\n",
+ "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.1)\n",
+ "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.15.1->transformers) (2023.6.0)\n",
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.15.1->transformers) (4.5.0)\n",
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.2.0)\n",
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.4)\n",
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.4)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2023.7.22)\n",
+ "\u001b[33mWARNING: Ignoring invalid distribution -lotly (/usr/local/lib/python3.10/dist-packages)\u001b[0m\u001b[33m\n",
+ "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution -lotly (/usr/local/lib/python3.10/dist-packages)\u001b[0m\u001b[33m\n",
+ "\u001b[0mRequirement already satisfied: adapter-transformers==3.0.1 in /usr/local/lib/python3.10/dist-packages (3.0.1)\n",
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from adapter-transformers==3.0.1) (3.12.2)\n",
+ "Requirement already satisfied: huggingface-hub<1.0,>=0.1.0 in /usr/local/lib/python3.10/dist-packages (from adapter-transformers==3.0.1) (0.17.2)\n",
+ "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from adapter-transformers==3.0.1) (1.23.5)\n",
+ "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from adapter-transformers==3.0.1) (23.1)\n",
+ "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from adapter-transformers==3.0.1) (6.0.1)\n",
+ "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from adapter-transformers==3.0.1) (2023.6.3)\n",
+ "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from adapter-transformers==3.0.1) (2.31.0)\n",
+ "Requirement already satisfied: sacremoses in /usr/local/lib/python3.10/dist-packages (from adapter-transformers==3.0.1) (0.0.53)\n",
+ "Requirement already satisfied: tokenizers!=0.11.3,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from adapter-transformers==3.0.1) (0.13.3)\n",
+ "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from adapter-transformers==3.0.1) (4.66.1)\n",
+ "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.1.0->adapter-transformers==3.0.1) (2023.6.0)\n",
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.1.0->adapter-transformers==3.0.1) (4.5.0)\n",
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->adapter-transformers==3.0.1) (3.2.0)\n",
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->adapter-transformers==3.0.1) (3.4)\n",
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->adapter-transformers==3.0.1) (2.0.4)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->adapter-transformers==3.0.1) (2023.7.22)\n",
+ "Requirement already satisfied: six in /usr/local/lib/python3.10/dist-packages (from sacremoses->adapter-transformers==3.0.1) (1.16.0)\n",
+ "Requirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from sacremoses->adapter-transformers==3.0.1) (8.1.7)\n",
+ "Requirement already satisfied: joblib in /usr/local/lib/python3.10/dist-packages (from sacremoses->adapter-transformers==3.0.1) (1.3.2)\n",
+ "\u001b[33mWARNING: Ignoring invalid distribution -lotly (/usr/local/lib/python3.10/dist-packages)\u001b[0m\u001b[33m\n",
+ "\u001b[0m"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Imports"
+ ],
+ "metadata": {
+ "id": "c0MMhYc_UKfG"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import json\n",
+ "import os\n",
+ "import re\n",
+ "from functools import partial\n",
+ "from tqdm.auto import tqdm\n",
+ "from typing import Any, Iterable, Mapping\n",
+ "\n",
+ "import datasets\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "import torch\n",
+ "from torch.utils.data import DataLoader\n",
+ "from transformers import DataCollatorWithPadding, AutoModel, AutoTokenizer, AutoConfig\n",
+ "from sklearn.cluster import KMeans\n",
+ "from sklearn.manifold import TSNE\n",
+ "\n",
+ "import plotly.express as px"
+ ],
+ "metadata": {
+ "id": "AJULv3wPUG0z"
+ },
+ "execution_count": 6,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Checks"
+ ],
+ "metadata": {
+ "id": "BY2W1tBTUVWN"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "#@markdown **Check GPU type**\n",
+ "!nvidia-smi -L\n",
+ "\n",
+ "#@markdown **Check PyTorch version**\n",
+ "print(\"PyTorch version:\", torch.__version__)\n",
+ "print(\"CUDA version:\", torch.version.cuda)\n",
+ "print(\"#GPUs:\", torch.cuda.device_count())"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "cellView": "form",
+ "id": "jtYjxTfuUXUb",
+ "outputId": "4f62a4ba-8b8b-462d-caa6-e002ec2d7b1b"
+ },
+ "execution_count": 7,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "GPU 0: Tesla T4 (UUID: GPU-5e2802f0-3a72-ee6b-56ce-fc17d7e725c4)\n",
+ "PyTorch version: 2.0.1+cu118\n",
+ "CUDA version: 11.8\n",
+ "#GPUs: 1\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Load Huggingface Stuff"
+ ],
+ "metadata": {
+ "id": "osH8mbM4aCw0"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"true\"\n",
+ "\n",
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
+ "\n",
+ "config = AutoConfig.from_pretrained(MODEL_NAME, return_dict=True, output_hidden_states=True)\n",
+ "\n",
+ "model = AutoModel.from_pretrained(MODEL_NAME, config=config)\n",
+ "if ADAPTER_NAME:\n",
+ " model.load_adapter(\n",
+ " ADAPTER_NAME,\n",
+ " source=\"hf\",\n",
+ " set_active=True,\n",
+ " )\n",
+ "\n",
+ "model.eval()\n",
+ "model.to(\"cuda\")"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "6j9EGcCSZ8Z_",
+ "outputId": "1edfabc5-35b0-47d6-8c58-8cf1e35ca5fe"
+ },
+ "execution_count": 8,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "BertModel(\n",
+ " (shared_parameters): ModuleDict()\n",
+ " (invertible_adapters): ModuleDict()\n",
+ " (embeddings): BertEmbeddings(\n",
+ " (word_embeddings): Embedding(31090, 768, padding_idx=0)\n",
+ " (position_embeddings): Embedding(512, 768)\n",
+ " (token_type_embeddings): Embedding(2, 768)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (encoder): BertEncoder(\n",
+ " (layer): ModuleList(\n",
+ " (0-11): 12 x BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " (prefix_tuning): PrefixTuningShim(\n",
+ " (pool): PrefixTuningPool(\n",
+ " (prefix_tunings): ModuleDict()\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " (adapters): ModuleDict()\n",
+ " (adapter_fusion_layer): ModuleDict()\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (intermediate_act_fn): GELUActivation()\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " (adapters): ModuleDict()\n",
+ " (adapter_fusion_layer): ModuleDict()\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (pooler): BertPooler(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (activation): Tanh()\n",
+ " )\n",
+ " (prefix_tuning): PrefixTuningPool(\n",
+ " (prefix_tunings): ModuleDict()\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 8
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Preparing Data"
+ ],
+ "metadata": {
+ "id": "v9olGFFaP6Un"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Downloading from acl-anthology github"
+ ],
+ "metadata": {
+ "id": "YvFxyYEpP_wj"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "The paper information can be downloaded from `acl-anthology` github page in the XML format: https://github.com/acl-org/acl-anthology/tree/master/data/xml/"
+ ],
+ "metadata": {
+ "id": "Vm022cIzSorc"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!rm -f $FILE_NAME\n",
+ "!wget \"https://raw.githubusercontent.com/acl-org/acl-anthology/master/data/xml/$FILE_NAME\"\n",
+ "\n",
+ "assert os.path.exists(FILE_NAME), \"Downloaded file exists\""
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "knMDRgK8Sfl_",
+ "outputId": "ea0abab7-fe9f-4ffa-e627-1b3d4f5a8953"
+ },
+ "execution_count": 9,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "--2023-09-20 03:28:48-- https://raw.githubusercontent.com/acl-org/acl-anthology/master/data/xml/2023.acl.xml\n",
+ "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\n",
+ "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n",
+ "HTTP request sent, awaiting response... 200 OK\n",
+ "Length: 2597735 (2.5M) [text/plain]\n",
+ "Saving to: ‘2023.acl.xml’\n",
+ "\n",
+ "2023.acl.xml 100%[===================>] 2.48M --.-KB/s in 0.02s \n",
+ "\n",
+ "2023-09-20 03:28:49 (142 MB/s) - ‘2023.acl.xml’ saved [2597735/2597735]\n",
+ "\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "download the xml file from this [link](https://github.com/acl-org/acl-anthology/tree/006c7247a6bf0ff859bfd3aab6ea6a19452580ad/data/xml). \n",
+ "Convert the xml files to jsonl files by running the following code"
+ ],
+ "metadata": {
+ "id": "2KFobPmUbu7j"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Parsing"
+ ],
+ "metadata": {
+ "id": "CUD4LOJlUmMj"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {
+ "id": "WXQgTZQ103g7",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "4edc4fd1-0a7f-4419-ffa3-e1d9f259a139"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "#papers founds in 2023.acl.xml: 1249\n"
+ ]
+ }
+ ],
+ "source": [
+ "import xml.etree.ElementTree as ET\n",
+ "\n",
+ "URL_MAPPINGS = dict(\n",
+ " D=\"emnlp\",\n",
+ " N=\"naacl\",\n",
+ " P=\"acl\",\n",
+ " Q=\"tacl\",\n",
+ ")\n",
+ "\n",
+ "def xml_to_jsonl(xml_file: os.PathLike) -> Iterable[Mapping[str, Any]]:\n",
+ " tree = ET.parse(xml_file)\n",
+ " root = tree.getroot()\n",
+ " papers = root.findall(\".//paper\")\n",
+ "\n",
+ " for paper in papers:\n",
+ " paper_dict = {}\n",
+ " paper_dict[\"title\"] = \"\".join(paper.find(\"title\").itertext())\n",
+ "\n",
+ " authors = []\n",
+ " for author in paper.findall(\"author\"):\n",
+ " first_name = author.findtext(\"first\")\n",
+ " last_name = author.findtext(\"last\")\n",
+ " authors.append(f\"{first_name} {last_name}\")\n",
+ " paper_dict[\"authors\"] = authors\n",
+ "\n",
+ " paper_dict[\"abstract\"] = \"\" if paper.find(\"abstract\")==None else \"\".join(paper.find(\"abstract\").itertext())\n",
+ " paper_dict[\"pages\"] = paper.findtext(\"pages\")\n",
+ " paper_dict[\"url\"] = paper.findtext(\"url\")\n",
+ " paper_dict[\"bibkey\"] = paper.findtext(\"bibkey\")\n",
+ " paper_dict[\"doi\"] = paper.findtext(\"doi\")\n",
+ "\n",
+ " conference, paper_type = None, None\n",
+ " matched = re.match(r\"(\\d+)\\.(\\w+)-(\\w+)\\.\\d+\", paper_dict[\"url\"])\n",
+ " if matched:\n",
+ " year = int(matched.group(1))\n",
+ " conference = matched.group(2)\n",
+ " paper_type = matched.group(3)\n",
+ " else:\n",
+ " bibs = paper_dict[\"bibkey\"].split(\"-\")\n",
+ " for b in range(len(bibs) - 1, -1, -1):\n",
+ " try:\n",
+ " year = int(bibs[b])\n",
+ " break\n",
+ " except ValueError:\n",
+ " pass\n",
+ "\n",
+ " conference = URL_MAPPINGS.get(paper_dict[\"url\"][0], None)\n",
+ "\n",
+ " paper_dict[\"source\"] = conference\n",
+ " paper_dict[\"year\"] = year\n",
+ " paper_dict[\"publication_type\"] = paper_type\n",
+ "\n",
+ " yield paper_dict\n",
+ "\n",
+ "papers = list(xml_to_jsonl(FILE_NAME))\n",
+ "\n",
+ "print(f\"#papers founds in {FILE_NAME}: {len(papers)}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Encode"
+ ],
+ "metadata": {
+ "id": "3yXoFyHhdd25"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Creating DataLoader"
+ ],
+ "metadata": {
+ "id": "ml0g17tYX2jP"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "dataset = datasets.Dataset.from_list(\n",
+ " [{\"text\": p[\"title\"] + tokenizer.sep_token + (p[\"abstract\"] or \"\"), \"idx\": i + 1} for i, p in enumerate(papers)]\n",
+ ")\n",
+ "\n",
+ "tokenize_fn = lambda batch: tokenizer(batch[\"text\"], padding=True, truncation=True, max_length=512)\n",
+ "dataset = dataset.map(tokenize_fn, batched=True)\n",
+ "\n",
+ "columns = [\"idx\", \"input_ids\", \"attention_mask\"]\n",
+ "if \"token_type_ids\" in dataset.column_names:\n",
+ " columns.append(\"token_type_ids\")\n",
+ "\n",
+ "data_loader = DataLoader(\n",
+ " dataset.with_format(\"torch\", columns=columns),\n",
+ " collate_fn=DataCollatorWithPadding(tokenizer),\n",
+ " batch_size=BATCH_SIZE,\n",
+ ")"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 153,
+ "referenced_widgets": [
+ "1619b254fcbb4cb880d1be5685c74dbc",
+ "607c048fb1634a7689e355036c144984",
+ "869501d4d38e46f184a66423d93a2745",
+ "c03dc1381a0c430182fe86d8a100b249",
+ "b6d31f4cebc84ef0a563d41482b14cc2",
+ "9ebec18dbdff4913a4902429a726b9e0",
+ "c9dc6fbcf53a4c9fb53716a18db6ffbe",
+ "c9ef5bf8ff3e44358c4557f74c3e379e",
+ "0cc5f439950e49eaa4d417396e21e2c4",
+ "1479cc60b4ac4864b46b592dc1050157",
+ "a9e75caedfbf46e0bd0effe1e60065cd"
+ ]
+ },
+ "id": "sCG1iVa4X7ye",
+ "outputId": "a287df82-3448-4b24-9e26-582bd7b4b180"
+ },
+ "execution_count": 11,
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "Map: 0%| | 0/1249 [00:00, ? examples/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "1619b254fcbb4cb880d1be5685c74dbc"
+ }
+ },
+ "metadata": {}
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Running Inference"
+ ],
+ "metadata": {
+ "id": "1KtBlNdMdnQQ"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "embeds = []\n",
+ "for batch in tqdm(data_loader, desc=\"encoding\"):\n",
+ " indices = batch.pop(\"idx\", None)\n",
+ " if isinstance(indices, torch.Tensor):\n",
+ " indices = indices.cpu().tolist()\n",
+ "\n",
+ " batch = {k: v.to(\"cuda\") if v is not None else v for k, v in batch.items()}\n",
+ "\n",
+ " with torch.no_grad():\n",
+ " output = model(**batch)\n",
+ " encoded = output.last_hidden_state[:, 0].cpu().numpy()\n",
+ "\n",
+ " embeds.append(encoded)\n",
+ "\n",
+ "embeds = np.concatenate(embeds, axis=0)\n",
+ "\n",
+ "print(f\"Embeddings size:\", embeds.shape)"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 143,
+ "referenced_widgets": [
+ "4c7b67b7151e4c9fb47eaae2f39a21b8",
+ "5bd552c8824e407c934978e35e7de980",
+ "d052c01440db4dafb5d699eb57a9d613",
+ "e1bbc114c9054a28a48762831a44ef11",
+ "591bf12de23c41c6aa510f6d6702b30e",
+ "5b379ade011143b9bf21c2aedaaf9149",
+ "26062d5edbee4879a66829962199ca43",
+ "c65d6c4a6d0a44d2a9fb8ca75cc5f790",
+ "3bb8adf35cf74c3cbd3d2c58912041a3",
+ "cc2125fcf9ab49eb9e2be054a4c3fc18",
+ "2ff93c21f097436f9ccd61a8c9c8010d"
+ ]
+ },
+ "id": "QbqEJIgWdr2o",
+ "outputId": "644df647-2880-417e-be90-e12492a5c3b7"
+ },
+ "execution_count": 12,
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "encoding: 0%| | 0/20 [00:00, ?it/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "4c7b67b7151e4c9fb47eaae2f39a21b8"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Embeddings size: (1249, 768)\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Housekeeping prior to Visualization\n",
+ "\n",
+ "To plot the embeddings, we first cluster the points and then reduce the number of dimensions to 2-d using t-SNE."
+ ],
+ "metadata": {
+ "id": "agDmw5DPefij"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Clustering"
+ ],
+ "metadata": {
+ "id": "bbLVfaIufuSu"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "clusterer = KMeans(n_clusters=NUM_CLUSTERS, n_init=\"auto\")\n",
+ "clusters = clusterer.fit(embeds).labels_\n",
+ "\n",
+ "print(\"Clustering done\")"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "iUEvI_OaeoIf",
+ "outputId": "6bc3c072-7fdc-4349-cd57-7e9205f77c01"
+ },
+ "execution_count": 13,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Clustering done\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Applying t-SNE\n",
+ "\n",
+ "We changed perplexity and number of iterations from their default value because the scatter plot would look nicer."
+ ],
+ "metadata": {
+ "id": "sqfCHjTAfwXF"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "reducer = TSNE(n_jobs=12, perplexity=10, n_iter=3000)\n",
+ "reduced_embeds = reducer.fit_transform(embeds)"
+ ],
+ "metadata": {
+ "id": "AT2fWBc4fyFl"
+ },
+ "execution_count": 14,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Visualize"
+ ],
+ "metadata": {
+ "id": "t28XwXvNgrBo"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# @title\n",
+ "def to_string_authors(list_of_authors):\n",
+ " if len(list_of_authors) > 5:\n",
+ " return \", \".join(list_of_authors[:5]) + \", et al.\"\n",
+ " elif len(list_of_authors) > 2:\n",
+ " return \", \".join(list_of_authors[:-1]) + \", and \" + list_of_authors[-1]\n",
+ " else:\n",
+ " return \" and \".join(list_of_authors)\n",
+ "\n",
+ "\n",
+ "for i, (point, c, p) in enumerate(zip(reduced_embeds, clusters, papers)):\n",
+ " p[\"x\"] = point[0]\n",
+ " p[\"y\"] = point[1]\n",
+ " p[\"cluster\"] = c\n",
+ " p[\"authors_trimmed\"] = [(x[x.index(\",\") + 1 :].strip() + \" \" + x.split(\",\")[0].strip()) if \",\" in x else x for x in p[\"authors\"]]\n",
+ " if \"publication_type\" in p:\n",
+ " p[\"type\"] = p.pop(\"publication_type\")\n",
+ "\n",
+ "df = pd.DataFrame(papers)\n",
+ "\n",
+ "fig = px.scatter(\n",
+ " df,\n",
+ " x=\"x\",\n",
+ " y=\"y\",\n",
+ " color=\"cluster\",\n",
+ " width=1000,\n",
+ " height=800,\n",
+ " custom_data=(\"title\", \"authors_trimmed\", \"year\", \"source\", \"type\"),\n",
+ " color_continuous_scale=\"fall\",\n",
+ ")\n",
+ "fig.update_traces(\n",
+ " hovertemplate=\"%{customdata[0]}
%{customdata[1]}
%{customdata[2]}
%{customdata[3]}\"\n",
+ ")\n",
+ "fig.update_layout(\n",
+ " showlegend=False,\n",
+ " font=dict(\n",
+ " family=\"Times New Roman\",\n",
+ " size=30,\n",
+ " ),\n",
+ " hoverlabel=dict(\n",
+ " align=\"left\",\n",
+ " font_size=14,\n",
+ " font_family=\"Rockwell\",\n",
+ " namelength=-1,\n",
+ " ),\n",
+ ")\n",
+ "fig.update_xaxes(title=\"\")\n",
+ "fig.update_yaxes(title=\"\")\n",
+ "\n",
+ "a = fig.show()"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 817
+ },
+ "cellView": "form",
+ "id": "B-TwYJM5gtF-",
+ "outputId": "99a5d7d7-2e49-43af-be93-7677c50effba"
+ },
+ "execution_count": 15,
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/html": [
+ "\n",
+ "